logo
Browse Source

init the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
651e84f553
  1. 0
      __init__.py
  2. 0
      ru_clip.py
  3. 82
      ruclip/__init__.py
  4. 239
      ruclip/model.py
  5. 65
      ruclip/predictor.py
  6. 74
      ruclip/processor.py

0
__init__.py

0
ru_clip.py

82
ruclip/__init__.py

@ -0,0 +1,82 @@
# -*- coding: utf-8 -*-
import os
from huggingface_hub import hf_hub_url, cached_download
from . import model, processor, predictor
from .model import CLIP
from .processor import RuCLIPProcessor
from .predictor import Predictor
MODELS = {
'ruclip-vit-base-patch32-224': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch32-224',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-base-patch16-224': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch16-224',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-large-patch14-224': dict(
repo_id='sberbank-ai/ruclip-vit-large-patch14-224',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-large-patch14-336': dict(
repo_id='sberbank-ai/ruclip-vit-large-patch14-336',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-base-patch32-384': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch32-384',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
'ruclip-vit-base-patch16-384': dict(
repo_id='sberbank-ai/ruclip-vit-base-patch16-384',
filenames=[
'bpe.model', 'config.json', 'pytorch_model.bin'
]
),
}
def load(name, device='cpu', cache_dir='/tmp/ruclip', use_auth_token=None):
"""Load a ruCLIP model
Parameters
----------
name : str
A model name listed in ruclip.MODELS.keys()
device : Union[str, torch.device]
The device to put the loaded model
cache_dir: str
path to download the model files; by default, it uses "/tmp/ruclip"
Returns
-------
clip : torch.nn.Module
The ruCLIP model
clip_processor : ruclip.processor.RuCLIPProcessor
A ruCLIP processor which performs tokenization and image preprocessing
"""
assert name in MODELS, f'All models: {MODELS.keys()}'
config = MODELS[name]
repo_id = config['repo_id']
cache_dir = os.path.join(cache_dir, name)
for filename in config['filenames']:
config_file_url = hf_hub_url(repo_id=repo_id, filename=f'{filename}')
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename, use_auth_token=use_auth_token)
clip = CLIP.from_pretrained(cache_dir).eval().to(device)
clip_processor = RuCLIPProcessor.from_pretrained(cache_dir)
return clip, clip_processor
__all__ = ['processor', 'model', 'predictor', 'CLIP', 'RuCLIPProcessor', 'Predictor', 'MODELS', 'load']
__version__ = '0.0.2'

239
ruclip/model.py

@ -0,0 +1,239 @@
# -*- coding: utf-8 -*-
import os
import json
from collections import OrderedDict
import torch
import numpy as np
from torch import nn
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.class_embedding.to(x.dtype) +
torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x
], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(
self,
embed_dim,
image_resolution,
vision_layers,
vision_width,
vision_patch_size,
context_length,
vocab_size,
transformer_width,
transformer_heads,
transformer_layers,
eos_id=3,
):
super().__init__()
self.eos_id = eos_id
self.context_length = context_length
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim,
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask(),
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1)
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, pixel_values):
"""Encode images
Parameters
----------
pixel_values: torch.Tensor
Processed images from RuCLIPProcessor class
Returns
-------
image_latents : torch.Tensor
Image embeddings
"""
return self.visual(pixel_values.type(self.dtype))
def encode_text(self, input_ids):
"""Encode texts
Parameters
----------
input_ids: torch.Tensor
Tokenized texts from RuCLIPProcessor class
Returns
-------
text_latents : torch.Tensor
Text embeddings
"""
x = self.token_embedding(input_ids).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
x = x[torch.arange(x.shape[0]), torch.where(input_ids == self.eos_id)[1]] @ self.text_projection
return x
def forward(self, input_ids, pixel_values):
image_features = self.encode_image(pixel_values)
text_features = self.encode_text(input_ids)
# normalize features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
return logits_per_image, logits_per_text
@classmethod
def from_pretrained(cls, folder):
"""Load model from folder"""
config = json.load(open(os.path.join(folder, 'config.json')))
model = cls(
embed_dim=config['embed_dim'],
image_resolution=config['image_resolution'],
vision_layers=config['vision_layers'],
vision_width=config['vision_width'],
vision_patch_size=config['vision_patch_size'],
context_length=config['context_length'],
vocab_size=config['vocab_size'],
transformer_width=config['transformer_width'],
transformer_heads=config['transformer_heads'],
transformer_layers=config['transformer_layers'],
)
checkpoint = torch.load(os.path.join(folder, 'pytorch_model.bin'), map_location='cpu')
model.load_state_dict(checkpoint)
return model

