diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..3a4024d --- /dev/null +++ b/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .blip import Blip + + +def blip(model_name: str, modality: str): + return Blip(model_name, modality) diff --git a/models.py b/models.py new file mode 100644 index 0000000..a082384 --- /dev/null +++ b/models.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from github.com/openai/CLIP +from collections import OrderedDict + +import numpy as np +import timm +import torch +from torch import nn + +import losses + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential(OrderedDict([ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None + return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class CLIP(nn.Module): + def __init__(self, + embed_dim: int, + # vision + vision_width: int, + vision_model: nn.Module, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + **kwargs, + ): + super().__init__() + + self.context_length = context_length + self.vision_width = vision_width + + self.visual = vision_model + + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask(), + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) + self.ln_final = LayerNorm(transformer_width) + + self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) + self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + self.initialize_parameters() + + def initialize_parameters(self): + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + + nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) + nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_image(self, image): + x = self.visual(image) + x = x @ self.image_projection + + return x + + def encode_text(self, text): + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection + + return x + + def forward(self, image, text): + image_embed = self.encode_image(image) + text_embed = self.encode_text(text) + + return {'image_embed': image_embed, + 'text_embed': text_embed, + 'logit_scale': self.logit_scale.exp()} + + +class SIMCLR(nn.Module): + def __init__(self, + # vision + vision_width: int, + vision_model: nn.Module, + # ssl + ssl_mlp_dim: int, + ssl_emb_dim: int, + **kwargs, + ): + super().__init__() + + self.vision_width = vision_width + self.visual = vision_model + + self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) + + def _build_mlp(self, in_dim, mlp_dim, out_dim): + return nn.Sequential(OrderedDict([ + ("layer1", nn.Linear(in_dim, mlp_dim)), + ("bn1", nn.SyncBatchNorm(mlp_dim)), + ("relu1", nn.ReLU(inplace=True)), + ("layer2", nn.Linear(mlp_dim, mlp_dim)), + ("bn2", nn.SyncBatchNorm(mlp_dim)), + ("relu2", nn.ReLU(inplace=True)), + ("layer3", nn.Linear(mlp_dim, out_dim)), + ])) + + def encode_image(self, image): + x = self.visual(image) + + return x + + def forward(self, aug1, aug2): + h1 = self.visual(aug1) + h2 = self.visual(aug2) + + aug1_embed = self.image_mlp(h1) + aug2_embed = self.image_mlp(h2) + + return {'aug1_embed': aug1_embed, + 'aug2_embed': aug2_embed} + + +class SLIP(CLIP): + def __init__(self, + ssl_mlp_dim: int, + ssl_emb_dim: int, + **kwargs, + ): + super().__init__(**kwargs) + + self.image_mlp = self._build_mlp(in_dim=self.vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) + + def _build_mlp(self, in_dim, mlp_dim, out_dim): + return nn.Sequential(OrderedDict([ + ("layer1", nn.Linear(in_dim, mlp_dim)), + ("bn1", nn.SyncBatchNorm(mlp_dim)), + ("relu1", nn.ReLU(inplace=True)), + ("layer2", nn.Linear(mlp_dim, mlp_dim)), + ("bn2", nn.SyncBatchNorm(mlp_dim)), + ("relu2", nn.ReLU(inplace=True)), + ("layer3", nn.Linear(mlp_dim, out_dim)), + ])) + + def forward(self, image, text, aug1, aug2): + aug1_embed = self.image_mlp(self.visual(aug1)) + aug2_embed = self.image_mlp(self.visual(aug2)) + + image_embed = self.encode_image(image) + text_embed = self.encode_text(text) + + return {'image_embed': image_embed, + 'text_embed': text_embed, + 'logit_scale': self.logit_scale.exp(), + 'aug1_embed': aug1_embed, + 'aug2_embed': aug2_embed} + + +def get_loss(model, ssl_temp, ssl_scale): + if model.startswith('SLIP'): + ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp) + return losses.SLIPLoss(ssl_loss, ssl_scale) + if model.startswith('CLIP'): + return losses.CLIPLoss() + if model.startswith('SIMCLR'): + return losses.SIMCLRLoss(temperature=ssl_temp) + + +def get_metric_names(model): + if model.startswith('SLIP'): + return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc'] + elif model.startswith('CLIP'): + return ['loss', 'clip_loss', 'clip_acc'] + else: + return ['loss', 'ssl_loss', 'ssl_acc'] + + +@timm.models.registry.register_model +def vit_small_mocov3_patch16_224(**kwargs): + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) + model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs) + + return model + + +def CLIP_VITS16(**kwargs): + vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) + model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408, + transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) + + return model + + +def SIMCLR_VITS16(**kwargs): + vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) + model = SIMCLR(vision_width=384, vision_model=vision_model, **kwargs) + + return model + + +def SLIP_VITS16(**kwargs): + vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) + model = SLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408, + transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) + + return model + + +def CLIP_VITB16(**kwargs): + vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) + model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, + transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) + + return model + + +def SIMCLR_VITB16(**kwargs): + vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) + model = SIMCLR(vision_width=768, vision_model=vision_model, **kwargs) + + return model + + +def SLIP_VITB16(**kwargs): + vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) + model = SLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, + transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) + + return model + + +def CLIP_VITL16(**kwargs): + vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) + model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, + transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) + + return model + + +def SIMCLR_VITL16(**kwargs): + vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) + model = SIMCLR(vision_width=1024, vision_model=vision_model, **kwargs) + + return model + + +def SLIP_VITL16(**kwargs): + vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) + model = SLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, + transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) + + return model diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f4a36c6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +torch>=1.9.0 +torchvision>=0.10.0 +Pillow +towhee.models diff --git a/slip.py b/slip.py new file mode 100644 index 0000000..11111ba --- /dev/null +++ b/slip.py @@ -0,0 +1,93 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from pathlib import Path + +import torch +from torchvision import transforms + +from towhee import register +from towhee.operator.base import NNOperator, OperatorFlag +from towhee.types.arg import arg, to_image_color +from towhee.types.image_utils import from_pil, to_pil + +from tokenizer import SimpleTokenizer + +def get_model(model): + if isinstance(model, torch.nn.DataParallel) \ + or isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + else: + return model + +@register(output_schema=['vec']) +class Slip(NNOperator) + """ + SLIP multi-modal embedding operator + """ + def __init__(self, model_name: str, modality: str): + super().__init__() + sys.path.append(str(Path(__file__).parent)) + self.tokenizer = SimpleTokenizer() + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.model.to(self.device) + self.model.eval() + + self.tfms = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + lambda x: x.convert('RGB'), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + ]) + + def __call__(self, data): + if self._modality == 'image': + vec = self._inference_from_image(data) + elif self._modality == 'text': + vec = self._inference_from_text(data) + else: + raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec.detach().cpu().numpy().flatten() + + def _inference_from_text(self, text): + texts = tokenizer(texts).cuda(non_blocking=True) + texts = texts.view(-1, 77).contiguous() + embedding = get_model(self.model).encode_text(texts) + embedding = embedding / embedding.norm(dim=-1, keepdim=True) + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + img = self._preprocess(img) + img = img.to(self.device) + embedding = get_model(self.model).encode_image(img) + return embedding + + def _preprocess(self, img): + img = to_pil(img) + processed_img = self.tfms(img).unsqueeze(0).to(self.device) + return processed_img + + def _configs(self): + config = {} + config['slip_vit_small'] = {} + config['slip_vit_small']['weights'] = 'https://dl.fbaipublicfiles.com/slip/slip_small_100ep.pt' + config['slip_vit_base'] = {} + config['slip_vit_base']['weights'] = 'https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt' + config['slip_vit_large'] = {} + config['slip_vit_large']['weights'] = 'https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt' + return config + diff --git a/tokenizer.py b/tokenizer.py new file mode 100644 index 0000000..4b3a26e --- /dev/null +++ b/tokenizer.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from github.com/openai/CLIP +import gzip +import html +import os +from functools import lru_cache + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + + def __call__(self, texts, context_length=77): + if isinstance(texts, str): + texts = [texts] + + sot_token = self.encoder["<|startoftext|>"] + eot_token = self.encoder["<|endoftext|>"] + all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + tokens = tokens[:context_length] + result[i, :len(tokens)] = torch.tensor(tokens) + + if len(result) == 1: + return result[0] + return result \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..2ffd522 --- /dev/null +++ b/utils.py @@ -0,0 +1,8 @@ +import torch + +def get_model(model): + if isinstance(model, torch.nn.DataParallel) \ + or isinstance(model, torch.nn.parallel.DistributedDataParallel): + return model.module + else: + return model