ru-clip
copied
wxywb
2 years ago
6 changed files with 460 additions and 0 deletions
@ -0,0 +1,82 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
import os |
||||
|
|
||||
|
from huggingface_hub import hf_hub_url, cached_download |
||||
|
|
||||
|
from . import model, processor, predictor |
||||
|
from .model import CLIP |
||||
|
from .processor import RuCLIPProcessor |
||||
|
from .predictor import Predictor |
||||
|
|
||||
|
MODELS = { |
||||
|
'ruclip-vit-base-patch32-224': dict( |
||||
|
repo_id='sberbank-ai/ruclip-vit-base-patch32-224', |
||||
|
filenames=[ |
||||
|
'bpe.model', 'config.json', 'pytorch_model.bin' |
||||
|
] |
||||
|
), |
||||
|
'ruclip-vit-base-patch16-224': dict( |
||||
|
repo_id='sberbank-ai/ruclip-vit-base-patch16-224', |
||||
|
filenames=[ |
||||
|
'bpe.model', 'config.json', 'pytorch_model.bin' |
||||
|
] |
||||
|
), |
||||
|
'ruclip-vit-large-patch14-224': dict( |
||||
|
repo_id='sberbank-ai/ruclip-vit-large-patch14-224', |
||||
|
filenames=[ |
||||
|
'bpe.model', 'config.json', 'pytorch_model.bin' |
||||
|
] |
||||
|
), |
||||
|
'ruclip-vit-large-patch14-336': dict( |
||||
|
repo_id='sberbank-ai/ruclip-vit-large-patch14-336', |
||||
|
filenames=[ |
||||
|
'bpe.model', 'config.json', 'pytorch_model.bin' |
||||
|
] |
||||
|
), |
||||
|
'ruclip-vit-base-patch32-384': dict( |
||||
|
repo_id='sberbank-ai/ruclip-vit-base-patch32-384', |
||||
|
filenames=[ |
||||
|
'bpe.model', 'config.json', 'pytorch_model.bin' |
||||
|
] |
||||
|
), |
||||
|
'ruclip-vit-base-patch16-384': dict( |
||||
|
repo_id='sberbank-ai/ruclip-vit-base-patch16-384', |
||||
|
filenames=[ |
||||
|
'bpe.model', 'config.json', 'pytorch_model.bin' |
||||
|
] |
||||
|
), |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def load(name, device='cpu', cache_dir='/tmp/ruclip', use_auth_token=None): |
||||
|
"""Load a ruCLIP model |
||||
|
Parameters |
||||
|
---------- |
||||
|
name : str |
||||
|
A model name listed in ruclip.MODELS.keys() |
||||
|
device : Union[str, torch.device] |
||||
|
The device to put the loaded model |
||||
|
cache_dir: str |
||||
|
path to download the model files; by default, it uses "/tmp/ruclip" |
||||
|
Returns |
||||
|
------- |
||||
|
clip : torch.nn.Module |
||||
|
The ruCLIP model |
||||
|
clip_processor : ruclip.processor.RuCLIPProcessor |
||||
|
A ruCLIP processor which performs tokenization and image preprocessing |
||||
|
""" |
||||
|
assert name in MODELS, f'All models: {MODELS.keys()}' |
||||
|
config = MODELS[name] |
||||
|
repo_id = config['repo_id'] |
||||
|
cache_dir = os.path.join(cache_dir, name) |
||||
|
for filename in config['filenames']: |
||||
|
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{filename}') |
||||
|
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename, use_auth_token=use_auth_token) |
||||
|
|
||||
|
clip = CLIP.from_pretrained(cache_dir).eval().to(device) |
||||
|
clip_processor = RuCLIPProcessor.from_pretrained(cache_dir) |
||||
|
return clip, clip_processor |
||||
|
|
||||
|
|
||||
|
__all__ = ['processor', 'model', 'predictor', 'CLIP', 'RuCLIPProcessor', 'Predictor', 'MODELS', 'load'] |
||||
|
__version__ = '0.0.2' |
@ -0,0 +1,239 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
import os |
||||
|
import json |
||||
|
from collections import OrderedDict |
||||
|
|
||||
|
import torch |
||||
|
import numpy as np |
||||
|
from torch import nn |
||||
|
|
||||
|
|
||||
|
class LayerNorm(nn.LayerNorm): |
||||
|
"""Subclass torch's LayerNorm to handle fp16.""" |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
orig_type = x.dtype |
||||
|
ret = super().forward(x.type(torch.float32)) |
||||
|
return ret.type(orig_type) |
||||
|
|
||||
|
|
||||
|
class QuickGELU(nn.Module): |
||||
|
def forward(self, x: torch.Tensor): |
||||
|
return x * torch.sigmoid(1.702 * x) |
||||
|
|
||||
|
|
||||
|
class ResidualAttentionBlock(nn.Module): |
||||
|
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.attn = nn.MultiheadAttention(d_model, n_head) |
||||
|
self.ln_1 = LayerNorm(d_model) |
||||
|
self.mlp = nn.Sequential(OrderedDict([ |
||||
|
('c_fc', nn.Linear(d_model, d_model * 4)), |
||||
|
('gelu', QuickGELU()), |
||||
|
('c_proj', nn.Linear(d_model * 4, d_model)) |
||||
|
])) |
||||
|
self.ln_2 = LayerNorm(d_model) |
||||
|
self.attn_mask = attn_mask |
||||
|
|
||||
|
def attention(self, x: torch.Tensor): |
||||
|
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
||||
|
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
x = x + self.attention(self.ln_1(x)) |
||||
|
x = x + self.mlp(self.ln_2(x)) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class Transformer(nn.Module): |
||||
|
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
||||
|
super().__init__() |
||||
|
self.width = width |
||||
|
self.layers = layers |
||||
|
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
return self.resblocks(x) |
||||
|
|
||||
|
|
||||
|
class VisionTransformer(nn.Module): |
||||
|
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): |
||||
|
super().__init__() |
||||
|
self.input_resolution = input_resolution |
||||
|
self.output_dim = output_dim |
||||
|
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) |
||||
|
|
||||
|
scale = width ** -0.5 |
||||
|
self.class_embedding = nn.Parameter(scale * torch.randn(width)) |
||||
|
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) |
||||
|
self.ln_pre = LayerNorm(width) |
||||
|
|
||||
|
self.transformer = Transformer(width, layers, heads) |
||||
|
|
||||
|
self.ln_post = LayerNorm(width) |
||||
|
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) |
||||
|
|
||||
|
def forward(self, x: torch.Tensor): |
||||
|
x = self.conv1(x) # shape = [*, width, grid, grid] |
||||
|
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] |
||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] |
||||
|
x = torch.cat([ |
||||
|
self.class_embedding.to(x.dtype) + |
||||
|
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x |
||||
|
], dim=1) # shape = [*, grid ** 2 + 1, width] |
||||
|
x = x + self.positional_embedding.to(x.dtype) |
||||
|
x = self.ln_pre(x) |
||||
|
|
||||
|
x = x.permute(1, 0, 2) # NLD -> LND |
||||
|
x = self.transformer(x) |
||||
|
x = x.permute(1, 0, 2) # LND -> NLD |
||||
|
|
||||
|
x = self.ln_post(x[:, 0, :]) |
||||
|
|
||||
|
if self.proj is not None: |
||||
|
x = x @ self.proj |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class CLIP(nn.Module): |
||||
|
def __init__( |
||||
|
self, |
||||
|
embed_dim, |
||||
|
image_resolution, |
||||
|
vision_layers, |
||||
|
vision_width, |
||||
|
vision_patch_size, |
||||
|
context_length, |
||||
|
vocab_size, |
||||
|
transformer_width, |
||||
|
transformer_heads, |
||||
|
transformer_layers, |
||||
|
eos_id=3, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.eos_id = eos_id |
||||
|
self.context_length = context_length |
||||
|
|
||||
|
vision_heads = vision_width // 64 |
||||
|
self.visual = VisionTransformer( |
||||
|
input_resolution=image_resolution, |
||||
|
patch_size=vision_patch_size, |
||||
|
width=vision_width, |
||||
|
layers=vision_layers, |
||||
|
heads=vision_heads, |
||||
|
output_dim=embed_dim, |
||||
|
) |
||||
|
|
||||
|
self.transformer = Transformer( |
||||
|
width=transformer_width, |
||||
|
layers=transformer_layers, |
||||
|
heads=transformer_heads, |
||||
|
attn_mask=self.build_attention_mask(), |
||||
|
) |
||||
|
|
||||
|
self.vocab_size = vocab_size |
||||
|
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
||||
|
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
||||
|
self.ln_final = LayerNorm(transformer_width) |
||||
|
|
||||
|
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
||||
|
|
||||
|
self.initialize_parameters() |
||||
|
|
||||
|
def initialize_parameters(self): |
||||
|
nn.init.normal_(self.token_embedding.weight, std=0.02) |
||||
|
nn.init.normal_(self.positional_embedding, std=0.01) |
||||
|
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
||||
|
attn_std = self.transformer.width ** -0.5 |
||||
|
fc_std = (2 * self.transformer.width) ** -0.5 |
||||
|
for block in self.transformer.resblocks: |
||||
|
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
||||
|
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
||||
|
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
||||
|
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
||||
|
|
||||
|
if self.text_projection is not None: |
||||
|
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
||||
|
|
||||
|
def build_attention_mask(self): |
||||
|
mask = torch.empty(self.context_length, self.context_length) |
||||
|
mask.fill_(float('-inf')) |
||||
|
mask.triu_(1) |
||||
|
return mask |
||||
|
|
||||
|
@property |
||||
|
def dtype(self): |
||||
|
return self.visual.conv1.weight.dtype |
||||
|
|
||||
|
def encode_image(self, pixel_values): |
||||
|
"""Encode images |
||||
|
Parameters |
||||
|
---------- |
||||
|
pixel_values: torch.Tensor |
||||
|
Processed images from RuCLIPProcessor class |
||||
|
Returns |
||||
|
------- |
||||
|
image_latents : torch.Tensor |
||||
|
Image embeddings |
||||
|
""" |
||||
|
return self.visual(pixel_values.type(self.dtype)) |
||||
|
|
||||
|
def encode_text(self, input_ids): |
||||
|
"""Encode texts |
||||
|
Parameters |
||||
|
---------- |
||||
|
input_ids: torch.Tensor |
||||
|
Tokenized texts from RuCLIPProcessor class |
||||
|
Returns |
||||
|
------- |
||||
|
text_latents : torch.Tensor |
||||
|
Text embeddings |
||||
|
""" |
||||
|
x = self.token_embedding(input_ids).type(self.dtype) # [batch_size, n_ctx, d_model] |
||||
|
x = x + self.positional_embedding.type(self.dtype) |
||||
|
x = x.permute(1, 0, 2) # NLD -> LND |
||||
|
x = self.transformer(x) |
||||
|
x = x.permute(1, 0, 2) # LND -> NLD |
||||
|
x = self.ln_final(x).type(self.dtype) |
||||
|
# x.shape = [batch_size, n_ctx, transformer.width] |
||||
|
x = x[torch.arange(x.shape[0]), torch.where(input_ids == self.eos_id)[1]] @ self.text_projection |
||||
|
return x |
||||
|
|
||||
|
def forward(self, input_ids, pixel_values): |
||||
|
image_features = self.encode_image(pixel_values) |
||||
|
text_features = self.encode_text(input_ids) |
||||
|
|
||||
|
# normalize features |
||||
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
||||
|
|
||||
|
# cosine similarity as logits |
||||
|
logit_scale = self.logit_scale.exp() |
||||
|
logits_per_image = logit_scale * image_features @ text_features.t() |
||||
|
logits_per_text = logits_per_image.t() |
||||
|
|
||||
|
return logits_per_image, logits_per_text |
||||
|
|
||||
|
@classmethod |
||||
|
def from_pretrained(cls, folder): |
||||
|
"""Load model from folder""" |
||||
|
config = json.load(open(os.path.join(folder, 'config.json'))) |
||||
|
model = cls( |
||||
|
embed_dim=config['embed_dim'], |
||||
|
image_resolution=config['image_resolution'], |
||||
|
vision_layers=config['vision_layers'], |
||||
|
vision_width=config['vision_width'], |
||||
|
vision_patch_size=config['vision_patch_size'], |
||||
|
context_length=config['context_length'], |
||||
|
vocab_size=config['vocab_size'], |
||||
|
transformer_width=config['transformer_width'], |
||||
|
transformer_heads=config['transformer_heads'], |
||||
|
transformer_layers=config['transformer_layers'], |
||||
|
) |
||||
|
checkpoint = torch.load(os.path.join(folder, 'pytorch_model.bin'), map_location='cpu') |
||||
|
model.load_state_dict(checkpoint) |
||||
|
return model |
@ -0,0 +1,65 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
import torch |
||||
|
import more_itertools |
||||
|
from tqdm import tqdm |
||||
|
|
||||
|
|
||||
|
class Predictor: |
||||
|
def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False): |
||||
|
self.device = device |
||||
|
self.clip_model = clip_model.to(self.device) |
||||
|
self.clip_model.eval() |
||||
|
self.clip_processor = clip_processor |
||||
|
self.bs = bs |
||||
|
self.quiet = quiet |
||||
|
self.templates = templates or [ |
||||
|
'{}', |
||||
|
'фото, на котором изображено {}', |
||||
|
'изображение с {}', |
||||
|
'картинка с {}', |
||||
|
'фото с {}', |
||||
|
'на фото видно {}', |
||||
|
] |
||||
|
|
||||
|
def get_text_latents(self, class_labels): |
||||
|
text_latents = [] |
||||
|
for template in self.templates: |
||||
|
_text_latents = [] |
||||
|
for chunk in more_itertools.chunked(class_labels, self.bs): |
||||
|
texts = [template.format(class_label.lower().strip()) for class_label in chunk] |
||||
|
inputs = self.clip_processor(text=texts, return_tensors='pt', padding=True) |
||||
|
_text_latents.append(self.clip_model.encode_text(inputs['input_ids'].to(self.device))) |
||||
|
text_latents.append(torch.cat(_text_latents, dim=0)) |
||||
|
text_latents = torch.stack(text_latents).mean(0) |
||||
|
text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True) |
||||
|
return text_latents |
||||
|
|
||||
|
def run(self, images, text_latents): |
||||
|
if not self.quiet: |
||||
|
pbar = tqdm() |
||||
|
labels = [] |
||||
|
logit_scale = self.clip_model.logit_scale.exp() |
||||
|
for pil_images in more_itertools.chunked(images, self.bs): |
||||
|
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) |
||||
|
image_latents = self.clip_model.encode_image(inputs['pixel_values'].to(self.device)) |
||||
|
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) |
||||
|
logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale |
||||
|
_labels = logits_per_text.argmax(0).cpu().numpy().tolist() |
||||
|
if not self.quiet: |
||||
|
pbar.update(len(_labels)) |
||||
|
labels.extend(_labels) |
||||
|
pbar.close() |
||||
|
return labels |
||||
|
|
||||
|
def get_image_latents(self, images): |
||||
|
if not self.quiet: |
||||
|
pbar = tqdm() |
||||
|
image_latents = [] |
||||
|
for pil_images in more_itertools.chunked(images, self.bs): |
||||
|
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) |
||||
|
image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device))) |
||||
|
if not self.quiet: |
||||
|
pbar.update(len(pil_images)) |
||||
|
image_latents = torch.cat(image_latents) |
||||
|
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) |
||||
|
return image_latents |
@ -0,0 +1,74 @@ |
|||||
|
# -*- coding: utf-8 -*- |
||||
|
import os |
||||
|
import json |
||||
|
|
||||
|
import torch |
||||
|
import numpy as np |
||||
|
import youtokentome as yttm |
||||
|
import torchvision.transforms as T |
||||
|
from torch.nn.utils.rnn import pad_sequence |
||||
|
|
||||
|
|
||||
|
class RuCLIPProcessor: |
||||
|
eos_id = 3 |
||||
|
bos_id = 2 |
||||
|
unk_id = 1 |
||||
|
pad_id = 0 |
||||
|
|
||||
|
def __init__(self, tokenizer_path, image_size=224, text_seq_length=77, mean=None, std=None): |
||||
|
self.tokenizer = yttm.BPE(tokenizer_path) |
||||
|
self.mean = mean or [0.48145466, 0.4578275, 0.40821073] |
||||
|
self.std = std or [0.26862954, 0.26130258, 0.27577711] |
||||
|
self.image_transform = T.Compose([ |
||||
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), |
||||
|
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), |
||||
|
T.ToTensor(), |
||||
|
T.Normalize(mean=self.mean, std=self.std) |
||||
|
]) |
||||
|
self.text_seq_length = text_seq_length |
||||
|
self.image_size = image_size |
||||
|
|
||||
|
def encode_text(self, text): |
||||
|
text = text.lower() |
||||
|
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] |
||||
|
tokens = tokens[:self.text_seq_length-2] |
||||
|
tokens = [self.bos_id] + tokens + [self.eos_id] |
||||
|
return self.prepare_tokens(tokens) |
||||
|
|
||||
|
def prepare_tokens(self, tokens): |
||||
|
empty_positions = self.text_seq_length - len(tokens) |
||||
|
if empty_positions > 0: |
||||
|
tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text |
||||
|
if len(tokens) > self.text_seq_length: |
||||
|
tokens = tokens[:self.text_seq_length-1] + tokens[-1:] |
||||
|
return torch.tensor(tokens).long() |
||||
|
|
||||
|
def decode_text(self, encoded): |
||||
|
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ |
||||
|
self.eos_id, self.bos_id, self.unk_id, self.pad_id |
||||
|
])[0] |
||||
|
|
||||
|
def __call__(self, text=None, images=None, **kwargs): |
||||
|
inputs = {} |
||||
|
if text is not None: |
||||
|
input_ids = [] |
||||
|
texts = [text] if isinstance(text, str) else text |
||||
|
for text in texts: |
||||
|
tokens = self.encode_text(text) |
||||
|
input_ids.append(tokens) |
||||
|
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) |
||||
|
if images is not None: |
||||
|
pixel_values = [] |
||||
|
for i, image in enumerate(images): |
||||
|
pixel_values.append(self.image_transform(image)) |
||||
|
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) |
||||
|
return inputs |
||||
|
|
||||
|
@classmethod |
||||
|
def from_pretrained(cls, folder): |
||||
|
tokenizer_path = os.path.join(folder, 'bpe.model') |
||||
|
config = json.load(open(os.path.join(folder, 'config.json'))) |
||||
|
image_size = config['image_resolution'] |
||||
|
text_seq_length = config['context_length'] |
||||
|
mean, std = config.get('mean'), config.get('std') |
||||
|
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std) |
Loading…
Reference in new issue