65
ruclip/predictor.py

@ -0,0 +1,65 @@
# -*- coding: utf-8 -*-
import torch
import more_itertools
from tqdm import tqdm
class Predictor:
def __init__(self, clip_model, clip_processor, device, templates=None, bs=8, quiet=False):
self.device = device
self.clip_model = clip_model.to(self.device)
self.clip_model.eval()
self.clip_processor = clip_processor
self.bs = bs
self.quiet = quiet
self.templates = templates or [
'{}',
'фото, на котором изображено {}',
'изображение с {}',
'картинка с {}',
'фото с {}',
'на фото видно {}',
]
def get_text_latents(self, class_labels):
text_latents = []
for template in self.templates:
_text_latents = []
for chunk in more_itertools.chunked(class_labels, self.bs):
texts = [template.format(class_label.lower().strip()) for class_label in chunk]
inputs = self.clip_processor(text=texts, return_tensors='pt', padding=True)
_text_latents.append(self.clip_model.encode_text(inputs['input_ids'].to(self.device)))
text_latents.append(torch.cat(_text_latents, dim=0))
text_latents = torch.stack(text_latents).mean(0)
text_latents = text_latents / text_latents.norm(dim=-1, keepdim=True)
return text_latents
def run(self, images, text_latents):
if not self.quiet:
pbar = tqdm()
labels = []
logit_scale = self.clip_model.logit_scale.exp()
for pil_images in more_itertools.chunked(images, self.bs):
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True)
image_latents = self.clip_model.encode_image(inputs['pixel_values'].to(self.device))
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
logits_per_text = torch.matmul(text_latents.to(self.device), image_latents.t()) * logit_scale
_labels = logits_per_text.argmax(0).cpu().numpy().tolist()
if not self.quiet:
pbar.update(len(_labels))
labels.extend(_labels)
pbar.close()
return labels
def get_image_latents(self, images):
if not self.quiet:
pbar = tqdm()
image_latents = []
for pil_images in more_itertools.chunked(images, self.bs):
inputs = self.clip_processor(text='', images=list(pil_images), return_tensors='pt', padding=True)
image_latents.append(self.clip_model.encode_image(inputs['pixel_values'].to(self.device)))
if not self.quiet:
pbar.update(len(pil_images))
image_latents = torch.cat(image_latents)
image_latents = image_latents / image_latents.norm(dim=-1, keepdim=True)
return image_latents

74
ruclip/processor.py

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
import os
import json
import torch
import numpy as np
import youtokentome as yttm
import torchvision.transforms as T
from torch.nn.utils.rnn import pad_sequence
class RuCLIPProcessor:
eos_id = 3
bos_id = 2
unk_id = 1
pad_id = 0
def __init__(self, tokenizer_path, image_size=224, text_seq_length=77, mean=None, std=None):
self.tokenizer = yttm.BPE(tokenizer_path)
self.mean = mean or [0.48145466, 0.4578275, 0.40821073]
self.std = std or [0.26862954, 0.26130258, 0.27577711]
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)),
T.ToTensor(),
T.Normalize(mean=self.mean, std=self.std)
])
self.text_seq_length = text_seq_length
self.image_size = image_size
def encode_text(self, text):
text = text.lower()
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0]
tokens = tokens[:self.text_seq_length-2]
tokens = [self.bos_id] + tokens + [self.eos_id]
return self.prepare_tokens(tokens)
def prepare_tokens(self, tokens):
empty_positions = self.text_seq_length - len(tokens)
if empty_positions > 0:
tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text
if len(tokens) > self.text_seq_length:
tokens = tokens[:self.text_seq_length-1] + tokens[-1:]
return torch.tensor(tokens).long()
def decode_text(self, encoded):
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
self.eos_id, self.bos_id, self.unk_id, self.pad_id
])[0]
def __call__(self, text=None, images=None, **kwargs):
inputs = {}
if text is not None:
input_ids = []
texts = [text] if isinstance(text, str) else text
for text in texts:
tokens = self.encode_text(text)
input_ids.append(tokens)
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True)
if images is not None:
pixel_values = []
for i, image in enumerate(images):
pixel_values.append(self.image_transform(image))
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True)
return inputs
@classmethod
def from_pretrained(cls, folder):
tokenizer_path = os.path.join(folder, 'bpe.model')
config = json.load(open(os.path.join(folder, 'config.json')))
image_size = config['image_resolution']
text_seq_length = config['context_length']
mean, std = config.get('mean'), config.get('std')
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)
Loading…
Cancel
Save