ru-clip
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
75 lines
2.9 KiB
75 lines
2.9 KiB
2 years ago
|
# -*- coding: utf-8 -*-
|
||
|
import os
|
||
|
import json
|
||
|
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
import youtokentome as yttm
|
||
|
import torchvision.transforms as T
|
||
|
from torch.nn.utils.rnn import pad_sequence
|
||
|
|
||
|
|
||
|
class RuCLIPProcessor:
|
||
|
eos_id = 3
|
||
|
bos_id = 2
|
||
|
unk_id = 1
|
||
|
pad_id = 0
|
||
|
|
||
|
def __init__(self, tokenizer_path, image_size=224, text_seq_length=77, mean=None, std=None):
|
||
|
self.tokenizer = yttm.BPE(tokenizer_path)
|
||
|
self.mean = mean or [0.48145466, 0.4578275, 0.40821073]
|
||
|
self.std = std or [0.26862954, 0.26130258, 0.27577711]
|
||
|
self.image_transform = T.Compose([
|
||
|
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||
|
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)),
|
||
|
T.ToTensor(),
|
||
|
T.Normalize(mean=self.mean, std=self.std)
|
||
|
])
|
||
|
self.text_seq_length = text_seq_length
|
||
|
self.image_size = image_size
|
||
|
|
||
|
def encode_text(self, text):
|
||
|
text = text.lower()
|
||
|
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0]
|
||
|
tokens = tokens[:self.text_seq_length-2]
|
||
|
tokens = [self.bos_id] + tokens + [self.eos_id]
|
||
|
return self.prepare_tokens(tokens)
|
||
|
|
||
|
def prepare_tokens(self, tokens):
|
||
|
empty_positions = self.text_seq_length - len(tokens)
|
||
|
if empty_positions > 0:
|
||
|
tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text
|
||
|
if len(tokens) > self.text_seq_length:
|
||
|
tokens = tokens[:self.text_seq_length-1] + tokens[-1:]
|
||
|
return torch.tensor(tokens).long()
|
||
|
|
||
|
def decode_text(self, encoded):
|
||
|
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
|
||
|
self.eos_id, self.bos_id, self.unk_id, self.pad_id
|
||
|
])[0]
|
||
|
|
||
|
def __call__(self, text=None, images=None, **kwargs):
|
||
|
inputs = {}
|
||
|
if text is not None:
|
||
|
input_ids = []
|
||
|
texts = [text] if isinstance(text, str) else text
|
||
|
for text in texts:
|
||
|
tokens = self.encode_text(text)
|
||
|
input_ids.append(tokens)
|
||
|
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True)
|
||
|
if images is not None:
|
||
|
pixel_values = []
|
||
|
for i, image in enumerate(images):
|
||
|
pixel_values.append(self.image_transform(image))
|
||
|
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True)
|
||
|
return inputs
|
||
|
|
||
|
@classmethod
|
||
|
def from_pretrained(cls, folder):
|
||
|
tokenizer_path = os.path.join(folder, 'bpe.model')
|
||
|
config = json.load(open(os.path.join(folder, 'config.json')))
|
||
|
image_size = config['image_resolution']
|
||
|
text_seq_length = config['context_length']
|
||
|
mean, std = config.get('mean'), config.get('std')
|
||
|
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)
|