slip
copied
wxywb
2 years ago
6 changed files with 612 additions and 0 deletions
@ -0,0 +1,19 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .blip import Blip |
|||
|
|||
|
|||
def blip(model_name: str, modality: str): |
|||
return Blip(model_name, modality) |
@ -0,0 +1,331 @@ |
|||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
|||
# All rights reserved. |
|||
|
|||
# This source code is licensed under the license found in the |
|||
# LICENSE file in the root directory of this source tree. |
|||
|
|||
# Modified from github.com/openai/CLIP |
|||
from collections import OrderedDict |
|||
|
|||
import numpy as np |
|||
import timm |
|||
import torch |
|||
from torch import nn |
|||
|
|||
import losses |
|||
|
|||
|
|||
class LayerNorm(nn.LayerNorm): |
|||
"""Subclass torch's LayerNorm to handle fp16.""" |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
orig_type = x.dtype |
|||
ret = super().forward(x.type(torch.float32)) |
|||
return ret.type(orig_type) |
|||
|
|||
|
|||
class QuickGELU(nn.Module): |
|||
def forward(self, x: torch.Tensor): |
|||
return x * torch.sigmoid(1.702 * x) |
|||
|
|||
|
|||
class ResidualAttentionBlock(nn.Module): |
|||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): |
|||
super().__init__() |
|||
|
|||
self.attn = nn.MultiheadAttention(d_model, n_head) |
|||
self.ln_1 = LayerNorm(d_model) |
|||
self.mlp = nn.Sequential(OrderedDict([ |
|||
("c_fc", nn.Linear(d_model, d_model * 4)), |
|||
("gelu", QuickGELU()), |
|||
("c_proj", nn.Linear(d_model * 4, d_model)) |
|||
])) |
|||
self.ln_2 = LayerNorm(d_model) |
|||
self.attn_mask = attn_mask |
|||
|
|||
def attention(self, x: torch.Tensor): |
|||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None |
|||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
x = x + self.attention(self.ln_1(x)) |
|||
x = x + self.mlp(self.ln_2(x)) |
|||
return x |
|||
|
|||
|
|||
class Transformer(nn.Module): |
|||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): |
|||
super().__init__() |
|||
self.width = width |
|||
self.layers = layers |
|||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) |
|||
|
|||
def forward(self, x: torch.Tensor): |
|||
return self.resblocks(x) |
|||
|
|||
|
|||
class CLIP(nn.Module): |
|||
def __init__(self, |
|||
embed_dim: int, |
|||
# vision |
|||
vision_width: int, |
|||
vision_model: nn.Module, |
|||
# text |
|||
context_length: int, |
|||
vocab_size: int, |
|||
transformer_width: int, |
|||
transformer_heads: int, |
|||
transformer_layers: int, |
|||
**kwargs, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.context_length = context_length |
|||
self.vision_width = vision_width |
|||
|
|||
self.visual = vision_model |
|||
|
|||
self.transformer = Transformer( |
|||
width=transformer_width, |
|||
layers=transformer_layers, |
|||
heads=transformer_heads, |
|||
attn_mask=self.build_attention_mask(), |
|||
) |
|||
|
|||
self.vocab_size = vocab_size |
|||
self.token_embedding = nn.Embedding(vocab_size, transformer_width) |
|||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) |
|||
self.ln_final = LayerNorm(transformer_width) |
|||
|
|||
self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim)) |
|||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) |
|||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|||
|
|||
self.initialize_parameters() |
|||
|
|||
def initialize_parameters(self): |
|||
nn.init.normal_(self.token_embedding.weight, std=0.02) |
|||
nn.init.normal_(self.positional_embedding, std=0.01) |
|||
|
|||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) |
|||
attn_std = self.transformer.width ** -0.5 |
|||
fc_std = (2 * self.transformer.width) ** -0.5 |
|||
for block in self.transformer.resblocks: |
|||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std) |
|||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std) |
|||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) |
|||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) |
|||
|
|||
nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5) |
|||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) |
|||
|
|||
def build_attention_mask(self): |
|||
# lazily create causal attention mask, with full attention between the vision tokens |
|||
# pytorch uses additive attention mask; fill with -inf |
|||
mask = torch.empty(self.context_length, self.context_length) |
|||
mask.fill_(float("-inf")) |
|||
mask.triu_(1) # zero out the lower diagonal |
|||
return mask |
|||
|
|||
def encode_image(self, image): |
|||
x = self.visual(image) |
|||
x = x @ self.image_projection |
|||
|
|||
return x |
|||
|
|||
def encode_text(self, text): |
|||
x = self.token_embedding(text) # [batch_size, n_ctx, d_model] |
|||
x = x + self.positional_embedding |
|||
x = x.permute(1, 0, 2) # NLD -> LND |
|||
x = self.transformer(x) |
|||
x = x.permute(1, 0, 2) # LND -> NLD |
|||
x = self.ln_final(x) |
|||
|
|||
# x.shape = [batch_size, n_ctx, transformer.width] |
|||
# take features from the eot embedding (eot_token is the highest number in each sequence) |
|||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection |
|||
|
|||
return x |
|||
|
|||
def forward(self, image, text): |
|||
image_embed = self.encode_image(image) |
|||
text_embed = self.encode_text(text) |
|||
|
|||
return {'image_embed': image_embed, |
|||
'text_embed': text_embed, |
|||
'logit_scale': self.logit_scale.exp()} |
|||
|
|||
|
|||
class SIMCLR(nn.Module): |
|||
def __init__(self, |
|||
# vision |
|||
vision_width: int, |
|||
vision_model: nn.Module, |
|||
# ssl |
|||
ssl_mlp_dim: int, |
|||
ssl_emb_dim: int, |
|||
**kwargs, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.vision_width = vision_width |
|||
self.visual = vision_model |
|||
|
|||
self.image_mlp = self._build_mlp(in_dim=vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) |
|||
|
|||
def _build_mlp(self, in_dim, mlp_dim, out_dim): |
|||
return nn.Sequential(OrderedDict([ |
|||
("layer1", nn.Linear(in_dim, mlp_dim)), |
|||
("bn1", nn.SyncBatchNorm(mlp_dim)), |
|||
("relu1", nn.ReLU(inplace=True)), |
|||
("layer2", nn.Linear(mlp_dim, mlp_dim)), |
|||
("bn2", nn.SyncBatchNorm(mlp_dim)), |
|||
("relu2", nn.ReLU(inplace=True)), |
|||
("layer3", nn.Linear(mlp_dim, out_dim)), |
|||
])) |
|||
|
|||
def encode_image(self, image): |
|||
x = self.visual(image) |
|||
|
|||
return x |
|||
|
|||
def forward(self, aug1, aug2): |
|||
h1 = self.visual(aug1) |
|||
h2 = self.visual(aug2) |
|||
|
|||
aug1_embed = self.image_mlp(h1) |
|||
aug2_embed = self.image_mlp(h2) |
|||
|
|||
return {'aug1_embed': aug1_embed, |
|||
'aug2_embed': aug2_embed} |
|||
|
|||
|
|||
class SLIP(CLIP): |
|||
def __init__(self, |
|||
ssl_mlp_dim: int, |
|||
ssl_emb_dim: int, |
|||
**kwargs, |
|||
): |
|||
super().__init__(**kwargs) |
|||
|
|||
self.image_mlp = self._build_mlp(in_dim=self.vision_width, mlp_dim=ssl_mlp_dim, out_dim=ssl_emb_dim) |
|||
|
|||
def _build_mlp(self, in_dim, mlp_dim, out_dim): |
|||
return nn.Sequential(OrderedDict([ |
|||
("layer1", nn.Linear(in_dim, mlp_dim)), |
|||
("bn1", nn.SyncBatchNorm(mlp_dim)), |
|||
("relu1", nn.ReLU(inplace=True)), |
|||
("layer2", nn.Linear(mlp_dim, mlp_dim)), |
|||
("bn2", nn.SyncBatchNorm(mlp_dim)), |
|||
("relu2", nn.ReLU(inplace=True)), |
|||
("layer3", nn.Linear(mlp_dim, out_dim)), |
|||
])) |
|||
|
|||
def forward(self, image, text, aug1, aug2): |
|||
aug1_embed = self.image_mlp(self.visual(aug1)) |
|||
aug2_embed = self.image_mlp(self.visual(aug2)) |
|||
|
|||
image_embed = self.encode_image(image) |
|||
text_embed = self.encode_text(text) |
|||
|
|||
return {'image_embed': image_embed, |
|||
'text_embed': text_embed, |
|||
'logit_scale': self.logit_scale.exp(), |
|||
'aug1_embed': aug1_embed, |
|||
'aug2_embed': aug2_embed} |
|||
|
|||
|
|||
def get_loss(model, ssl_temp, ssl_scale): |
|||
if model.startswith('SLIP'): |
|||
ssl_loss = losses.SIMCLRLoss(temperature=ssl_temp) |
|||
return losses.SLIPLoss(ssl_loss, ssl_scale) |
|||
if model.startswith('CLIP'): |
|||
return losses.CLIPLoss() |
|||
if model.startswith('SIMCLR'): |
|||
return losses.SIMCLRLoss(temperature=ssl_temp) |
|||
|
|||
|
|||
def get_metric_names(model): |
|||
if model.startswith('SLIP'): |
|||
return ['loss', 'clip_loss', 'ssl_loss', 'clip_acc', 'ssl_acc'] |
|||
elif model.startswith('CLIP'): |
|||
return ['loss', 'clip_loss', 'clip_acc'] |
|||
else: |
|||
return ['loss', 'ssl_loss', 'ssl_acc'] |
|||
|
|||
|
|||
@timm.models.registry.register_model |
|||
def vit_small_mocov3_patch16_224(**kwargs): |
|||
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=12, **kwargs) |
|||
model = timm.models.vision_transformer._create_vision_transformer('vit_small_patch16_224', **model_kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def CLIP_VITS16(**kwargs): |
|||
vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) |
|||
model = CLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408, |
|||
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def SIMCLR_VITS16(**kwargs): |
|||
vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) |
|||
model = SIMCLR(vision_width=384, vision_model=vision_model, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def SLIP_VITS16(**kwargs): |
|||
vision_model = timm.create_model('vit_small_mocov3_patch16_224', num_classes=0) |
|||
model = SLIP(embed_dim=512, vision_width=384, vision_model=vision_model, context_length=77, vocab_size=49408, |
|||
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def CLIP_VITB16(**kwargs): |
|||
vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) |
|||
model = CLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, |
|||
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def SIMCLR_VITB16(**kwargs): |
|||
vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) |
|||
model = SIMCLR(vision_width=768, vision_model=vision_model, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def SLIP_VITB16(**kwargs): |
|||
vision_model = timm.create_model('vit_base_patch16_224', num_classes=0) |
|||
model = SLIP(embed_dim=512, vision_width=768, vision_model=vision_model, context_length=77, vocab_size=49408, |
|||
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def CLIP_VITL16(**kwargs): |
|||
vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) |
|||
model = CLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, |
|||
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def SIMCLR_VITL16(**kwargs): |
|||
vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) |
|||
model = SIMCLR(vision_width=1024, vision_model=vision_model, **kwargs) |
|||
|
|||
return model |
|||
|
|||
|
|||
def SLIP_VITL16(**kwargs): |
|||
vision_model = timm.create_model('vit_large_patch16_224', num_classes=0) |
|||
model = SLIP(embed_dim=512, vision_width=1024, vision_model=vision_model, context_length=77, vocab_size=49408, |
|||
transformer_width=512, transformer_heads=8, transformer_layers=12, **kwargs) |
|||
|
|||
return model |
@ -0,0 +1,4 @@ |
|||
torch>=1.9.0 |
|||
torchvision>=0.10.0 |
|||
Pillow |
|||
towhee.models |
@ -0,0 +1,93 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import sys |
|||
from pathlib import Path |
|||
|
|||
import torch |
|||
from torchvision import transforms |
|||
|
|||
from towhee import register |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee.types.arg import arg, to_image_color |
|||
from towhee.types.image_utils import from_pil, to_pil |
|||
|
|||
from tokenizer import SimpleTokenizer |
|||
|
|||
def get_model(model): |
|||
if isinstance(model, torch.nn.DataParallel) \ |
|||
or isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|||
return model.module |
|||
else: |
|||
return model |
|||
|
|||
@register(output_schema=['vec']) |
|||
class Slip(NNOperator) |
|||
""" |
|||
SLIP multi-modal embedding operator |
|||
""" |
|||
def __init__(self, model_name: str, modality: str): |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
self.tokenizer = SimpleTokenizer() |
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
self.model.to(self.device) |
|||
self.model.eval() |
|||
|
|||
self.tfms = transforms.Compose([ |
|||
transforms.Resize(224), |
|||
transforms.CenterCrop(224), |
|||
lambda x: x.convert('RGB'), |
|||
transforms.ToTensor(), |
|||
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|||
std=[0.229, 0.224, 0.225]) |
|||
]) |
|||
|
|||
def __call__(self, data): |
|||
if self._modality == 'image': |
|||
vec = self._inference_from_image(data) |
|||
elif self._modality == 'text': |
|||
vec = self._inference_from_text(data) |
|||
else: |
|||
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|||
return vec.detach().cpu().numpy().flatten() |
|||
|
|||
def _inference_from_text(self, text): |
|||
texts = tokenizer(texts).cuda(non_blocking=True) |
|||
texts = texts.view(-1, 77).contiguous() |
|||
embedding = get_model(self.model).encode_text(texts) |
|||
embedding = embedding / embedding.norm(dim=-1, keepdim=True) |
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def _inference_from_image(self, img): |
|||
img = self._preprocess(img) |
|||
img = img.to(self.device) |
|||
embedding = get_model(self.model).encode_image(img) |
|||
return embedding |
|||
|
|||
def _preprocess(self, img): |
|||
img = to_pil(img) |
|||
processed_img = self.tfms(img).unsqueeze(0).to(self.device) |
|||
return processed_img |
|||
|
|||
def _configs(self): |
|||
config = {} |
|||
config['slip_vit_small'] = {} |
|||
config['slip_vit_small']['weights'] = 'https://dl.fbaipublicfiles.com/slip/slip_small_100ep.pt' |
|||
config['slip_vit_base'] = {} |
|||
config['slip_vit_base']['weights'] = 'https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt' |
|||
config['slip_vit_large'] = {} |
|||
config['slip_vit_large']['weights'] = 'https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt' |
|||
return config |
|||
|
@ -0,0 +1,157 @@ |
|||
# Copyright (c) Meta Platforms, Inc. and affiliates. |
|||
# All rights reserved. |
|||
|
|||
# This source code is licensed under the license found in the |
|||
# LICENSE file in the root directory of this source tree. |
|||
|
|||
# Modified from github.com/openai/CLIP |
|||
import gzip |
|||
import html |
|||
import os |
|||
from functools import lru_cache |
|||
|
|||
import ftfy |
|||
import regex as re |
|||
import torch |
|||
|
|||
|
|||
@lru_cache() |
|||
def default_bpe(): |
|||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") |
|||
|
|||
|
|||
@lru_cache() |
|||
def bytes_to_unicode(): |
|||
""" |
|||
Returns list of utf-8 byte and a corresponding list of unicode strings. |
|||
The reversible bpe codes work on unicode strings. |
|||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. |
|||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. |
|||
This is a signficant percentage of your normal, say, 32K bpe vocab. |
|||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. |
|||
And avoids mapping to whitespace/control characters the bpe code barfs on. |
|||
""" |
|||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) |
|||
cs = bs[:] |
|||
n = 0 |
|||
for b in range(2**8): |
|||
if b not in bs: |
|||
bs.append(b) |
|||
cs.append(2**8+n) |
|||
n += 1 |
|||
cs = [chr(n) for n in cs] |
|||
return dict(zip(bs, cs)) |
|||
|
|||
|
|||
def get_pairs(word): |
|||
"""Return set of symbol pairs in a word. |
|||
Word is represented as tuple of symbols (symbols being variable-length strings). |
|||
""" |
|||
pairs = set() |
|||
prev_char = word[0] |
|||
for char in word[1:]: |
|||
pairs.add((prev_char, char)) |
|||
prev_char = char |
|||
return pairs |
|||
|
|||
|
|||
def basic_clean(text): |
|||
text = ftfy.fix_text(text) |
|||
text = html.unescape(html.unescape(text)) |
|||
return text.strip() |
|||
|
|||
|
|||
def whitespace_clean(text): |
|||
text = re.sub(r'\s+', ' ', text) |
|||
text = text.strip() |
|||
return text |
|||
|
|||
|
|||
class SimpleTokenizer(object): |
|||
def __init__(self, bpe_path: str = default_bpe()): |
|||
self.byte_encoder = bytes_to_unicode() |
|||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} |
|||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') |
|||
merges = merges[1:49152-256-2+1] |
|||
merges = [tuple(merge.split()) for merge in merges] |
|||
vocab = list(bytes_to_unicode().values()) |
|||
vocab = vocab + [v+'</w>' for v in vocab] |
|||
for merge in merges: |
|||
vocab.append(''.join(merge)) |
|||
vocab.extend(['<|startoftext|>', '<|endoftext|>']) |
|||
self.encoder = dict(zip(vocab, range(len(vocab)))) |
|||
self.decoder = {v: k for k, v in self.encoder.items()} |
|||
self.bpe_ranks = dict(zip(merges, range(len(merges)))) |
|||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} |
|||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) |
|||
|
|||
def bpe(self, token): |
|||
if token in self.cache: |
|||
return self.cache[token] |
|||
word = tuple(token[:-1]) + ( token[-1] + '</w>',) |
|||
pairs = get_pairs(word) |
|||
|
|||
if not pairs: |
|||
return token+'</w>' |
|||
|
|||
while True: |
|||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) |
|||
if bigram not in self.bpe_ranks: |
|||
break |
|||
first, second = bigram |
|||
new_word = [] |
|||
i = 0 |
|||
while i < len(word): |
|||
try: |
|||
j = word.index(first, i) |
|||
new_word.extend(word[i:j]) |
|||
i = j |
|||
except: |
|||
new_word.extend(word[i:]) |
|||
break |
|||
|
|||
if word[i] == first and i < len(word)-1 and word[i+1] == second: |
|||
new_word.append(first+second) |
|||
i += 2 |
|||
else: |
|||
new_word.append(word[i]) |
|||
i += 1 |
|||
new_word = tuple(new_word) |
|||
word = new_word |
|||
if len(word) == 1: |
|||
break |
|||
else: |
|||
pairs = get_pairs(word) |
|||
word = ' '.join(word) |
|||
self.cache[token] = word |
|||
return word |
|||
|
|||
def encode(self, text): |
|||
bpe_tokens = [] |
|||
text = whitespace_clean(basic_clean(text)).lower() |
|||
for token in re.findall(self.pat, text): |
|||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) |
|||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) |
|||
return bpe_tokens |
|||
|
|||
def decode(self, tokens): |
|||
text = ''.join([self.decoder[token] for token in tokens]) |
|||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ') |
|||
return text |
|||
|
|||
def __call__(self, texts, context_length=77): |
|||
if isinstance(texts, str): |
|||
texts = [texts] |
|||
|
|||
sot_token = self.encoder["<|startoftext|>"] |
|||
eot_token = self.encoder["<|endoftext|>"] |
|||
all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] |
|||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) |
|||
|
|||
for i, tokens in enumerate(all_tokens): |
|||
tokens = tokens[:context_length] |
|||
result[i, :len(tokens)] = torch.tensor(tokens) |
|||
|
|||
if len(result) == 1: |
|||
return result[0] |
|||
return result |
@ -0,0 +1,8 @@ |
|||
import torch |
|||
|
|||
def get_model(model): |
|||
if isinstance(model, torch.nn.DataParallel) \ |
|||
or isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|||
return model.module |
|||
else: |
|||
return model |
Loading…
Reference in new issue