ru-clip
copied
wxywb
2 years ago
6 changed files with 460 additions and 0 deletions
@ -0,0 +1,82 @@ |
|||
# -*- coding: utf-8 -*- |
|||
import os |
|||
|
|||
from huggingface_hub import hf_hub_url, cached_download |
|||
|
|||
from . import model, processor, predictor |
|||
from .model import CLIP |
|||
from .processor import RuCLIPProcessor |
|||
from .predictor import Predictor |
|||
|
|||
MODELS = { |
|||
'ruclip-vit-base-patch32-224': dict( |
|||
repo_id='sberbank-ai/ruclip-vit-base-patch32-224', |
|||
filenames=[ |
|||
'bpe.model', 'config.json', 'pytorch_model.bin' |
|||
] |
|||
), |
|||
'ruclip-vit-base-patch16-224': dict( |
|||
repo_id='sberbank-ai/ruclip-vit-base-patch16-224', |
|||
filenames=[ |
|||
'bpe.model', 'config.json', 'pytorch_model.bin' |
|||
] |
|||
), |
|||
'ruclip-vit-large-patch14-224': dict( |
|||
repo_id='sberbank-ai/ruclip-vit-large-patch14-224', |
|||
filenames=[ |
|||
'bpe.model', 'config.json', 'pytorch_model.bin' |
|||
] |
|||
), |
|||
'ruclip-vit-large-patch14-336': dict( |
|||
repo_id='sberbank-ai/ruclip-vit-large-patch14-336', |
|||
filenames=[ |
|||
'bpe.model', 'config.json', 'pytorch_model.bin' |
|||
] |
|||
), |
|||
'ruclip-vit-base-patch32-384': dict( |
|||
repo_id='sberbank-ai/ruclip-vit-base-patch32-384', |
|||
filenames=[ |
|||
'bpe.model', 'config.json', 'pytorch_model.bin' |
|||
] |
|||
), |
|||
'ruclip-vit-base-patch16-384': dict( |
|||
repo_id='sberbank-ai/ruclip-vit-base-patch16-384', |
|||
filenames=[ |
|||
'bpe.model', 'config.json', 'pytorch_model.bin' |
|||
] |
|||
), |
|||
} |
|||
|
|||
|
|||
def load(name, device='cpu', cache_dir='/tmp/ruclip', use_auth_token=None): |
|||
"""Load a ruCLIP model |
|||
Parameters |
|||
---------- |
|||
name : str |
|||
A model name listed in ruclip.MODELS.keys() |
|||
device : Union[str, torch.device] |
|||
The device to put the loaded model |
|||
cache_dir: str |
|||
path to download the model files; by default, it uses "/tmp/ruclip" |
|||
Returns |
|||
------- |
|||
clip : torch.nn.Module |
|||
The ruCLIP model |
|||
clip_processor : ruclip.processor.RuCLIPProcessor |
|||
A ruCLIP processor which performs tokenization and image preprocessing |
|||
""" |
|||
assert name in MODELS, f'All models: {MODELS.keys()}' |
|||
config = MODELS[name] |
|||
repo_id = config['repo_id'] |
|||
cache_dir = os.path.join(cache_dir, name) |
|||
for filename in config['filenames']: |
|||
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{filename}') |
|||
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename, use_auth_token=use_auth_token) |
|||
|
|||
clip = CLIP.from_pretrained(cache_dir).eval().to(device) |
|||
clip_processor = RuCLIPProcessor.from_pretrained(cache_dir) |
|||
return clip, clip_processor |
|||
|
|||
|
|||
__all__ = ['processor', 'model', 'predictor', 'CLIP', 'RuCLIPProcessor', 'Predictor', 'MODELS', 'load'] |
|||
__version__ = '0.0.2' |
@ -0,0 +1,239 @@ |
|||
# -*- coding: utf-8 -*- |
|||
import os |
|||
import json |
|||
from collections import OrderedDict |
|||
|
|||
import torch |
|||
import numpy as np |
|||
from torch import nn |
|||
|
|||
|
|||
class LayerNorm(nn.LayerNorm): |
|||
"""Subclass torch's LayerNorm to handle fp16.""" |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
orig_type = x.dtype |
|||
ret = super().forward(x.type(torch.float32)) |
|||
return ret.type(orig_type) |
|||
|
|||
|
|||
class QuickGELU(nn.Module): |
|||
def forward(self, x: torch.Tensor): |
|||
return x * torch.sigmoid(1.702 * x) |
|||
|
|||
|
|||
class ResidualAttentionBlock(nn.Module): |
|||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
|||
super().__init__() |
|||
|
|||
self.attn = nn.MultiheadAttention(d_model, n_head) |
|||
self.ln_1 = LayerNorm(d_model) |
|||
self.mlp = nn.Sequential(OrderedDict([ |
|||
('c_fc', nn.Linear(d_model, d_model * 4)), |
|||
('gelu', QuickGELU()), |
|||
('c_proj', nn.Linear(d_model * 4, d_model)) |
|||
])) |
|||
self.ln_2 = LayerNorm(d_model) |
|||
self.attn_mask = attn_mask |
|||
|
|||
def attention(self, x: torch.Tensor): |
|||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
|||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
x = x + self.attention(self.ln_1(x)) |
|||
x = x + self.mlp(self.ln_2(x)) |
|||
return x |
|||
|
|||
|
|||
class Transformer(nn.Module): |
|||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
|||
super().__init__() |
|||
self.width = width |
|||
self.layers = layers |
|||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
return self.resblocks(x) |
|||
|
|||
|
|||
class VisionTransformer(nn.Module): |
|||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
|||
super().__init__() |
|||
self.input_resolution = input_resolution |
|||
self.output_dim = output_dim |
|||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
|||
|
|||
scale = width ** -0.5 |
|||
self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
|||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
|||
self.ln_pre = LayerNorm(width) |
|||
|
|||
self.transformer = Transformer(width, layers, heads) |
|||
|
|||
self.ln_post = LayerNorm(width) |
|||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
x = self.conv1(x) # shape = [*, width, grid, grid] |
|||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] |
|||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] |
|||
x = torch.cat([ |
|||
self.class_embedding.to(x.dtype) + |
|||
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x |
|||
], dim=1) # shape = [*, grid ** 2 + 1, width] |
|||
x = x + self.positional_embedding.to(x.dtype) |
|||
x = self.ln_pre(x) |
|||
|
|||
x = x.permute(1, 0, 2) # NLD -> LND |
|||
x = self.transformer(x) |
|||
x = x.permute(1, 0, 2) # LND -> NLD |
|||
|
|||
x = self.ln_post(x[:, 0, :]) |
|||
|
|||
if self.proj is not None: |
|||
x = x @ self.proj |
|||
|
|||
return x |
|||
|
|||
|
|||
class CLIP(nn.Module): |
|||
def __init__( |
|||
self, |
|||
embed_dim, |
|||
image_resolution, |
|||
vision_layers, |
|||
vision_width, |
|||
vision_patch_size, |
|||
context_length, |
|||
vocab_size, |
|||
transformer_width, |
|||
transformer_heads, |
|||
transformer_layers, |
|||
eos_id=3, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.eos_id = eos_id |
|||
self.context_length = context_length |
|||
|
|||
vision_heads = vision_width // 64 |
|||
self.visual = VisionTransformer( |
|||
input_resolution=image_resolution, |
|||
patch_size=vision_patch_size, |
|||
width=vision_width, |
|||
layers=vision_layers, |
|||
heads=vision_heads, |
|||
output_dim=embed_dim, |
|||
) |
|||
|
|||
self.transformer = Transformer( |
|||
width=transformer_width, |
|||
layers=transformer_layers, |
|||
heads=transformer_heads, |
|||
attn_mask=self.build_attention_mask(), |
|||
) |
|||
|
|||
self.vocab_size = vocab_size |
|||
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
|||
self.ln_final = LayerNorm(transformer_width) |
|||
|
|||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
|||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|||
|
|||
self.initialize_parameters() |
|||
|
|||
def initialize_parameters(self): |
|||
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|||
nn.init.normal_(self.positional_embedding, std=0.01) |
|||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
|||
attn_std = self.transformer.width ** -0.5 |
|||
fc_std = (2 * self.transformer.width) ** -0.5 |
|||
for block in self.transformer.resblocks: |
|||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
|||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
|||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
|||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|||
|
|||
if self.text_projection is not None: |
|||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|||
|
|||
def build_attention_mask(self): |
|||
mask = torch.empty(self.context_length, self.context_length) |
|||
mask.fill_(float('-inf')) |
|||
mask.triu_(1) |
|||
return mask |
|||
|
|||
@property |
|||
def dtype(self): |
|||
return self.visual.conv1.weight.dtype |
|||
|
|||
def encode_image(self, pixel_values): |
|||
"""Encode images |
|||
Parameters |
|||
---------- |
|||
pixel_values: torch.Tensor |
|||
Processed images from RuCLIPProcessor class |
|||
Returns |
|||
------- |
|||
image_latents : torch.Tensor |
|||
Image embeddings |
|||
""" |
|||
return self.visual(pixel_values.type(self.dtype)) |
|||
|
|||
def encode_text(self, input_ids): |
|||
"""Encode texts |
|||
Parameters |
|||
---------- |
|||
input_ids: torch.Tensor |
|||
Tokenized texts from RuCLIPProcessor class |
|||
Returns |
|||
------- |
|||
text_latents : torch.Tensor |
|||
Text embeddings |
|||
""" |
|||
x = self.token_embedding(input_ids).type(self.dtype) # [batch_size, n_ctx, d_model] |
|||
x = x + self.positional_embedding.type(self.dtype) |
|||
x = x.permute(1, 0, 2) # NLD -> LND |
|||
x = self.transformer(x) |
|||
x = x.permute(1, 0, 2) # LND -> NLD |
|||
x = self.ln_final(x).type(self.dtype) |
|||
# x.shape = [batch_size, n_ctx, transformer.width] |
|||
x = x[torch.arange(x.shape[0]), torch.where(input_ids == self.eos_id)[1]] @ self.text_projection |
|||
return x |
|||
|
|||
def forward(self, input_ids, pixel_values): |
|||
image_features = self.encode_image(pixel_values) |
|||
text_features = self.encode_text(input_ids) |
|||
|
|||
# normalize features |
|||
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|||
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|||
|
|||
# cosine similarity as logits |
|||
logit_scale = self.logit_scale.exp() |
|||
logits_per_image = logit_scale * image_features @ text_features.t() |
|||
logits_per_text = logits_per_image.t() |
|||
|
|||
return logits_per_image, logits_per_text |
|||
|
|||
@classmethod |
|||
def from_pretrained(cls, folder): |
|||
"""Load model from folder""" |
|||
config = json.load(open(os.path.join(folder, 'config.json'))) |
|||
model = cls( |
|||
embed_dim=config['embed_dim'], |
|||
image_resolution=config['image_resolution'], |
|||
vision_layers=config['vision_layers'], |
|||
vision_width=config['vision_width'], |
|||
vision_patch_size=config['vision_patch_size'], |
|||
context_length=config['context_length'], |
|||
vocab_size=config['vocab_size'], |
|||
transformer_width=config['transformer_width'], |
|||
transformer_heads=config['transformer_heads'], |
|||
transformer_layers=config['transformer_layers'], |
|||
) |
|||
checkpoint = torch.load(os.path.join(folder, 'pytorch_model.bin'), map_location='cpu') |
|||
model.load_state_dict(checkpoint) |
|||
return model |
@ -0,0 +1,65 @@ |
|||
# -*- coding: utf-8 -*- |
|||
import torch |
|||
import more_itertools |
|||
from tqdm import tqdm |
|||
|
|||
|
|||
class Predictor: |
|||
def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False): |
|||
self.device = device |
|||
self.clip_model = clip_model.to(self.device) |
|||
self.clip_model.eval() |
|||
self.clip_processor = clip_processor |
|||
self.bs = bs |
|||
self.quiet = quiet |
|||
self.templates = templates or [ |
|||
'{}', |
|||
'фото, на котором изображено {}', |
|||
'изображение с {}', |
|||
'картинка с {}', |
|||
'фото с {}', |
|||
'на фото видно {}', |
|||
] |
|||
|
|||
def get_text_latents(self, class_labels): |
|||
text_latents = [] |
|||
for template in self.templates: |
|||
_text_latents = [] |
|||
for chunk in more_itertools.chunked(class_labels, self.bs): |
|||
texts = [template.format(class_label.lower().strip()) for class_label in chunk] |
|||
inputs = self.clip_processor(text=texts, return_tensors='pt', padding=True) |
|||
_text_latents.append(self.clip_model.encode_text(inputs['input_ids'].to(self.device))) |
|||
text_latents.append(torch.cat(_text_latents, dim=0)) |
|||
text_latents = torch.stack(text_latents).mean(0) |
|||
text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True) |
|||
return text_latents |
|||
|
|||
def run(self, images, text_latents): |
|||
if not self.quiet: |
|||
pbar = tqdm() |
|||
labels = [] |
|||
logit_scale = self.clip_model.logit_scale.exp() |
|||
for pil_images in more_itertools.chunked(images, self.bs): |
|||
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) |
|||
image_latents = self.clip_model.encode_image(inputs['pixel_values'].to(self.device)) |
|||
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) |
|||
logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale |
|||
_labels = logits_per_text.argmax(0).cpu().numpy().tolist() |
|||
if not self.quiet: |
|||
pbar.update(len(_labels)) |
|||
labels.extend(_labels) |
|||
pbar.close() |
|||
return labels |
|||
|
|||
def get_image_latents(self, images): |
|||
if not self.quiet: |
|||
pbar = tqdm() |
|||
image_latents = [] |
|||
for pil_images in more_itertools.chunked(images, self.bs): |
|||
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) |
|||
image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device))) |
|||
if not self.quiet: |
|||
pbar.update(len(pil_images)) |
|||
image_latents = torch.cat(image_latents) |
|||
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) |
|||
return image_latents |
@ -0,0 +1,74 @@ |
|||
# -*- coding: utf-8 -*- |
|||
import os |
|||
import json |
|||
|
|||
import torch |
|||
import numpy as np |
|||
import youtokentome as yttm |
|||
import torchvision.transforms as T |
|||
from torch.nn.utils.rnn import pad_sequence |
|||
|
|||
|
|||
class RuCLIPProcessor: |
|||
eos_id = 3 |
|||
bos_id = 2 |
|||
unk_id = 1 |
|||
pad_id = 0 |
|||
|
|||
def __init__(self, tokenizer_path, image_size=224, text_seq_length=77, mean=None, std=None): |
|||
self.tokenizer = yttm.BPE(tokenizer_path) |
|||
self.mean = mean or [0.48145466, 0.4578275, 0.40821073] |
|||
self.std = std or [0.26862954, 0.26130258, 0.27577711] |
|||
self.image_transform = T.Compose([ |
|||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
|||
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), |
|||
T.ToTensor(), |
|||
T.Normalize(mean=self.mean, std=self.std) |
|||
]) |
|||
self.text_seq_length = text_seq_length |
|||
self.image_size = image_size |
|||
|
|||
def encode_text(self, text): |
|||
text = text.lower() |
|||
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] |
|||
tokens = tokens[:self.text_seq_length-2] |
|||
tokens = [self.bos_id] + tokens + [self.eos_id] |
|||
return self.prepare_tokens(tokens) |
|||
|
|||
def prepare_tokens(self, tokens): |
|||
empty_positions = self.text_seq_length - len(tokens) |
|||
if empty_positions > 0: |
|||
tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text |
|||
if len(tokens) > self.text_seq_length: |
|||
tokens = tokens[:self.text_seq_length-1] + tokens[-1:] |
|||
return torch.tensor(tokens).long() |
|||
|
|||
def decode_text(self, encoded): |
|||
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ |
|||
self.eos_id, self.bos_id, self.unk_id, self.pad_id |
|||
])[0] |
|||
|
|||
def __call__(self, text=None, images=None, **kwargs): |
|||
inputs = {} |
|||
if text is not None: |
|||
input_ids = [] |
|||
texts = [text] if isinstance(text, str) else text |
|||
for text in texts: |
|||
tokens = self.encode_text(text) |
|||
input_ids.append(tokens) |
|||
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) |
|||
if images is not None: |
|||
pixel_values = [] |
|||
for i, image in enumerate(images): |
|||
pixel_values.append(self.image_transform(image)) |
|||
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) |
|||
return inputs |
|||
|
|||
@classmethod |
|||
def from_pretrained(cls, folder): |
|||
tokenizer_path = os.path.join(folder, 'bpe.model') |
|||
config = json.load(open(os.path.join(folder, 'config.json'))) |
|||
image_size = config['image_resolution'] |
|||
text_seq_length = config['context_length'] |
|||
mean, std = config.get('mean'), config.get('std') |
|||
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) |
Loading…
Reference in new issue