From 651e84f553b2328fa0b861451640fdf2b2e1e35c Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 19 Oct 2022 17:31:16 +0800 Subject: [PATCH] init the operator. Signed-off-by: wxywb --- __init__.py | 0 ru_clip.py | 0 ruclip/__init__.py | 82 +++++++++++++++ ruclip/model.py | 239 ++++++++++++++++++++++++++++++++++++++++++++ ruclip/predictor.py | 65 ++++++++++++ ruclip/processor.py | 74 ++++++++++++++ 6 files changed, 460 insertions(+) create mode 100644 __init__.py create mode 100644 ru_clip.py create mode 100644 ruclip/__init__.py create mode 100644 ruclip/model.py create mode 100644 ruclip/predictor.py create mode 100644 ruclip/processor.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ru_clip.py b/ru_clip.py new file mode 100644 index 0000000..e69de29 diff --git a/ruclip/__init__.py b/ruclip/__init__.py new file mode 100644 index 0000000..5d0c487 --- /dev/null +++ b/ruclip/__init__.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +import os + +from huggingface_hub import hf_hub_url, cached_download + +from . import model, processor, predictor +from .model import CLIP +from .processor import RuCLIPProcessor +from .predictor import Predictor + +MODELS = { + 'ruclip-vit-base-patch32-224': dict( + repo_id='sberbank-ai/ruclip-vit-base-patch32-224', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), + 'ruclip-vit-base-patch16-224': dict( + repo_id='sberbank-ai/ruclip-vit-base-patch16-224', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), + 'ruclip-vit-large-patch14-224': dict( + repo_id='sberbank-ai/ruclip-vit-large-patch14-224', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), + 'ruclip-vit-large-patch14-336': dict( + repo_id='sberbank-ai/ruclip-vit-large-patch14-336', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), + 'ruclip-vit-base-patch32-384': dict( + repo_id='sberbank-ai/ruclip-vit-base-patch32-384', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), + 'ruclip-vit-base-patch16-384': dict( + repo_id='sberbank-ai/ruclip-vit-base-patch16-384', + filenames=[ + 'bpe.model', 'config.json', 'pytorch_model.bin' + ] + ), +} + + +def load(name, device='cpu', cache_dir='/tmp/ruclip', use_auth_token=None): + """Load a ruCLIP model + Parameters + ---------- + name : str + A model name listed in ruclip.MODELS.keys() + device : Union[str, torch.device] + The device to put the loaded model + cache_dir: str + path to download the model files; by default, it uses "/tmp/ruclip" + Returns + ------- + clip : torch.nn.Module + The ruCLIP model + clip_processor : ruclip.processor.RuCLIPProcessor + A ruCLIP processor which performs tokenization and image preprocessing + """ + assert name in MODELS, f'All models: {MODELS.keys()}' + config = MODELS[name] + repo_id = config['repo_id'] + cache_dir = os.path.join(cache_dir, name) + for filename in config['filenames']: + config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{filename}') + cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename, use_auth_token=use_auth_token) + + clip = CLIP.from_pretrained(cache_dir).eval().to(device) + clip_processor = RuCLIPProcessor.from_pretrained(cache_dir) + return clip, clip_processor + + +__all__ = ['processor', 'model', 'predictor', 'CLIP', 'RuCLIPProcessor', 'Predictor', 'MODELS', 'load'] +__version__ = '0.0.2' diff --git a/ruclip/model.py b/ruclip/model.py new file mode 100644 index 0000000..f992e7e --- /dev/null +++ b/ruclip/model.py @@ -0,0 +1,239 @@ +# -*- coding: utf-8 -*- +import os +import json +from collections import OrderedDict + +import torch +import numpy as np +from torch import nn + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) + + scale = width ** -0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat([ + self.class_embedding.to(x.dtype) + + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x + ], dim=1) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIP(nn.Module): + def __init__( + self, + embed_dim, + image_resolution, + vision_layers, + vision_width, + vision_patch_size, + context_length, + vocab_size, + transformer_width, + transformer_heads, + transformer_layers, + eos_id=3, + ): + super().__init__() + + self.eos_id = eos_id + self.context_length = context_length + + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim, + ) + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + if self.text_projection is not None: + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float('-inf')) + mask.triu_(1) + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, pixel_values): + """Encode images + Parameters + ---------- + pixel_values: torch.Tensor + Processed images from RuCLIPProcessor class + Returns + ------- + image_latents : torch.Tensor + Image embeddings + """ + return self.visual(pixel_values.type(self.dtype)) + + def encode_text(self, input_ids): + """Encode texts + Parameters + ---------- + input_ids: torch.Tensor + Tokenized texts from RuCLIPProcessor class + Returns + ------- + text_latents : torch.Tensor + Text embeddings + """ + x = self.token_embedding(input_ids).type(self.dtype) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding.type(self.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x).type(self.dtype) + # x.shape = [batch_size, n_ctx, transformer.width] + x = x[torch.arange(x.shape[0]), torch.where(input_ids == self.eos_id)[1]] @ self.text_projection + return x + + def forward(self, input_ids, pixel_values): + image_features = self.encode_image(pixel_values) + text_features = self.encode_text(input_ids) + + # normalize features + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = self.logit_scale.exp() + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.t() + + return logits_per_image, logits_per_text + + @classmethod + def from_pretrained(cls, folder): + """Load model from folder""" + config = json.load(open(os.path.join(folder, 'config.json'))) + model = cls( + embed_dim=config['embed_dim'], + image_resolution=config['image_resolution'], + vision_layers=config['vision_layers'], + vision_width=config['vision_width'], + vision_patch_size=config['vision_patch_size'], + context_length=config['context_length'], + vocab_size=config['vocab_size'], + transformer_width=config['transformer_width'], + transformer_heads=config['transformer_heads'], + transformer_layers=config['transformer_layers'], + ) + checkpoint = torch.load(os.path.join(folder, 'pytorch_model.bin'), map_location='cpu') + model.load_state_dict(checkpoint) + return model diff --git a/ruclip/predictor.py b/ruclip/predictor.py new file mode 100644 index 0000000..15fd370 --- /dev/null +++ b/ruclip/predictor.py @@ -0,0 +1,65 @@ +# -*- coding: utf-8 -*- +import torch +import more_itertools +from tqdm import tqdm + + +class Predictor: + def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False): + self.device = device + self.clip_model = clip_model.to(self.device) + self.clip_model.eval() + self.clip_processor = clip_processor + self.bs = bs + self.quiet = quiet + self.templates = templates or [ + '{}', + 'фото, на котором изображено {}', + 'изображение с {}', + 'картинка с {}', + 'фото с {}', + 'на фото видно {}', + ] + + def get_text_latents(self, class_labels): + text_latents = [] + for template in self.templates: + _text_latents = [] + for chunk in more_itertools.chunked(class_labels, self.bs): + texts = [template.format(class_label.lower().strip()) for class_label in chunk] + inputs = self.clip_processor(text=texts, return_tensors='pt', padding=True) + _text_latents.append(self.clip_model.encode_text(inputs['input_ids'].to(self.device))) + text_latents.append(torch.cat(_text_latents, dim=0)) + text_latents = torch.stack(text_latents).mean(0) + text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True) + return text_latents + + def run(self, images, text_latents): + if not self.quiet: + pbar = tqdm() + labels = [] + logit_scale = self.clip_model.logit_scale.exp() + for pil_images in more_itertools.chunked(images, self.bs): + inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) + image_latents = self.clip_model.encode_image(inputs['pixel_values'].to(self.device)) + image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) + logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale + _labels = logits_per_text.argmax(0).cpu().numpy().tolist() + if not self.quiet: + pbar.update(len(_labels)) + labels.extend(_labels) + pbar.close() + return labels + + def get_image_latents(self, images): + if not self.quiet: + pbar = tqdm() + image_latents = [] + for pil_images in more_itertools.chunked(images, self.bs): + inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True) + image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device))) + if not self.quiet: + pbar.update(len(pil_images)) + image_latents = torch.cat(image_latents) + image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True) + return image_latents diff --git a/ruclip/processor.py b/ruclip/processor.py new file mode 100644 index 0000000..3ea8135 --- /dev/null +++ b/ruclip/processor.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +import os +import json + +import torch +import numpy as np +import youtokentome as yttm +import torchvision.transforms as T +from torch.nn.utils.rnn import pad_sequence + + +class RuCLIPProcessor: + eos_id = 3 + bos_id = 2 + unk_id = 1 + pad_id = 0 + + def __init__(self, tokenizer_path, image_size=224, text_seq_length=77, mean=None, std=None): + self.tokenizer = yttm.BPE(tokenizer_path) + self.mean = mean or [0.48145466, 0.4578275, 0.40821073] + self.std = std or [0.26862954, 0.26130258, 0.27577711] + self.image_transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)), + T.ToTensor(), + T.Normalize(mean=self.mean, std=self.std) + ]) + self.text_seq_length = text_seq_length + self.image_size = image_size + + def encode_text(self, text): + text = text.lower() + tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0] + tokens = tokens[:self.text_seq_length-2] + tokens = [self.bos_id] + tokens + [self.eos_id] + return self.prepare_tokens(tokens) + + def prepare_tokens(self, tokens): + empty_positions = self.text_seq_length - len(tokens) + if empty_positions > 0: + tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text + if len(tokens) > self.text_seq_length: + tokens = tokens[:self.text_seq_length-1] + tokens[-1:] + return torch.tensor(tokens).long() + + def decode_text(self, encoded): + return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[ + self.eos_id, self.bos_id, self.unk_id, self.pad_id + ])[0] + + def __call__(self, text=None, images=None, **kwargs): + inputs = {} + if text is not None: + input_ids = [] + texts = [text] if isinstance(text, str) else text + for text in texts: + tokens = self.encode_text(text) + input_ids.append(tokens) + inputs['input_ids'] = pad_sequence(input_ids, batch_first=True) + if images is not None: + pixel_values = [] + for i, image in enumerate(images): + pixel_values.append(self.image_transform(image)) + inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True) + return inputs + + @classmethod + def from_pretrained(cls, folder): + tokenizer_path = os.path.join(folder, 'bpe.model') + config = json.load(open(os.path.join(folder, 'config.json'))) + image_size = config['image_resolution'] + text_seq_length = config['context_length'] + mean, std = config.get('mean'), config.get('std') + return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)