You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Files and versions
74 lines
2.9 KiB
74 lines
2.9 KiB
# -*- coding: utf-8 -*-
import os
import json
import torch
import numpy as np
import youtokentome as yttm
import torchvision.transforms as T
from torch.nn.utils.rnn import pad_sequence
class RuCLIPProcessor:
eos_id = 3
bos_id = 2
unk_id = 1
pad_id = 0
def __init__(self, tokenizer_path, image_size=224, text_seq_length=77, mean=None, std=None):
self.tokenizer = yttm.BPE(tokenizer_path)
self.mean = mean or [0.48145466, 0.4578275, 0.40821073]
self.std = std or [0.26862954, 0.26130258, 0.27577711]
self.image_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.RandomResizedCrop(image_size, scale=(1., 1.), ratio=(1., 1.)),
T.Normalize(mean=self.mean, std=self.std)
self.text_seq_length = text_seq_length
self.image_size = image_size
def encode_text(self, text):
text = text.lower()
tokens = self.tokenizer.encode([text], output_type=yttm.OutputType.ID, dropout_prob=0.0)[0]
tokens = tokens[:self.text_seq_length-2]
tokens = [self.bos_id] + tokens + [self.eos_id]
return self.prepare_tokens(tokens)
def prepare_tokens(self, tokens):
empty_positions = self.text_seq_length - len(tokens)
if empty_positions > 0:
tokens = np.hstack((tokens, np.zeros(empty_positions))) # position tokens after text
if len(tokens) > self.text_seq_length:
tokens = tokens[:self.text_seq_length-1] + tokens[-1:]
return torch.tensor(tokens).long()
def decode_text(self, encoded):
return self.tokenizer.decode(encoded.cpu().numpy().tolist(), ignore_ids=[
self.eos_id, self.bos_id, self.unk_id, self.pad_id
def __call__(self, text=None, images=None, **kwargs):
inputs = {}
if text is not None:
input_ids = []
texts = [text] if isinstance(text, str) else text
for text in texts:
tokens = self.encode_text(text)
inputs['input_ids'] = pad_sequence(input_ids, batch_first=True)
if images is not None:
pixel_values = []
for i, image in enumerate(images):
inputs['pixel_values'] = pad_sequence(pixel_values, batch_first=True)
return inputs
def from_pretrained(cls, folder):
tokenizer_path = os.path.join(folder, 'bpe.model')
config = json.load(open(os.path.join(folder, 'config.json')))
image_size = config['image_resolution']
text_seq_length = config['context_length']
mean, std = config.get('mean'), config.get('std')
return cls(tokenizer_path, image_size=image_size, text_seq_length=text_seq_length, mean=mean, std=std)