From 49e3970b488beac207af6248023b1046468cdc27 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 28 Oct 2022 15:57:56 +0800 Subject: [PATCH] update the magic. Signed-off-by: wxywb --- .DS_Store | Bin 0 -> 6148 bytes .README.md.swp | Bin 0 -> 12288 bytes README.md | 81 +++- language_model/README.md | 167 +++++++ language_model/dataclass.py | 157 ++++++ language_model/loss_func.py | 80 ++++ language_model/simctg.py | 233 +++++++++ language_model/train.py | 107 +++++ language_model/train_flickr30k.sh | 17 + language_model/train_mscoco.sh | 17 + language_model/trainer.py | 165 +++++++ language_model/utlis.py | 291 ++++++++++++ magic.py | 29 +- zerocap/README.md | 89 ---- zerocap/cog.yaml | 12 - zerocap/flickr30k_zerocap.sh | 14 - zerocap/forbidden_tokens.npy | Bin 7464 -> 0 bytes zerocap/model/ZeroCLIP.py | 389 --------------- zerocap/model/ZeroCLIP_batched.py | 449 ------------------ zerocap/model/__init__.py | 0 .../model/__pycache__/ZeroCLIP.cpython-36.pyc | Bin 13665 -> 0 bytes .../model/__pycache__/ZeroCLIP.cpython-37.pyc | Bin 13594 -> 0 bytes .../ZeroCLIP_batched.cpython-36.pyc | Bin 15664 -> 0 bytes .../ZeroCLIP_batched.cpython-37.pyc | Bin 15568 -> 0 bytes .../model/__pycache__/__init__.cpython-36.pyc | Bin 163 -> 0 bytes .../model/__pycache__/__init__.cpython-37.pyc | Bin 167 -> 0 bytes zerocap/mscoco_zerocap.sh | 14 - zerocap/predict.py | 117 ----- zerocap/predict_arithmetic.py | 129 ----- zerocap/requirements.txt | 3 - zerocap/run.py | 131 ----- zerocap/setup.py | 19 - 32 files changed, 1334 insertions(+), 1376 deletions(-) create mode 100644 .DS_Store create mode 100644 .README.md.swp create mode 100644 language_model/README.md create mode 100644 language_model/dataclass.py create mode 100644 language_model/loss_func.py create mode 100644 language_model/simctg.py create mode 100644 language_model/train.py create mode 100644 language_model/train_flickr30k.sh create mode 100644 language_model/train_mscoco.sh create mode 100644 language_model/trainer.py create mode 100644 language_model/utlis.py delete mode 100644 zerocap/README.md delete mode 100644 zerocap/cog.yaml delete mode 100755 zerocap/flickr30k_zerocap.sh delete mode 100644 zerocap/forbidden_tokens.npy delete mode 100644 zerocap/model/ZeroCLIP.py delete mode 100644 zerocap/model/ZeroCLIP_batched.py delete mode 100644 zerocap/model/__init__.py delete mode 100644 zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc delete mode 100644 zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc delete mode 100644 zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc delete mode 100644 zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc delete mode 100644 zerocap/model/__pycache__/__init__.cpython-36.pyc delete mode 100644 zerocap/model/__pycache__/__init__.cpython-37.pyc delete mode 100755 zerocap/mscoco_zerocap.sh delete mode 100644 zerocap/predict.py delete mode 100644 zerocap/predict_arithmetic.py delete mode 100644 zerocap/requirements.txt delete mode 100644 zerocap/run.py delete mode 100644 zerocap/setup.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..0f450d5f94f0455bc3367ce02859a3e89bd19233 GIT binary patch literal 6148 zcmeHKO-jR15T2<4D=Kv9GA9sv0%M5_*InsCw5?PjB~f(aDhm%#LA-%y@B$v7;Kqel z@CGjY=DlgnYg(5@WCrGaFY}Xm`BE|+BI3^NpiR^wq9#a6~NORMpt;T>cU<}L{knckb6-*RM zkM7ff#f<>K2xceP%Pb)|v0|cFdV~kUJ{9OwSw{@^>9EHdmnfDVeLArYAFP#GClq$o zasSwb6DJ;RH3p1T$uzdaxlkCbEFb4h=101E@w1Y>|-dcM&>9sNR6e=Qd mrN>zcHm(&TR$B26)CugdY=DVk=@Ax){Sk09*kTNBej^$Z<^S}ol z*akN4V(dIP15SdM!6V>d@DR8c>;yZ&1TKOL-~;eJcniD<6iC1fcn~}Qc7S`p z6ky=BZTKDp7y=0*Fb!(pcR2eE{0c6EAHfgcGhlr_3*H0nk~s0-hSF66%RWXl_JEogwvAC}NMXR_mxRxlXdVR*PDh(Jd>b=#0MQKdx_-8)^IR zGm^yj5Cs`^Ww|8!9TA2y>3Mj}V`bY~tyahl$BMCJqt+rJJTpv1<}DVT8=SJdGZVx- z%Sd%8$@}ThOG2(SA2wTLcYTCw-WNt_yK(Gze^SmoW5l`{Pf04Z&Xgv@S4Bpgmdgpt z)+>!zCabG zIanxD1Qv_WPuT%Vr)3&kB?&FJEM(AbjC(e5m$3yckYQwZc*Wd6^5TQ8iet5AH;nB; zDUwi_Na1&0Z2RfzqXm^<_T1h_R7M8aKH2Y4ri1yK=NsP1W3Ih4>D7qGX1)fKd2HG> z%FHkp^R-CGUS#I>%%C?mgxWJW&MPQce7`^@&)otwSD~%?d z95pE7d8#@=HU%dHOm-dnn6GbXs=M{8h7Co4Q@C7Dxryxbb`|C9RZ{n2)v0c;H0oZN zUEZ0rSB%AE-^LYZ#ipL!WIBY;1nGT3OGE+3z~)3MQ-L%Uv?fhtQz@2#G-YCDzRC>} zG%M35OGO~NGT^pg;A9{Io5VxIRnm7&N^f6N`n8;HN~-IYSj23TwHvWTl`fmKS6s=2 zGWd;Q0|%$btJWN2&nX__FjrNot5u(D^k_?Pwp6d;u$L9s!>FCgeRr5vNy9*?(o?bvrP~VG3)IymNwip&6Ee2%&0Ig={pa1{> literal 0 HcmV?d00001 diff --git a/README.md b/README.md index 4b22684..6a4a1b3 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,81 @@ -# magic +# Image Captioning with MAGIC + +*author: David Wang* + + +
+ + +## Description + +This operator generates the caption with [MAGIC](https://arxiv.org/abs/2205.02655) which describes the content of the given image. MAGIC is a simple yet efficient plug-and-play framework, which directly combines an off-the-shelf LM (i.e., GPT-2) and an image-text matching model (i.e., CLIP) for image-grounded text generation. During decoding, MAGIC influences the generation of the LM by introducing a CLIP-induced score, called magic score, which regularizes the generated result to be semantically related to a given image while being coherent to the previously generated context. This is an adaptation from [yxuansu / MAGIC](https://github.com/yxuansu/MAGIC). + + +
+ + +## Code Example + +Load an image from path './image.jpg' to generate the caption. + + *Write the pipeline in simplified style*: + +```python +import towhee + +towhee.glob('./image.jpg') \ + .image_decode() \ + .image_captioning.magic(model_name='expansionnet_rf') \ + .show() +``` +result1 + +*Write a same pipeline with explicit inputs/outputs name specifications:* + +```python +import towhee + +towhee.glob['path']('./image.jpg') \ + .image_decode['path', 'img']() \ + .image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \ + .select['img', 'text']() \ + .show() +``` +result2 + + +
+ + +## Factory Constructor + +Create the operator via the following factory method + +***expansionnet_v2(model_name)*** + +**Parameters:** + +​ ***model_name:*** *str* + +​ The model name of MAGIC. Supported model names: +- magic_mscoco + +
+ +## Interface + +An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. + + +**Parameters:** + +​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* + +​ The image to generate embedding. + + + +**Returns:** *str* + +​ The caption generated by model. diff --git a/language_model/README.md b/language_model/README.md new file mode 100644 index 0000000..da8f45e --- /dev/null +++ b/language_model/README.md @@ -0,0 +1,167 @@ +## Unsupervised Domain Adaptation of Language Model +**** +### Catalogue: +* 1. MSCOCO Benchmark + * 1.1. MSCOCO Data Preparation + * 1.2. Unsupervised Domain Adaptation on MSCOCO +* 2. Flickr30k Benchmark + * 2.1. Flickr30k Data Preparation + * 2.2. Unsupervised Domain Adaptation on Flickr30k +* 3. Unsupervised Baselines + * 3.1. Contrastive Search + * 3.2. Top-k Sampling + * 3.3. Nucleus Sampling + +**** + + +#### 1. MSCOCO Benchmark: + +We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark. + + + +##### 1.1. MSCOCO Data Preparation: + +To prepare the MSCOCO benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). + + + +##### 1.2.Unsupervised Domain Adaptation on MSCOCO: +After preparing the MSCOCO data, run the following command to train the language model. +```yaml +chmod +x ./train_mscoco.sh +./train_mscoco.sh +``` +The arguments are as follows: +* `--model_name`: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large). +* `--train_path`: The file path of training set. +* `--dev_path`: The file path of validation set. +* `--test_path`: The file path of test set. +* `--add_eos_token_to_data`: Whether adding an eos token at the end of text sequence. +* `--margin`: The contrastive margin $\rho$. +* `--max_len`: The maximum length of training samples. +* `--number_of_gpu`: The number of available GPUs. +* `--batch_size_per_gpu`: The batch size for each GPU. +* `--gradient_accumulation_steps`: How many forward computations between two gradient updates. +* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu. +* `--total_steps`: The number of total gradient update steps. +* `--print_every`: Have many steps to show the intermediate results. +* `--save_every`: How many steps to save one checkpoint. +* `--learning_rate`: The learning rate. +* `--save_path_prefix`: Where to save the checkpoints. + +**** + + +#### 2. Flickr30k Benchmark: + +We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark. + + + +##### 2.1. Flickr30k Data Preparation: + +To prepare the Flickr30k benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). + + + +##### 2.2. Unsupervised Domain Adaptation on Flickr30k: +After preparing the Flickr30k data, run the following command to train the language model. +```yaml +chmod +x ./train_flickr30k.sh +./train_flickr30k.sh +``` + +**** + + +#### 3. Unsupervised Baselines: + +Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are **unsupervised** as the language model is a text-only model and does not take image as input. + +```python +# first, load the language model +import torch +from simctg import SimCTG +sos_token, pad_token = r'<-start_of_text->', r'<-pad->' +# we use the language model adapted on MSCOCO as an example. +language_model_name = r'cambridgeltl/magic_mscoco' +generation_model = SimCTG(language_model_name, sos_token, pad_token) +generation_model.eval() + +# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token. +tokens = generation_model.tokenizer.tokenize(sos_token) +input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens) +input_ids = torch.LongTensor(input_ids).view(1,-1) +``` + + + +##### 3.1. Contrastive Search : +```python +''' + use contrastive search to generate the result. + note that, contrastive search is a deterministic decoding method, thus the generated text is always the same. +''' +beam_width, alpha, decoding_len = 45, 0.1, 16 +output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len) +print (output_text) +''' + A man is riding a skateboard down a street. +''' +``` +The arguments are as follows: +* `--input_ids`: The id of the start of sentence token. +* `--beam_width`: k in the contrastive search. +* `--alpha`: alpha in the contrastive search. +* `--decoding_len`: Number of tokens to generate. + + + +##### 3.2. Top-k Sampling : +```python +''' + use top-k sampling to generate the result. + note that, the this method is a stochastic method, thus the generated text is always different. +''' +top_k, decoding_len = 40, 16 +output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len) +print (output_text) +''' + some very different types of vases with flowers together +''' +``` +The arguments are as follows: +* `--input_ids`: The id of the start of sentence token. +* `--k`: The k in top-k sampling. +* `--decoding_len`: Number of tokens to generate. + + + +##### 3.3. Nucleus Sampling : +```python +''' + use nucleus sampling to generate the result. + note that, the this method is a stochastic method, thus the generated text is always different. +''' +nucleus_p, decoding_len = 0.95, 16 +output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len) +print (output_text) +''' + Two young girls enjoying a hot dog hot dog bun. +''' +``` +The arguments are as follows: +* `--input_ids`: The id of the start of sentence token. +* `--nucleus_p`: The probability in nucleus sampling. +* `--decoding_len`: Number of tokens to generate. + + + + + + + + + diff --git a/language_model/dataclass.py b/language_model/dataclass.py new file mode 100644 index 0000000..3786dc5 --- /dev/null +++ b/language_model/dataclass.py @@ -0,0 +1,157 @@ +import json +import random +import torch +import numpy as np +import progressbar +from torch.nn.utils import rnn + +class Data: + def __init__(self, model_name, train_path, dev_path, test_path, max_len, + sos_token, pad_token, add_eos_token_to_data): + ''' + model_name: gpt2 + train_path: training data path + dev_path: validation data path + test_path: test data path + max_len: maximum length for training sequences + sos_token: initialized sos token <-start_of_text-> + pad_token: used to pad the sequences <-pad-> + add_eos_token_to_data: whether we want to the model learn to generate eos token; + if so, the model could automatically stop generation by generating eos token + ''' + from transformers import GPT2TokenizerFast + self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name) + self.sos_token, self.sos_token_id = self.add_special_token(sos_token) + print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) + self.pad_token, self.pad_token_id = self.add_special_token(pad_token) + print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) + self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id + print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) + self.add_eos_token_to_data = add_eos_token_to_data + + self.max_len = max_len + self.train_token_list, self.train_token_id_list = self.process_one_file(train_path) + self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path) + self.test_token_list, self.test_token_id_list = self.process_one_file(test_path) + self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \ + len(self.test_token_list) + print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num)) + + self.train_idx_list = [i for i in range(self.train_num)] + random.shuffle(self.train_idx_list) + self.dev_idx_list = [j for j in range(self.dev_num)] + self.test_idx_list = [j for j in range(self.test_num)] + self.dev_current_idx, self.test_current_idx = 0, 0 + + def add_special_token(self, special_token): + if special_token in self.tokenizer.vocab: + print (special_token + ' token exists.') + else: + print ('Add token to the tokenizer.') + print ('Original vocabulary size is {}'.format(len(self.tokenizer))) + self.tokenizer.add_tokens([special_token]) + print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) + assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 + special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] + return special_token, special_token_id + + def process_one_file(self, path): + print ('Processing {}'.format(path)) + with open(path) as f: + item_list = json.load(f) + lines = [] + for item in item_list: + captions_list = item['captions'] + for one_caption in captions_list: + lines.append(one_caption.strip()) + + res_token_list, res_token_id_list = [], [] + n = len(lines) + p = progressbar.ProgressBar(n) + p.start() + for i in range(n): + p.update(i) + text = lines[i].strip('\n') + self.process_one_text(text, res_token_list, res_token_id_list) + p.finish() + print ('{} processed!'.format(path)) + return res_token_list, res_token_id_list + + def process_one_text(self, text, res_token_list, res_token_id_list): + tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True) + if len(tokens) <= 1: # filter out too short sequence + return + tokens = [self.sos_token] + tokens[:self.max_len] + if self.add_eos_token_to_data: + tokens = tokens + [self.eos_token] + token_ids = self.tokenizer.convert_tokens_to_ids(tokens) + res_token_list.append(tokens) + res_token_id_list.append(token_ids) + return + + def pad_batch(self, batch_id_list): + batch_id_list = [torch.LongTensor(item) for item in batch_id_list] + batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id) + batch_mask = torch.ones_like(batch_tensor) + batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor) + return batch_tensor, batch_mask + + def process_output(self, batch_tgt_id_list): + batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list] + batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence + batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone() + batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone() + return batch_tgt_input_tensor, batch_tgt_output_tensor + + def parse_batch(self, batch_id_list): + batch_input, batch_labels = self.process_output(batch_id_list) + batch_labels[batch_labels[:, :] == self.pad_token_id] = -100 + return batch_input, batch_labels + + def get_next_train_batch(self, batch_size): + batch_idx_list = random.sample(self.train_idx_list, batch_size) + batch_id_list, batch_token_list = [], [] + + for idx in batch_idx_list: + batch_id_list.append(self.train_token_id_list[idx]) + batch_token_list.append(self.train_token_list[idx]) + batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) + return batch_input_tensor, batch_labels, batch_token_list + + def get_next_validation_batch(self, batch_size, mode): + batch_id_list, batch_token_list = [], [] + if mode == 'dev': + curr_select_idx, instance_num = self.dev_current_idx, self.dev_num + tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list + elif mode == 'test': + curr_select_idx, instance_num = self.test_current_idx, self.test_num + tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list + else: + raise Exception('Wrong Validation Mode!!!') + + if curr_select_idx + batch_size < instance_num: + for i in range(batch_size): + curr_idx = curr_select_idx + i + batch_id_list.append(tgt_token_id_list[curr_idx]) + batch_token_list.append(tgt_token_list[curr_idx]) + if mode == 'dev': + self.dev_current_idx += batch_size + else: + self.test_current_idx += batch_size + else: + for i in range(batch_size): + curr_idx = curr_select_idx + i + if curr_idx > instance_num - 1: + curr_idx = 0 + if mode == 'dev': + self.dev_current_idx = 0 + else: + self.test_current_idx = 0 + batch_id_list.append(tgt_token_id_list[curr_idx]) + batch_token_list.append(tgt_token_list[curr_idx]) + if mode == 'dev': + self.dev_current_idx = 0 + else: + self.test_current_idx = 0 + batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) + return batch_input_tensor, batch_labels, batch_token_list diff --git a/language_model/loss_func.py b/language_model/loss_func.py new file mode 100644 index 0000000..96a4243 --- /dev/null +++ b/language_model/loss_func.py @@ -0,0 +1,80 @@ +import torch + +def compute_valid_token_num(valid_len_list): + res = 0 + for one_len in valid_len_list: + res += one_len * (one_len - 1) + return res + +def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0): + ''' + prefix_len: the length of prefix that we do not want to compute CL loss for. + + (1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4), + then the loss padding matrix looks like + [0., 1., 1., 1.], + [1., 0., 1., 1.], + [1., 1., 0., 1.], + [1., 1., 1., 0.] + + (2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3), + then the loss padding matrix looks like + [0., 1., 1., 0.], + [1., 0., 1., 0.], + [1., 1., 0., 0.], + [0., 0., 0., 0.] + ''' + res_list = [] + base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen) + base_mask = base_mask.type(torch.FloatTensor) + bsz = len(valid_len_list) + for i in range(bsz): + one_base_mask = base_mask.clone() + one_valid_len = valid_len_list[i] + one_base_mask[:,one_valid_len:] = 0. + one_base_mask[one_valid_len:, :] = 0. + if prefix_len > 0: + one_base_mask[:prefix_len, :prefix_len] = 0. + res_list.append(one_base_mask) + res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list) + #print (res_mask) + assert res_mask.size() == torch.Size([bsz, seqlen, seqlen]) + return res_mask + +def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0): + ''' + margin: predefined margin to push similarity score away + score_matrix: bsz x seqlen x seqlen + input_ids: bsz x seqlen + pad_token_id: indicating which tokens are padding token + ''' + bsz, seqlen, _ = score_matrix.size() + gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen + gold_score = torch.unsqueeze(gold_score, -1) + assert gold_score.size() == torch.Size([bsz, seqlen, 1]) + difference_matrix = gold_score - score_matrix + assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen]) + loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen + loss_matrix = torch.nn.functional.relu(loss_matrix) + + ### input mask + input_mask = torch.ones_like(input_ids).type(torch.FloatTensor) + if loss_matrix.is_cuda: + input_mask = input_mask.cuda(loss_matrix.get_device()) + input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0) + + if loss_matrix.is_cuda: + input_mask = input_mask.cuda(loss_matrix.get_device()) + + valid_len_list = torch.sum(input_mask, dim = -1).tolist() + loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len) + if score_matrix.is_cuda: + loss_mask = loss_mask.cuda(score_matrix.get_device()) + masked_loss_matrix = loss_matrix * loss_mask + + loss_matrix = torch.sum(masked_loss_matrix, dim = -1) + assert loss_matrix.size() == input_ids.size() + loss_matrix = loss_matrix * input_mask + cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask) + return cl_loss + \ No newline at end of file diff --git a/language_model/simctg.py b/language_model/simctg.py new file mode 100644 index 0000000..25599e9 --- /dev/null +++ b/language_model/simctg.py @@ -0,0 +1,233 @@ +import os +import sys +import operator +from tqdm import tqdm +from operator import itemgetter +import torch +from torch import nn +import random +import argparse +import numpy as np +import torch.nn.functional as F +from torch.nn import CrossEntropyLoss +from loss_func import contrastive_loss +from utlis import PlugAndPlayContrastiveDecodingOneStepFast + +import seaborn as sns +import matplotlib.pyplot as plt +import pandas as pd +import datetime + +train_fct = CrossEntropyLoss() +val_fct = CrossEntropyLoss(reduction='none') +class SimCTG(nn.Module): + def __init__(self, model_name, sos_token, pad_token): + super(SimCTG, self).__init__() + from transformers import AutoTokenizer, GPT2LMHeadModel + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.sos_token, self.sos_token_id = self.add_special_token(sos_token) + print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) + self.pad_token, self.pad_token_id = self.add_special_token(pad_token) + print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) + self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id + print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) + self.model = GPT2LMHeadModel.from_pretrained(model_name) + self.vocab_size = len(self.tokenizer) + print ('Resizing model embedding...') + self.model.resize_token_embeddings(len(self.tokenizer)) + print ('Model embedding resized!') + self.embed_dim = self.model.config.hidden_size + + def add_special_token(self, special_token): + if special_token in self.tokenizer.vocab: + print (special_token + ' token exists.') + else: + print ('Add token to the tokenizer.') + print ('Original vocabulary size is {}'.format(len(self.tokenizer))) + self.tokenizer.add_tokens([special_token]) + print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) + assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 + special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] + return special_token, special_token_id + + def compute_logits_and_hidden_states(self, input_ids): + # used for advanced decoding + # input_ids: 1 x seqlen + outputs = self.model(input_ids=input_ids, output_hidden_states=True) + last_hidden_states = outputs.hidden_states[-1] + logits = outputs.logits + return last_hidden_states, logits + + def forward(self, input_ids, labels, margin): + bsz, seqlen = input_ids.size() + outputs = self.model(input_ids=input_ids, output_hidden_states=True) + logits = outputs.logits + assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) + last_hidden_states = outputs.hidden_states[-1] + assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) + mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1)) + + norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True) + cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2)) + assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen]) + cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0) + return mle_loss, cl_loss + + def eval_loss(self, input_ids, labels): + bsz, seqlen = input_ids.size() + outputs = self.model(input_ids=input_ids, output_hidden_states=True) + logits = outputs.logits + assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) + last_hidden_states = outputs.hidden_states[-1] + assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) + mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1)) + assert mle_loss.size() == torch.Size([bsz * seqlen]) + mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0) + mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0) + # sum + mle_loss_sum = torch.sum(mle_loss) + token_num_sum = torch.sum(mask) + return mle_loss_sum, token_num_sum + + def save_model(self, ckpt_save_path): + import os + if os.path.exists(ckpt_save_path): + pass + else: # recursively construct directory + os.makedirs(ckpt_save_path, exist_ok=True) + # save model + self.model.save_pretrained(ckpt_save_path) + # save tokenizer + self.tokenizer.save_pretrained(ckpt_save_path) + + def parse_sentences(self, text, num_of_sentences_to_keep): + item_list = text.split('.') + res_list = item_list[:num_of_sentences_to_keep] + if len(item_list) > num_of_sentences_to_keep: + res_text = '.'.join(res_list).strip('.') + '.' + else: + res_text = '.'.join(res_list).strip('.').strip() + return res_text + + def parse_generated_result(self, output, num_of_sentences_to_keep): + output_text = self.tokenizer.decode(output) + item_list = output_text.split(self.eos_token) + full_text = self.eos_token.join(item_list[:2]).strip() + full_text = self.parse_sentences(full_text, num_of_sentences_to_keep) + generated_text = item_list[1].strip() + generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep) + return full_text, generated_text + + # decoding functions + # ------------------------------------------------------- # + + def parse_output_token_list(self, output): + output = output.tolist() + res_list = [] + for token_id in output: + if token_id == self.sos_token_id: + continue + elif token_id == self.eos_token_id: + break + else: + res_list.append(token_id) + text = self.tokenizer.decode(res_list).strip() + return ' '.join(text.split()).strip() + + @torch.no_grad() + def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip, + clip_text_max_len):#, add_token_level_score=False): + prefix_len = input_ids.size()[1] + #from utlis import PlugAndPlayContrastiveDecodingOneStepFast + past_key_values, last_hidden_states, logits = None, None, None + generated = [item for item in input_ids.tolist()] + input_ids_for_class = input_ids.clone() + + image_embeds = clip.compute_image_representation_from_image_instance(image_instance) + + start_time = datetime.datetime.now() + + # the maximum supported length of generation for SimCTG is 256 + # to support longer generated length, you can re-train the SimCTG model with longer sequences + decoding_len = decoding_len - prefix_len + for step in range(decoding_len): + input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \ + PlugAndPlayContrastiveDecodingOneStepFast( + self.model, + input_ids, + prefix_len, + beam_width, + alpha, + beta, + self.tokenizer, + image_embeds, + clip, + clip_text_max_len, + past_key_values, + last_hidden_states, + logits, + first_step=step==0, + input_ids_for_class=input_ids_for_class, + ) + end_time = datetime.datetime.now() + time_diff = (end_time - start_time) + execution_time = time_diff.total_seconds() * 1000 + return self.parse_output_token_list(input_ids_for_class[0]) + + def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len): + ''' + input_ids: prefix input; 1 x prefix_len + decoding_len: how many tokens to generate + beam_width: size of candidate pool during decoding + alpha: regulates importance of model confidence and degeneration penalty + ''' + self.model.eval() + #from utlis import ContrastiveDecodingOneStepFast + # sanity check + assert alpha >= 0. and alpha <= 1.0 + + # fast mode + prefix_len = input_ids.size()[1] + batch_size, seqlen = input_ids.size() + #generated = [[] for _ in range(batch_size)] + generated = [item for item in input_ids.tolist()] + past_key_values = None + last_hidden_states = None + logits = None + decoding_len = decoding_len - prefix_len + for step in range(decoding_len): + input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast( + self.model, + input_ids, + beam_width, + alpha, + past_key_values, + last_hidden_states, + self.tokenizer, + logits, + first_step=step == 0, + ) + tokens = input_ids.squeeze(dim=-1).tolist() + for idx, t in enumerate(tokens): + generated[idx].append(t) + return self.parse_output_token_list(torch.LongTensor(generated[0])) + + def top_k_sampling(self, input_ids, k, decoding_len): + _, prefix_len = input_ids.size() + output = self.model.generate( + input_ids, + do_sample=True, + max_length=decoding_len, + top_p=1.0, + top_k=k) + return self.parse_output_token_list(output[0]) + + def nucleus_sampling(self, input_ids, nucleus_p, decoding_len): + _, prefix_len = input_ids.size() + output = self.model.generate( + input_ids, + do_sample=True, + max_length=decoding_len, + top_p=nucleus_p, + top_k=0) + return self.parse_output_token_list(output[0]) diff --git a/language_model/train.py b/language_model/train.py new file mode 100644 index 0000000..f401e44 --- /dev/null +++ b/language_model/train.py @@ -0,0 +1,107 @@ +# coding=utf-8 +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.multiprocessing as mp +import argparse, os +import random +import numpy as np +import time +import logging +import progressbar + +import logging +logging.getLogger('transformers.generation_utils').disabled = True + +def parse_config(): + parser = argparse.ArgumentParser() + # data configuration + parser.add_argument("--model_name", type=str, default='gpt2') + parser.add_argument("--train_path", type=str) + parser.add_argument("--dev_path", type=str) + parser.add_argument("--test_path", type=str) + parser.add_argument("--max_len", type=int) + parser.add_argument("--add_eos_token_to_data", type=str) + # mini-batch training configuration + parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.") + parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.') + parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.") + parser.add_argument("--effective_batch_size", type=int, + help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps") + # pre-training configuration + parser.add_argument("--total_steps", type=int, + help="total effective training steps") + parser.add_argument("--print_every", type=int, + help="how many update steps to print one intermediate result") + parser.add_argument("--save_every", type=int, + help="how many update steps to save one model") + # learning configuration + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--margin", type=float) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.") + return parser.parse_args() + +def load_previous_best_model(path): + import os + filenames = os.listdir(path) + for file in filenames: + if file.startswith('training_step'): + return path + '/' + file + raise Exception('No best model found!') + +import argparse +if __name__ == '__main__': + if torch.cuda.is_available(): + print ('Cuda is available.') + cuda_available = torch.cuda.is_available() + multi_gpu_training = False + if cuda_available: + if torch.cuda.device_count() > 1: + multi_gpu_training = True + print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count())) + else: + print ('Using single GPU training.') + else: + pass + args = parse_config() + device = torch.device('cuda') + model_name = args.model_name + + sos_token, pad_token = r'<-start_of_text->', r'<-pad->' + add_eos_token_to_data = args.add_eos_token_to_data + if add_eos_token_to_data == 'True': + add_eos_token_to_data = True + print ('Add eos token to data!') + elif add_eos_token_to_data == 'False': + add_eos_token_to_data = False + print ('Do not add eos token to data!') + else: + raise Exception('Wrong eos configuration for data!!!') + print ('Loading data...') + from dataclass import Data + data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len, + sos_token, pad_token, add_eos_token_to_data) + print ('Data loaded.') + + from trainer import model_training + print ('############################################################') + print ('Start Training...') + from simctg import SimCTG + print ('Initializaing SimCTG model...') + model = SimCTG(model_name, sos_token, pad_token) + if cuda_available: + if multi_gpu_training: + model = nn.DataParallel(model) # multi-gpu training + else: + pass + model = model.to(device) + else: + pass + print ('Model loaded') + total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every + ckpt_save_path = args.save_path_prefix + model = model_training(args, data, model, total_steps, print_every, save_every, + ckpt_save_path, cuda_available, device) + print ('Training stage completed!') + print ('############################################################') diff --git a/language_model/train_flickr30k.sh b/language_model/train_flickr30k.sh new file mode 100644 index 0000000..f76d082 --- /dev/null +++ b/language_model/train_flickr30k.sh @@ -0,0 +1,17 @@ +CUDA_VISIBLE_DEVICES=0 python train.py\ + --model_name gpt2\ + --train_path ../data/flickr30k/flickr30k_train.json\ + --dev_path ../data/flickr30k/flickr30k_val.json\ + --test_path ../data/flickr30k/flickr30k_test.json\ + --add_eos_token_to_data True\ + --margin 0.5\ + --max_len 64\ + --number_of_gpu 1\ + --batch_size_per_gpu 32\ + --gradient_accumulation_steps 4\ + --effective_batch_size 128\ + --total_steps 10000\ + --print_every 50\ + --save_every 250\ + --learning_rate 2e-5\ + --save_path_prefix ./magic_flickr30k/ \ No newline at end of file diff --git a/language_model/train_mscoco.sh b/language_model/train_mscoco.sh new file mode 100644 index 0000000..7ab7582 --- /dev/null +++ b/language_model/train_mscoco.sh @@ -0,0 +1,17 @@ +CUDA_VISIBLE_DEVICES=0 python train.py\ + --model_name gpt2\ + --train_path ../data/mscoco/mscoco_train.json\ + --dev_path ../data/mscoco/mscoco_val.json\ + --test_path ../data/mscoco/mscoco_test.json\ + --add_eos_token_to_data True\ + --margin 0.5\ + --max_len 64\ + --number_of_gpu 1\ + --batch_size_per_gpu 32\ + --gradient_accumulation_steps 4\ + --effective_batch_size 128\ + --total_steps 20000\ + --print_every 100\ + --save_every 500\ + --learning_rate 2e-5\ + --save_path_prefix ./magic_mscoco/ \ No newline at end of file diff --git a/language_model/trainer.py b/language_model/trainer.py new file mode 100644 index 0000000..e51a850 --- /dev/null +++ b/language_model/trainer.py @@ -0,0 +1,165 @@ +# coding=utf-8 +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.multiprocessing as mp +import argparse, os +import random +import numpy as np +import time +import logging +import progressbar + +import logging +logging.getLogger('transformers.generation_utils').disabled = True + +def eval_model(args, model, data, cuda_available, device): + dataset_batch_size = args.batch_size_per_gpu * args.number_of_gpu + eval_step = int(data.test_num / dataset_batch_size) + 1 + val_loss, token_sum = 0., 0. + model.eval() + with torch.no_grad(): + p = progressbar.ProgressBar(eval_step) + p.start() + for idx in range(eval_step): + p.update(idx) + batch_input_tensor, batch_labels, _ = \ + data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test') + if cuda_available: + batch_input_tensor = batch_input_tensor.cuda(device) + batch_labels = batch_labels.cuda(device) + one_val_loss, one_val_token_sum = model.eval_loss(batch_input_tensor, batch_labels) + one_val_loss = torch.sum(one_val_loss) + one_val_token_sum = torch.sum(one_val_token_sum) + val_loss += one_val_loss.item() + token_sum += one_val_token_sum.item() + p.finish() + model.train() + val_loss = val_loss / token_sum + return val_loss + +def model_training(args, data, model, total_steps, print_every, save_every, ckpt_save_path, cuda_available, device): + import os + if os.path.exists(ckpt_save_path): + pass + else: # recursively construct directory + os.makedirs(ckpt_save_path, exist_ok=True) + + max_save_num = 1 + + batch_size_per_gpu, gradient_accumulation_steps, number_of_gpu, effective_batch_size = \ + args.batch_size_per_gpu, args.gradient_accumulation_steps, args.number_of_gpu, args.effective_batch_size + assert effective_batch_size == batch_size_per_gpu * gradient_accumulation_steps * number_of_gpu + + warmup_steps = int(0.1 * total_steps) # 10% of training steps are used for warmup + print ('total training steps is {}, warmup steps is {}'.format(total_steps, warmup_steps)) + from transformers.optimization import AdamW, get_linear_schedule_with_warmup + optimizer = AdamW(model.parameters(), lr=args.learning_rate) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) + optimizer.zero_grad() + + effective_batch_acm = 0 + all_batch_step = 1 + print_valid, save_valid = False, False + train_loss, train_cl_loss, min_val_loss = 0., 0., 1e10 + train_ave_bleu = 0. + + print ('--------------------------------------------------------------------------') + print ('Start Training:') + model.train() + number_of_saves = 0 + + while effective_batch_acm < total_steps: + all_batch_step += 1 + train_batch_input_tensor, train_batch_labels, _ = data.get_next_train_batch(batch_size_per_gpu * number_of_gpu) + if cuda_available: + train_batch_input_tensor = train_batch_input_tensor.cuda(device) + train_batch_labels = train_batch_labels.cuda(device) + mle_loss, cl_loss = model(train_batch_input_tensor, train_batch_labels, args.margin) + + loss = mle_loss + cl_loss + loss = loss.mean() + loss.backward() + train_loss += mle_loss.item() + train_cl_loss += cl_loss.item() + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + # parameter update + if all_batch_step % gradient_accumulation_steps == 0: + optimizer.step() + scheduler.step() + optimizer.zero_grad() + effective_batch_acm += 1 + print_valid, save_valid = True, True + + # print intermediate result + if effective_batch_acm % print_every == 0 and print_valid: + denominator = (effective_batch_acm - (number_of_saves * save_every)) * gradient_accumulation_steps + one_train_loss = train_loss / denominator + one_train_cl_loss = train_cl_loss / denominator + print ('At training steps {}, training MLE loss is {}, train CL loss is {}'.format(effective_batch_acm, + one_train_loss, one_train_cl_loss)) + print_valid = False + + # saving result + if effective_batch_acm % save_every == 0 and save_valid: + number_of_saves += 1 + + save_valid = False + one_train_loss = train_loss / (save_every * gradient_accumulation_steps) + one_train_cl_loss = train_cl_loss / (save_every * gradient_accumulation_steps) + + model.eval() + one_val_loss = eval_model(args, model, data, cuda_available, device) + model.train() + + print ('At training steps {}, training MLE loss is {}, train CL loss is {}, validation loss is {}'.format(effective_batch_acm, + one_train_loss, one_train_cl_loss, one_val_loss)) + + train_loss, train_cl_loss = 0., 0. + + if one_val_loss < min_val_loss: + # in finetuning stage, we always save the model + min_val_loss = min(one_val_loss, min_val_loss) + print ('Saving model...') + one_val_ppl = np.exp(one_val_loss) + one_val_ppl = round(one_val_ppl, 3) + save_name = 'training_step_{}_train_mle_loss_{}_train_cl_loss_{}_dev_loss_{}_dev_ppl_{}'.format(effective_batch_acm, + round(one_train_loss,5), round(one_train_cl_loss,5), round(one_val_loss,5), one_val_ppl) + + model_save_path = ckpt_save_path + '/' + save_name + import os + if os.path.exists(model_save_path): + pass + else: # recursively construct directory + os.makedirs(model_save_path, exist_ok=True) + if cuda_available and torch.cuda.device_count() > 1: + model.module.save_model(model_save_path) + else: + model.save_model(model_save_path) + print ('Model Saved!') + + # --------------------------------------------------------------------------------------------- # + # removing extra checkpoints... + import os + from operator import itemgetter + fileData = {} + test_output_dir = ckpt_save_path + for fname in os.listdir(test_output_dir): + if fname.startswith('training_step'): + fileData[fname] = os.stat(test_output_dir + '/' + fname).st_mtime + else: + pass + sortedFiles = sorted(fileData.items(), key=itemgetter(1)) + + if len(sortedFiles) < max_save_num: + pass + else: + delete = len(sortedFiles) - max_save_num + for x in range(0, delete): + one_folder_name = test_output_dir + '/' + sortedFiles[x][0] + os.system('rm -r ' + one_folder_name) + print ('-----------------------------------') + # --------------------------------------------------------------------------------------------- # + return model + diff --git a/language_model/utlis.py b/language_model/utlis.py new file mode 100644 index 0000000..a739cd2 --- /dev/null +++ b/language_model/utlis.py @@ -0,0 +1,291 @@ +import sys +import os +import operator +from operator import itemgetter +import torch +from torch import nn +import torch.nn.functional as F +import random +import numpy as np +import argparse +import random + +def parse_prompt(text): + ''' + process the prompt text; + ''' + eos_token = '<|endoftext|>' + text = text.strip(eos_token).strip() + left_bracket_idx, right_bracket_idx = -1, -1 + for idx in range(len(text)): + char = text[idx] + if char == '[' and left_bracket_idx == -1: # first [ is met + left_bracket_idx = idx + elif char == ']' and right_bracket_idx == -1: # first ] is met + right_bracket_idx = idx + else: + pass + res_text = '' + remove = False + if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx: + if right_bracket_idx - left_bracket_idx <= 6: + remove = True + else: + pass + + for idx in range(len(text)): + if remove: + if idx >= left_bracket_idx and idx <= right_bracket_idx: + continue + else: + res_text += text[idx] + else: + res_text += text[idx] + res_text = res_text.strip() + res_text = ' '.join(res_text.split()).strip() + return res_text + +def typical_filtering(scores, mass, min_tokens_to_keep, filter_value): + # calculate entropy + normalized = torch.nn.functional.log_softmax(scores, dim=-1) + p = torch.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = torch.abs((-normalized) - ent) + sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < mass).sum(dim=1) + last_ind[last_ind < 0] = 0 + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + scores = scores.masked_fill(indices_to_remove, filter_value) + return scores + +def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf): + assert logits.dim() == 1 + top_k = min(top_k, logits.size(-1)) + if top_k > 0: + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + logits[indices_to_remove] = filter_value + + indices_to_remove = logits < threshold + logits[indices_to_remove] = filter_value + return logits + +# ========== batch version ========= # +def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width): + ''' + context_hidden: bsz*beam x seqlen x embed_dim + next_hidden: bsz*beam x 1 x embed_dim + next_top_k_probs: bsz x beam + ''' + _, context_len, embed_dim = context_hidden.size() + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S] + scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] + next_top_k_probs = next_top_k_probs.view(-1) # [B*K] + scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + scores = torch.stack(torch.split(scores, beam_width)) # [B, K] + selected_idx = scores.max(dim=-1)[1] # [B] + return selected_idx + +def ContrastiveDecodingOneStepFast( + model, + ids, + beam_width, + alpha, + past_key_values, + last_hidden_states, + vocab, + logit_for_next_step, + first_step=False, + ): + # input_ids: [B, S] + if first_step: + output = model( + input_ids=ids, + past_key_values=past_key_values, + use_cache=True, + output_hidden_states=True + ) + past_key_values = output.past_key_values + last_hidden_states = output.hidden_states[-1] # [B, S, E] + logit_for_next_step = output.logits[:, -1, :] # [B, V] + bsz, seqlen, embed_dim = last_hidden_states.size() + p = random.uniform(0, 1) + + next_probs = F.softmax(logit_for_next_step, dim=-1) + _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K] + top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K] + # compute new hidden + past_key_values = enlarge_past_key_values(past_key_values, beam_width) + output = model( + input_ids=top_k_ids.view(-1, 1), + attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), + past_key_values=past_key_values, + output_hidden_states=True, + use_cache=True, + ) + past_key_values = output.past_key_values + logits = output.logits[:, -1, :] # [B*K, V] + next_hidden = output.hidden_states[-1] # [B*K, 1, E] + context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E] + + selected_idx = ranking_fast( + context_hidden, + next_hidden, + top_k_probs, # [B, K] + alpha, + beam_width, + ) # [B] + # prepare for the next step + next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1] + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E] + next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E] + past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) + logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V] + # next_id: [B, 1] + return next_id, past_key_values, last_hidden_states, logits + +def enlarge_past_key_values(past_key_values, beam_width): + # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] + new_key_values = [] + for layer in past_key_values: + items = [] + for item in layer: + # item is the key and value matrix + bsz, num_head, seq_len, esz = item.size() + item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + return new_key_values + +def select_past_key_values(past_key_values, beam_width, selected_idx): + '''select_idx: [B]''' + new_key_values = [] + for layer in past_key_values: + items = [] + for item in layer: + bsz_and_beam, num_head, seq_len, esz = item.size() + bsz = int(bsz_and_beam//beam_width) + item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] + item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] + items.append(item) + new_key_values.append(items) + return new_key_values + +# ========== fast plug and play version ========= # +def plug_and_play_fast_ranking( + context_hidden, + next_hidden, + next_top_k_ids, + next_top_k_probs, + alpha, + beta, + batch_class_score, + beam_width, +): + ''' + context_hidden: beam_width x context_len x embed_dim + next_hidden: beam_width x 1 x embed_dim + next_top_k_ids: beam_width x 1 + batch_class_score: beam_width x 1 + ''' + _, context_len, embed_dim = context_hidden.size() + norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) + norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) + cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) + scores, _ = torch.max(cosine_matrix, dim = -1) + next_top_k_probs = next_top_k_probs.view(-1) + scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width]) + scores = torch.stack(torch.split(scores, beam_width)) + selected_idx = scores.max(dim=-1)[1] + return selected_idx + +def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta, + simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states, + logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False): + ''' + model: the generation model, e.g., gpt2 + input_ids: 1 x seqlen + ''' + + if first_step: + output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True) + past_key_values = output.past_key_values + last_hidden_states = output.hidden_states[-1] # [B, S, E] + logit_for_next_step = output.logits[:, -1, :] # [B, V] + bsz, seqlen, embed_dim = last_hidden_states.size() + next_probs = F.softmax(logit_for_next_step, dim = -1) + _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width) + top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids) + + # compute the new hidden + past_key_values = enlarge_past_key_values(past_key_values, beam_width) + output = model( + input_ids=top_k_ids.view(-1, 1) , + attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), + past_key_values=past_key_values, + output_hidden_states=True, + use_cache=True, + ) + past_key_values = output.past_key_values + logits = output.logits[:, -1, :] + next_hidden = output.hidden_states[-1] + context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) + + # prepare for the classification model + input_ids_for_class_ = torch.cat([ + input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen), + top_k_ids.view(-1, 1) + ], dim=-1 + ) + + batch_text_list = [] + for one_input_id in input_ids_for_class_: + one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:]) + # we only consider the class score of the generated text continuation + batch_text_list.append(one_text) + batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) + + selected_idx = plug_and_play_fast_ranking( + context_hidden, + next_hidden, + top_k_ids, + top_k_probs, + alpha, + beta, + batch_score, + beam_width, + ) + + # prepare for the next step + next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) + next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) + next_hidden = next_hidden[range(bsz), selected_idx, :] + last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) + past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) + logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] + input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1) + return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class + + diff --git a/magic.py b/magic.py index 85f46a7..d291142 100644 --- a/magic.py +++ b/magic.py @@ -29,7 +29,6 @@ from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register -from towhee.models import clip class Magic(NNOperator): """ @@ -38,22 +37,32 @@ class Magic(NNOperator): def __init__(self, model_name: str): super().__init__() path = str(pathlib.Path(__file__).parent) - sys.path.append(path) + sys.path.append(path + '/clip') + sys.path.append(path + '/language_model') + print(sys.path) from clip import CLIP from simctg import SimCTG sys.path.pop() + sys.path.pop() self.device = "cuda" if torch.cuda.is_available() else "cpu" # Load Language Model - language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco' + cfg = self._configs()[model_name] + language_model_name = cfg['language_model'] # or r'/path/to/downloaded/cambridgeltl/magic_mscoco' sos_token, pad_token = r'<-start_of_text->', r'<-pad->' self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device) self.generation_model.eval() - model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32" + model_name = cfg['clip_model'] # or r"/path/to/downloaded/openai/clip-vit-base-patch32" self.clip = CLIP(model_name).to(self.device) + self.clip.to(self.device) self.clip.eval() + sos_token = r'<-start_of_text->' + start_token = self.generation_model.tokenizer.tokenize(sos_token) + start_token_id = self.generation_model.tokenizer.convert_tokens_to_ids(start_token) + self.input_ids = torch.LongTensor(start_token_id).view(1,-1).to(self.device) + def _preprocess(self, img): img = to_pil(img) @@ -87,13 +96,15 @@ class Magic(NNOperator): k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16 eos_token = '<|endoftext|>' with torch.no_grad(): - output = generation_model.magic_search(input_ids, k, - alpha, decoding_len, beta, image_instance, clip, 60) + print(type(img)) + output = self.generation_model.magic_search(self.input_ids, k, + alpha, decoding_len, beta, img, self.clip, 60) - return out + return output def _configs(self): config = {} - config['expansionnet_rf'] = {} - config['expansionnet_rf']['weights'] = 'rf_model.pth' + config['magic_mscoco'] = {} + config['magic_mscoco']['language_model'] = 'cambridgeltl/magic_mscoco' + config['magic_mscoco']['clip_model'] = 'openai/clip-vit-base-patch32' return config diff --git a/zerocap/README.md b/zerocap/README.md deleted file mode 100644 index 1839550..0000000 --- a/zerocap/README.md +++ /dev/null @@ -1,89 +0,0 @@ -### Our Implementation of the ZeroCap Baseline Model - -**** -### Catalogue: -* 1. Environment Preparation -* 2. Image Captioning on MSCOCO -* 3. Image Captioning on Flickr30k -* 4. Cross Domain Image Captioning on MSCOCO -* 5. Cross Domain Image Captioning on Flickr30k -* 6. Citation -* 7. Acknowledgements - -**** - - - -#### 1. Environment Preparation: -To install the correct environment, please run the following command: -```yaml -pip install -r requirements.txt -``` - -**** - - - -#### 2. Image Captioning on MSCOCO: -To perform image captioning on MSCOCO, please run the following command: -```yaml -chmod +x ./mscoco_zerocap.sh -./mscoco_zerocap.sh -``` - -**** - - - -#### 3. Image Captioning on Flickr30k: -To perform image captioning on Flickr30k, please run the following command: -```yaml -chmod +x ./flickr30k_zerocap.sh -./flickr30k_zerocap.sh -``` - -**** - - - -#### 4. Cross Domain Image Captioning on MSCOCO: -To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command: -```yaml -chmod +x ./flickr30k_to_mscoco_zerocap.sh -./flickr30k_to_mscoco_zerocap.sh -``` - -**** - - - -#### 5. Cross Domain Image Captioning on Flickr30k: -To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command: -```yaml -chmod +x ./mscoco_to_flickr30k_zerocap.sh -./mscoco_to_flickr30k_zerocap.sh -``` - -**** - - - -#### 6. Citation: -If you find our code helpful, please cite the original paper as - -```bibtex -@article{tewel2021zero, - title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, - author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, - journal={arXiv preprint arXiv:2111.14447}, - year={2021} -} -``` - -**** - - - -#### 7. Acknowledgements: -We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text). - diff --git a/zerocap/cog.yaml b/zerocap/cog.yaml deleted file mode 100644 index 92f13da..0000000 --- a/zerocap/cog.yaml +++ /dev/null @@ -1,12 +0,0 @@ -build: - gpu: true - python_version: "3.8" - system_packages: - - "libgl1-mesa-glx" - - "libglib2.0-0" - python_packages: - - "git+https://github.com/openai/CLIP.git" - - "git+https://github.com/YoadTew/zero-shot-image-to-text.git" - -predict: "predict.py:Predictor" -#predict: "predict_arithmetic.py:Predictor" \ No newline at end of file diff --git a/zerocap/flickr30k_zerocap.sh b/zerocap/flickr30k_zerocap.sh deleted file mode 100755 index b727b12..0000000 --- a/zerocap/flickr30k_zerocap.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -# lm_model: -# 1. cambridgeltl/magic_mscoco -# 2. cambridgeltl/magic_flickr30k -CUDA_VISIBLE_DEVICES=1 python run.py \ - --beam_size 1 \ - --target_seq_length 16 \ - --reset_context_delta \ - --lm_model cambridgeltl/magic_flickr30k \ - --test_image_prefix_path ../data/flickr30k/test_images \ - --test_path ../data/flickr30k/flickr30k_test.json \ - --save_path_prefix ../inference_result/flickr30k/baselines/ \ - --save_name zerocap_result.json diff --git a/zerocap/forbidden_tokens.npy b/zerocap/forbidden_tokens.npy deleted file mode 100644 index aeed51b4396b77c16800aa57b47eff81db7fe9fa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 7464 zcmbW3`(Mv@8^_xm%3(@H!ss@)bo=s9LR2K82$3ixR6;2$LI)yNb5>EQ5#>;8h7z%8 zGUoW;7Pj3vZB}lF9azX|_x5-_{R6jOp04YCy|4Fm{CrCv<2TwT;60;Qqj=kp(3s$8 z+krN=E_0l04K}v3qM{c>&xj0)iVg|=TR(h8L`>*k`k2`>=7;_zcOPKi-(Y93NwoRD z|1~!?(mylU8k$23XbIcEw$KW;gT;;XpH+=TJ9x2)HnkoO=z!=Fn_*|{|rdteubY{h$nk+;D=&_^PFh+kIe?9tS;wnL3k6%mtG{#Q>tmK|Gb^C?5>lo|?^qtA`1$FW1qPSYSh=I;xmZR7e zmb*%h93%!I$2v(qXb@uti_->(_ua&Uuo3UXjgY=!n0R1_c#QimY$Y2;iQi&hg5HUH zVLf$TgWY!g{9rGAUoSDsL2N%!tR5}eyNi2~^O3#AN&nhgTr);|1%0X0g{ji-gX!ZX zxAYVPdH){%YPlavK2M0d^dGXXpC;yei0Qs!+X-Sn`~@Mm=N)V8UJ-9>f7wqB6@Q)~ zdgA|^$&z0Zmnm|opY(3@FPC~P8!EliT+y3+`v*%t8z8o&FE+fhZj$sF_0d&a^d7_!g}>$aD0OyZ2d-@@s#FJvBS z@K-xX`frK%=LpGNXN%#ybCtf>Fb8Mx(}aBDsN0B-bZ_rG(Tschqa?2;9uG&!i9N-} zyl=#NVf6h5e&;Nd-HDII*ckB&dv=O_)E=&#e{1J{5a=i=peq{7qsac_*xte4YF{WJ`|1-(@&9L;COO%N6RCktlr{ zeb0!Ke4M&CcjX zH1};`8GbW~!yS1Y_p|V8wOD?>a2Dh7=etJoqSaz9{W>yT^7*CWe&V$Zm+XtbGyLxD zrJnOQ==|R)qJj5&F)vB)>Pvs>$Ul#~%c!3@e$ttnL--pQAiqP&ua>zmp-+2q>&W%31l2CXT&)kNnx| z`^-@@-ubOSeg|fYBZ|ce=KEHv&#RJbaX?%}yl)OmPU2oskz`BWJA(ZF5$O%=eWx7B2qu5=Meq0W6lik_IJ7NZ9y)f54p_4g>5=_C9g-D zC3o2e$^QlnoG!g5eW{|Z>)D$s?%yH4V)z<+H|8jOwfw{+iOKB6kIdJ~iPDe#OuR;X z%Q7VgknalookwPwwRSL~YrO(EH1#|lyakXN9#&O=N zO7^?)cg`Qb?9D^^afkShlFxARxKBT7@^!Bpb9<3EC(-{N++Y8>>>eEwTN0OZrR0g( zVm$LZfjYmX9**pLAbs~vlf5nZykZWU63z~2ecDc)%!s3bdm(%;z1Xu$^y5?V zUcOQ9Ih_zq@t1=<3xAHx-R^_3t7D$McS!z@;co~1iC|7vlcyQ=swbcLPvmbDb!v>=D)!?I_1cL4ZS=JX z@49Z*{b1s&pe{G5Q)V}vM-hi7bvO${lXU(&zhfi$uGR)hKZ4%{<9Nwy=;J)%OW}L8 zmHK7DncNQ|-i>@u*OJHQa4q#3g`c`{!4y0m?P&zou@4lr{cemJ?+E$R_x&|`d>nS{rFu>kJkNh z^#AN9C>d!Ece2=|tAexJT> zVE&(Ul)W?d8<=|odA?vyr$@>zBtvv#56_{uVy>)+FNnIfpneaS(`&^0vO<1-Vjl`Z zBrhheU-{ijqn<|WX&7}+;(P4CckD3n-$PERl)oLsbGle^DSn61&+vKDU&#{9>5C_G z=*qnz4m!Why(Ii3Q`aA<>rM8`iFe=P$7GZ4o6_&OqhT(LRzJi#7jl4m>Y!}-n!P}er(TgCad74ox- z{klz`UN8r#?mD-oZUILncjE6MN9tZoy{=NP$APl*E)dgaiQB@(oylUqb)qHt-#;d~ zDf4umx%-oNr%>-O;!WjywT^z=K(3+=W0<3+)ag0jkLL8}MV{Up3_B8k0{d;xJ2mXf zX7+Uua!>M@KUIEwsY4I;FNp6)D*ILTvFx@`w_S53`^^@UgT<23;*}uLmU?bNUhOY^ zhM#CQNqp%mE*USrKa%qvVuO>|i@tt0NOFX$_^`id?kJYoizYqAkL<+2?xJ@$aZ(5I zSvxVzLNqrQk64QNZNz(>#NJcHY4qnA^SY6JTtMC?Gi0}i`DuVf{GGReJdNo07Jj$4 fOw+wbm7*K-m5AMj3h9jnqZZHvzJoRYISu~@MICR& diff --git a/zerocap/model/ZeroCLIP.py b/zerocap/model/ZeroCLIP.py deleted file mode 100644 index 2c36fd2..0000000 --- a/zerocap/model/ZeroCLIP.py +++ /dev/null @@ -1,389 +0,0 @@ -import numpy as np -from torch import nn -from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer -from transformers.models.gpt_neo import GPTNeoForCausalLM -import torch -import clip -from PIL import Image -from datetime import datetime -import sys - - -def log_info(text, verbose=True): - if verbose: - dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") - print(f'{dt_string} | {text}') - sys.stdout.flush() - - -def add_context(x, y): - return (x[0] + y[0], x[1] + y[1]) - - -def convert_models_to_fp32(model): - for p in model.parameters(): - p.data = p.data.float() - - -class CLIPTextGenerator: - def __init__(self, - seed=0, - lm_model='gpt-2', - forbidden_tokens_file_path='./forbidden_tokens.npy', - clip_checkpoints='./clip_checkpoints', - target_seq_length=15, - reset_context_delta=True, - num_iterations=5, - clip_loss_temperature=0.01, - clip_scale=1., - ce_scale=0.2, - stepsize=0.3, - grad_norm_factor=0.9, - fusion_factor=0.99, - repetition_penalty=1., - end_token='.', - end_factor=1.01, - forbidden_factor=20, - **kwargs): - - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - # set Random seed - torch.manual_seed(seed) - np.random.seed(seed) - - # Initialize Language model - self.context_prefix = '' - - self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model) - self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True) - self.context_prefix = self.lm_tokenizer.bos_token - - self.lm_model.to(self.device) - self.lm_model.eval() - - self.forbidden_tokens = np.load(forbidden_tokens_file_path) - self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if - (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] - - # Freeze LM weights - for param in self.lm_model.parameters(): - param.requires_grad = False - - # Initialize CLIP - self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, - download_root=clip_checkpoints, jit=False) - # convert_models_to_fp32(self.clip) - - # Init arguments - self.target_seq_length = target_seq_length - self.reset_context_delta = reset_context_delta - self.num_iterations = num_iterations - self.clip_loss_temperature = clip_loss_temperature - self.clip_scale = clip_scale - self.ce_scale = ce_scale - self.stepsize = stepsize - self.grad_norm_factor = grad_norm_factor - self.fusion_factor = fusion_factor - self.repetition_penalty = repetition_penalty - self.end_token = self.lm_tokenizer.encode(end_token)[0] - self.end_factor = end_factor - self.ef_idx = 1 - self.forbidden_factor = forbidden_factor - - def get_img_feature(self, img_path, weights): - imgs = [Image.open(x) for x in img_path] - clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] - - with torch.no_grad(): - image_fts = [self.clip.encode_image(x) for x in clip_imgs] - - if weights is not None: - image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) - else: - image_features = sum(image_fts) - - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - return image_features.detach() - - def get_txt_features(self, text): - clip_texts = clip.tokenize(text).to(self.device) - - with torch.no_grad(): - text_features = self.clip.encode_text(clip_texts) - - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - return text_features.detach() - - def get_combined_feature(self, img_path, texts, weights_i, weights_t): - imgs = [Image.open(x) for x in img_path] - clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] - clip_texts = [clip.tokenize(x).to(self.device) for x in texts] - - with torch.no_grad(): - image_fts = [self.clip.encode_image(x) for x in clip_imgs] - text_fts = [self.clip.encode_text(x) for x in clip_texts] - - features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) - if weights_t is not None: - features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) - - features = features / features.norm(dim=-1, keepdim=True) - return features.detach() - - def run(self, image_features, cond_text, beam_size): - self.image_features = image_features - - context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) - - output_tokens, output_text = self.generate_text(context_tokens, beam_size) - - return output_text - - def generate_text(self, context_tokens, beam_size): - context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) - - gen_tokens = None - scores = None - seq_lengths = torch.ones(beam_size, device=self.device) - is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) - - for i in range(self.target_seq_length): - probs = self.get_next_probs(i, context_tokens) - logits = probs.log() - - if scores is None: - scores, next_tokens = logits.topk(beam_size, -1) - context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) - next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) - - if gen_tokens is None: - gen_tokens = next_tokens - else: - gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) - gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) - else: - logits[is_stopped] = -float(np.inf) - logits[is_stopped, 0] = 0 - scores_sum = scores[:, None] + logits - seq_lengths[~is_stopped] += 1 - scores_sum_average = scores_sum / seq_lengths[:, None] - scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( - beam_size, -1) - next_tokens_source = next_tokens // scores_sum.shape[1] - seq_lengths = seq_lengths[next_tokens_source] - next_tokens = next_tokens % scores_sum.shape[1] - next_tokens = next_tokens.unsqueeze(1) - gen_tokens = gen_tokens[next_tokens_source] - gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) - context_tokens = context_tokens[next_tokens_source] - scores = scores_sum_average * seq_lengths - is_stopped = is_stopped[next_tokens_source] - - context_tokens = torch.cat((context_tokens, next_tokens), dim=1) - is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() - - #### - tmp_scores = scores / seq_lengths - tmp_output_list = gen_tokens.cpu().numpy() - tmp_output_texts = [ - self.lm_tokenizer.decode(tmp_output) - for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) - ] - tmp_order = tmp_scores.argsort(descending=True) - tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] - log_info(tmp_output_texts, verbose=True) - #### - - if is_stopped.all(): - break - - scores = scores / seq_lengths - output_list = gen_tokens.cpu().numpy() - output_texts = [ - self.lm_tokenizer.decode(output[: int(length)]) - for output, length in zip(output_list, seq_lengths) - ] - order = scores.argsort(descending=True) - output_texts = [output_texts[i] for i in order] - - return context_tokens, output_texts - - def get_next_probs(self, i, context_tokens): - last_token = context_tokens[:, -1:] - - if self.reset_context_delta and context_tokens.size(1) > 1: - context = self.lm_model(context_tokens[:, :-1])["past_key_values"] - - # Logits of LM with unshifted context - logits_before_shift = self.lm_model(context_tokens)["logits"] - logits_before_shift = logits_before_shift[:, -1, :] - probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) - - if context: - context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) - - lm_output = self.lm_model(last_token, past_key_values=context) - logits, past = ( - lm_output["logits"], - lm_output["past_key_values"], - ) - logits = logits[:, -1, :] - - logits = self.update_special_tokens_logits(context_tokens, i, logits) - - probs = nn.functional.softmax(logits, dim=-1) - probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) - probs = probs / probs.sum() - - return probs - - def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): - context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] - - window_mask = torch.ones_like(context[0][0]).to(self.device) - - for i in range(self.num_iterations): - curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in - context_delta] - - for p0, p1 in curr_shift: - p0.retain_grad() - p1.retain_grad() - - shifted_context = list(map(add_context, context, curr_shift)) - - shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) - logits = shifted_outputs["logits"][:, -1, :] - probs = nn.functional.softmax(logits, dim=-1) - - loss = 0.0 - - # CLIP LOSS - clip_loss, clip_losses = self.clip_loss(probs, context_tokens) - loss += self.clip_scale * clip_loss - - # CE/Fluency loss - ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) - loss += ce_loss.sum() - - loss.backward() - - # ---------- Weights ---------- - combined_scores_k = -(ce_loss) - combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) - - # minmax - if combined_scores_k.shape[0] == 1: - tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) - else: - tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( - combined_scores_k.max() - combined_scores_k.min()) - tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( - combined_scores_c.max() - combined_scores_c.min()) - - tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c - tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) - - factor = 1 - - # --------- Specific Gen --------- - sep_grads = None - - for b in range(context_tokens.shape[0]): - tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] - for p_ in curr_shift] - - # normalize gradients - tmp_grad = [tuple([-self.stepsize * factor * ( - x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ - j] ** self.grad_norm_factor).data.cpu().numpy() - for j, x in enumerate(p_)]) - for i, p_ in enumerate(curr_shift)] - if sep_grads is None: - sep_grads = tmp_grad - else: - for l_index in range(len(sep_grads)): - sep_grads[l_index] = list(sep_grads[l_index]) - for k_index in range(len(sep_grads[0])): - sep_grads[l_index][k_index] = np.concatenate( - (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) - sep_grads[l_index] = tuple(sep_grads[l_index]) - final_grads = sep_grads - - # --------- update context --------- - context_delta = list(map(add_context, final_grads, context_delta)) - - for p0, p1 in curr_shift: - p0.grad.data.zero_() - p1.grad.data.zero_() - - new_context = [] - for p0, p1 in context: - new_context.append((p0.detach(), p1.detach())) - context = new_context - - context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) - for p_ in context_delta] - context = list(map(add_context, context, context_delta)) - - new_context = [] - for p0, p1 in context: - new_context.append((p0.detach(), p1.detach())) - context = new_context - - return context - - def update_special_tokens_logits(self, context_tokens, i, logits): - for beam_id in range(context_tokens.shape[0]): - for token_idx in set(context_tokens[beam_id][-4:].tolist()): - factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) - logits[beam_id, token_idx] /= factor - - if i >= self.ef_idx: - factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) - logits[beam_id, self.end_token] *= factor - if i == 0: - start_factor = 1.6 - factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) - logits[beam_id, self.end_token] /= factor - - for token_idx in list(self.forbidden_tokens): - factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) - logits[beam_id, token_idx] /= factor - - return logits - - def clip_loss(self, probs, context_tokens): - for p_ in self.clip.transformer.parameters(): - if p_.grad is not None: - p_.grad.data.zero_() - - top_size = 512 - _, top_indices = probs.topk(top_size, -1) - - prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] - - clip_loss = 0 - losses = [] - for idx_p in range(probs.shape[0]): - top_texts = [] - prefix_text = prefix_texts[idx_p] - for x in top_indices[idx_p]: - top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) - text_features = self.get_txt_features(top_texts) - - with torch.no_grad(): - similiraties = (self.image_features @ text_features.T) - target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() - target_probs = target_probs.type(torch.float32) - - target = torch.zeros_like(probs[idx_p]) - target[top_indices[idx_p]] = target_probs[0] - target = target.unsqueeze(0) - cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) - - clip_loss += cur_clip_loss - losses.append(cur_clip_loss) - - return clip_loss, losses diff --git a/zerocap/model/ZeroCLIP_batched.py b/zerocap/model/ZeroCLIP_batched.py deleted file mode 100644 index 2c0209f..0000000 --- a/zerocap/model/ZeroCLIP_batched.py +++ /dev/null @@ -1,449 +0,0 @@ -import numpy as np -from torch import nn -from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer -from transformers.models.gpt_neo import GPTNeoForCausalLM -import torch -import clip -from PIL import Image -from datetime import datetime -import sys - -class TextCLIP(nn.Module): - def __init__(self, model): - super(TextCLIP, self).__init__() - self.model = model - - def forward(self, text): - return self.model.encode_text(text) - - -class ImageCLIP(nn.Module): - def __init__(self, model): - super(ImageCLIP, self).__init__() - self.model = model - - def forward(self, image): - return self.model.encode_image(image) - -def log_info(text, verbose=True): - if verbose: - dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") - print(f'{dt_string} | {text}') - sys.stdout.flush() - - -def add_context(x, y): - return (x[0] + y[0], x[1] + y[1]) - - -def convert_models_to_fp32(model): - for p in model.parameters(): - p.data = p.data.float() - - -class CLIPTextGenerator: - def __init__(self, - seed=0, - lm_model='gpt-2', - forbidden_tokens_file_path='./forbidden_tokens.npy', - clip_checkpoints='./clip_checkpoints', - target_seq_length=15, - reset_context_delta=True, - num_iterations=5, - clip_loss_temperature=0.01, - clip_scale=1., - ce_scale=0.2, - stepsize=0.3, - grad_norm_factor=0.9, - fusion_factor=0.99, - repetition_penalty=1., - end_token='.', - end_factor=1.01, - forbidden_factor=20, - **kwargs): - - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - # set Random seed - torch.manual_seed(seed) - np.random.seed(seed) - - # Initialize Language model - self.context_prefix = '' - - if lm_model == 'gpt-neo': - self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') - self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) - elif lm_model == 'gpt-2': - self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') - self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) - self.context_prefix = self.lm_tokenizer.bos_token - - self.lm_model.to(self.device) - self.lm_model.eval() - - self.forbidden_tokens = np.load(forbidden_tokens_file_path) - self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if - (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] - - # Freeze LM weights - for param in self.lm_model.parameters(): - param.requires_grad = False - - # Initialize CLIP - self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, - download_root=clip_checkpoints, jit=False) - self.clip_image = ImageCLIP(self.clip) - self.clip_image = torch.nn.DataParallel(self.clip_image) - self.clip_text = TextCLIP(self.clip) - self.clip_text = torch.nn.DataParallel(self.clip_text) - - # Init arguments - self.target_seq_length = target_seq_length - self.reset_context_delta = reset_context_delta - self.num_iterations = num_iterations - self.clip_loss_temperature = clip_loss_temperature - self.clip_scale = clip_scale - self.ce_scale = ce_scale - self.stepsize = stepsize - self.grad_norm_factor = grad_norm_factor - self.fusion_factor = fusion_factor - self.repetition_penalty = repetition_penalty - self.end_token = self.lm_tokenizer.encode(end_token)[0] - self.end_factor = end_factor - self.ef_idx = 1 - self.forbidden_factor = forbidden_factor - - def get_img_feature(self, img_path, weights): - imgs = [Image.open(x) for x in img_path] - clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] - - with torch.no_grad(): - image_fts = [self.clip_image(x) for x in clip_imgs] - - if weights is not None: - image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) - else: - image_features = sum(image_fts) - - image_features = torch.nn.functional.normalize(image_features, dim=-1) - return image_features.detach() - - def get_txt_features(self, text): - clip_texts = clip.tokenize(text).to(self.device) - - with torch.no_grad(): - text_features = self.clip_text(clip_texts) - - text_features = torch.nn.functional.normalize(text_features, dim=-1) - return text_features.detach() - - def get_combined_feature(self, img_path, texts, weights_i, weights_t): - imgs = [Image.open(x) for x in img_path] - clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] - clip_texts = [clip.tokenize(x).to(self.device) for x in texts] - - with torch.no_grad(): - image_fts = [self.clip.encode_image(x) for x in clip_imgs] - text_fts = [self.clip.encode_text(x) for x in clip_texts] - - features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) - if weights_t is not None: - features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) - - features = features / features.norm(dim=-1, keepdim=True) - return features.detach() - - def run(self, image_features, cond_text, beam_size): - self.image_features = image_features - - context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) - - output_tokens, output_text = self.generate_text(context_tokens, beam_size) - - return output_text - - def generate_text(self, context_tokens, beam_size): - context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) - - gen_tokens = None - scores = None - seq_lengths = torch.ones(beam_size, device=self.device) - is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) - - for i in range(self.target_seq_length): - probs = self.get_next_probs(i, context_tokens) - logits = probs.log() - - if scores is None: - scores, next_tokens = logits.topk(beam_size, -1) - context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) - next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) - - if gen_tokens is None: - gen_tokens = next_tokens - else: - gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) - gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) - else: - logits[is_stopped] = -float(np.inf) - logits[is_stopped, 0] = 0 - scores_sum = scores[:, None] + logits - seq_lengths[~is_stopped] += 1 - scores_sum_average = scores_sum / seq_lengths[:, None] - scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( - beam_size, -1) - next_tokens_source = next_tokens // scores_sum.shape[1] - seq_lengths = seq_lengths[next_tokens_source] - next_tokens = next_tokens % scores_sum.shape[1] - next_tokens = next_tokens.unsqueeze(1) - gen_tokens = gen_tokens[next_tokens_source] - gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) - context_tokens = context_tokens[next_tokens_source] - scores = scores_sum_average * seq_lengths - is_stopped = is_stopped[next_tokens_source] - - context_tokens = torch.cat((context_tokens, next_tokens), dim=1) - is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() - - #### - tmp_scores = scores / seq_lengths - tmp_output_list = gen_tokens.cpu().numpy() - tmp_output_texts = [ - self.lm_tokenizer.decode(tmp_output) - for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) - ] - tmp_order = tmp_scores.argsort(descending=True) - tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] - log_info(tmp_output_texts, verbose=True) - #### - - if is_stopped.all(): - break - - scores = scores / seq_lengths - output_list = gen_tokens.cpu().numpy() - output_texts = [ - self.lm_tokenizer.decode(output[: int(length)]) - for output, length in zip(output_list, seq_lengths) - ] - order = scores.argsort(descending=True) - output_texts = [output_texts[i] for i in order] - - return context_tokens, output_texts - - def get_next_probs(self, i, context_tokens): - last_token = context_tokens[:, -1:] - - if self.reset_context_delta and context_tokens.size(1) > 1: - context = self.lm_model(context_tokens[:, :-1])["past_key_values"] - - # Logits of LM with unshifted context - logits_before_shift = self.lm_model(context_tokens)["logits"] - logits_before_shift = logits_before_shift[:, -1, :] - probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) - - if context: - context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) - - lm_output = self.lm_model(last_token, past_key_values=context) - logits, past = ( - lm_output["logits"], - lm_output["past_key_values"], - ) - logits = logits[:, -1, :] - - logits = self.update_special_tokens_logits(context_tokens, i, logits) - - probs = nn.functional.softmax(logits, dim=-1) - probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) - probs = probs / probs.sum() - - return probs - - def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): - context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] - - for i in range(self.num_iterations): - curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in - context_delta] - - for p0, p1 in curr_shift: - p0.retain_grad() - p1.retain_grad() - - shifted_context = list(map(add_context, context, curr_shift)) - - shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) - logits = shifted_outputs["logits"][:, -1, :] - probs = nn.functional.softmax(logits, dim=-1) - - loss = 0.0 - - # CLIP LOSS - clip_loss, clip_losses = self.clip_loss(probs, context_tokens) - loss += self.clip_scale * clip_loss - - # CE/Fluency loss - ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) - loss += ce_loss.sum() - - loss.backward() - - # --------- Specific Gen --------- - final_grads = self.norm_grad(context, context_tokens, curr_shift) - - # --------- update context --------- - context_delta = list(map(add_context, final_grads, context_delta)) - - for p0, p1 in curr_shift: - p0.grad.data.zero_() - p1.grad.data.zero_() - - new_context = [] - for p0, p1 in context: - new_context.append((p0.detach(), p1.detach())) - context = new_context - - context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) - for p_ in context_delta] - context = list(map(add_context, context, context_delta)) - - new_context = [] - for p0, p1 in context: - new_context.append((p0.detach(), p1.detach())) - context = new_context - - return context - - def norm_grad(self, context, context_tokens, curr_shift, ): - factor = 1 - sep_grads = None - window_mask = torch.ones_like(context[0][0]).to(self.device) - - for b in range(context_tokens.shape[0]): - tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] - for p_ in curr_shift] - - # normalize gradients - tmp_grad = [tuple([-self.stepsize * factor * ( - x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ - j] ** self.grad_norm_factor).data.cpu().numpy() - for j, x in enumerate(p_)]) - for i, p_ in enumerate(curr_shift)] - if sep_grads is None: - sep_grads = tmp_grad - else: - for l_index in range(len(sep_grads)): - sep_grads[l_index] = list(sep_grads[l_index]) - for k_index in range(len(sep_grads[0])): - sep_grads[l_index][k_index] = np.concatenate( - (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) - sep_grads[l_index] = tuple(sep_grads[l_index]) - final_grads = sep_grads - - return final_grads - - def update_special_tokens_logits(self, context_tokens, i, logits): - for beam_id in range(context_tokens.shape[0]): - for token_idx in set(context_tokens[beam_id][-4:].tolist()): - factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) - logits[beam_id, token_idx] /= factor - - if i >= self.ef_idx: - factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) - logits[beam_id, self.end_token] *= factor - if i == 0: - start_factor = 1.6 - factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) - logits[beam_id, self.end_token] /= factor - - for token_idx in list(self.forbidden_tokens): - factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) - logits[beam_id, token_idx] /= factor - - return logits - - def clip_loss(self, probs, context_tokens): - for p_ in self.clip.transformer.parameters(): - if p_.grad is not None: - p_.grad.data.zero_() - - top_size = 512 - top_probs, top_indices = probs.topk(top_size, -1) - - prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] - - clip_loss = 0 - losses = [] - - top_texts = [] - for idx_p in range(probs.shape[0]): - prefix_text = prefix_texts[idx_p] - for x in top_indices[idx_p]: - top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) - - text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) - - with torch.no_grad(): - similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) - similiraties = similiraties.reshape(probs.size(0), -1) - target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() - target_probs = target_probs.type(torch.float32) - - clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) - # for idx_p in range(probs.shape[0]): - # top_texts = [] - # prefix_text = prefix_texts[idx_p] - # for x in top_indices[idx_p]: - # top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) - # text_features = self.get_txt_features(top_texts) - # - # with torch.no_grad(): - # similiraties = (self.image_features @ text_features.T) - # target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() - # target_probs = target_probs.type(torch.float32) - # - # target = torch.zeros_like(probs[idx_p]) - # target[top_indices[idx_p]] = target_probs[0] - # target = target.unsqueeze(0) - # cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) - # - # clip_loss += cur_clip_loss - # losses.append(cur_clip_loss) - - return clip_loss, losses - - def clip_loss_old(self, probs, context_tokens): - for p_ in self.clip.transformer.parameters(): - if p_.grad is not None: - p_.grad.data.zero_() - - top_size = 512 - _, top_indices = probs.topk(top_size, -1) - - prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] - - clip_loss = 0 - losses = [] - for idx_p in range(probs.shape[0]): - top_texts = [] - prefix_text = prefix_texts[idx_p] - for x in top_indices[idx_p]: - top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) - text_features = self.get_txt_features(top_texts) - - with torch.no_grad(): - similiraties = (self.image_features @ text_features.T) - target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() - target_probs = target_probs.type(torch.float32) - - target = torch.zeros_like(probs[idx_p]) - target[top_indices[idx_p]] = target_probs[0] - target = target.unsqueeze(0) - cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) - - clip_loss += cur_clip_loss - losses.append(cur_clip_loss) - - return clip_loss, losses \ No newline at end of file diff --git a/zerocap/model/__init__.py b/zerocap/model/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc deleted file mode 100644 index 093e0838443faaa3ea157d218bea1ec5ae38840e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 13665 zcmc&*X>1(VeV=Q0XHPDdE0UrpnU*ERbRE%_b2x@z#YgS9wNxLhl`O1`Ry(uQ4)@}l zp+&K?Z38M5RaA-6A|INdfm0N1(<4bQv>%Fo2+{y8(xOQFp_mUrixvkzz6T-LKXzJ^i?*{gXEKGm-x=uHYAtxSFfC zwVIB*(Kb3}&D2?DwyjR8mg=NyY58omGab8T%e>w;YKHtSTg&2}YUeunT3)_Ow+o#@ zt&r3!){1V~wcYGBy;gE_ZvL89n{W$m5%==5np<)wu4(SXmQkB@^_R89$vYU4c5YFR zjc#|*h$o(XVfobaXFuZA-LpN{YsWL&kRNp%=C~-g3iW%lL8f7Otmo1*ecix;PnXzP_dThO2Ms10yseG;PMj(_B_<7Xc`{_>`|5G^cbLFg8FKKIm8i<5R5^?umub@BNU zDAdE_Qir5ls^RIW=bwAwWdCYxw|gs2tJ~}qNCg{-rdxVRujrzJ{D@d2+K5;Sc!Kn0sRX(U>67S-C-F2Ntugl3#p&*b8q>E|f z8~mC-mT|%4VjV?hV?NS+ZA*JvyHFnJk?!lhv4!zC`daS1S{(fxXqxNSMZM#No(N)# zbR)TY^$?>;Npi$`fA5HG29-s3Ld zo^$VYkK&$p?{kmgUT}|txbKN4xq{0e*|T2P6ZNnsB)Q*6t@Ok9o_b^Ytg&6li*~xfAx*Z$)!fS)pr_~y~b+42VM;BOkljrcTz}i|5WwM&;H})wKKQzfZun{ z-2U7z^Y<5LZa*S_e|zTkUwxCme>roTRjXhA*E6wx^7i7ZpZ?mv-~XL6cdGcTLDu1q z4{2O^T){(FONtlQNt9f@Yq-X?=9nn6B8VG=3!>#`RNgi5&JMGFE;6=sJ+Fnc{-m2_%aH}`mAh%Z%Dz$GIT)po&P8dYhoTJ9`3P-0ha=Q? zjzl@6_e6Q53sC{-y%9$29E~{Q`{WpoMH5@Pn?IuxRGBQyN?nN9nZ*7AFwTzcdbL_O_#ueBOpTy%R^x{z_#5xpLCt=|ge zG}QPP;}nFYaalZwp5sC*aO$tsTkZO#wkJNoqI{>`U9Y#D!1G*KpMIPc^{(6Na26;j zC92Bl3$NMQ5F~ay(GTXF+uo7lpUB-V$EmtO&GYQMls?DdHGSA8OwV zf_O5lixm&n#9Kon-IefiJcVyiGMo(T0#swEyWVkH=%Y@{5yS^%L+xIGaeE!Cc0F7d zUYwJ~0an9{?S|(hxghlV0k|M8b1Y7`Cpu2E-T<}ZVskw}i^InXh6ZcQPMp5it+&Ig zaawUsob|eHLZuwDNegMO>9pL9aZ0H-i&gPH5*xmaCZ0f2o7+oFr`c+IPQM;ru9fA` z$F!gLaentale0|DG5HB5FEDwL$xBRLX0puW6(;AId<;pPUcCb91~BckG?L2~bu}hw z-|Qg=A4cuqF(g{GYUcEUnKlY~S@gIS*n@w zDGb?#KAs@Y$Vps8CYi+kqnSkoIT|=X9L;_-dk%Htf)Xvvna=PS?e_1uu}{b3C3*;w z40L^D$JJMj0o>#Wlc&+BCT3%2Nnk*@P_f<`nu5#~ueEYH3{;)bWSgj!OjfGtWf&0C zZFSmBMQFSe+3Xb8f3<1$=uYKGY6T2uN( zosSkCgE%V{fyyd>xfup=Dak6vP=f@)q=Mnca&_W~Jq$#LFOWVDJ!bEbdgwD_&1 zXj5VpWfoW=h>N@p1uR+Y5qrZxKmx-p zU&n(xuJTf2Dla9~T6#>;8L}jkD8Zjil%P~}QChyUW$m<7iy2vxP1Io~QHSat0z+ZBWPIMMDvR$*Rf-Fs!??8Xh~=F zP5IdV1c#e;i9XG6HUGEOrJy+N@Zk6i_SaV%J9F4yQ651BsjsGYOv`SCElC4rrXYy^ z57__Jx9k6%Dr!Ny_4f0PerWxzy`%mH`}Ftj?18w5CWjUwn^**=wZ|s>%>Plp;hug! zvD5FHsZ3>uq-GQ}U1EjFj#|tswdkW1;RZFLSZ9sbn5;5sGZ9Rv`)an*m(ZMc;*})E z$5butd`*)XSE)ev3MLzuZWQr}iZVn<=X6^?LQ{0&oym8Dc8EXvkz8k*gWy3VksiUd zg5`m`LyHs{GxR7e83?m>6kU z6Gpj5=O4ugmb`k$k>SKB;!-$9t(XMB>M_m_AMx296&IUgz58?MG(on~M#U%_n+NVD z1nQ0aXIgUlnw)#^1;oC}4sT=D$eFHw(}e1#<%R1vFmcg?PQeR84O-wq3eT|3$rl-Z zF-YGsHjL0lnKT5q^b6k`Wbl-WhP>%V10mS5(`444?K3&w_)z7gdHX7@pO0_uBn&-X-N zfRjYIW`eGu_BX>yNQ^|!Yf&~zpC$;3yK6L!(;7#}{#_t_K>e~a+8`gLqVxn@L_iGI z2SGa2LfLpd!~=z>;2!|WtIL^{<;)7w7y`+0t%KT}HYlQ{;z4bYf{9E@r=`IJ%`MP8 zhmrFA_^2m|3BeE2eZ>4FqxACs~VNaMCbS}W&Z zGd=)jjZ44OXh69d`sO2dOW-c8u%C|s>?_W>fVzMR0Ms7M1_CO@RnCTLN}Bp(jENYO zxLDj!+lJUEl#B4=mOe8?xb#hH;rQ_dz(*#qcJW@$GKF}ee|2bxFQN$1NO@pEoo21Do1-iGIF#FDv%n|-qjFNLyRfnJ255TY^6(S zpkfo)?J&Gbt2KzzZ$fOD)a$k56kxa&PvSLkiT1Ax zg%8ZFcPRiQ*;~Q(Nw2>O9BHEuye>`!m+O5m&OjbI>j(`qNeB)kwH}It7@i$R45B{sY@1@4dw#5MVL6L?yN_)brRa!VvO>+SYpMVV!~MuPrWEeHNhkU4DZNT#Ue zcks5LmIKllV1#|3gBRH}wt_hHG=|;^gIZorCoyq3)q;cYwpQ861FA%vyW-U8(`K z688*fmoia?Eh0rPu*^1We9|VYh8(wS+eSMXw4g`!mN@{*CVPSvkv%bF)O>pi_QZuB z=G(K-A+T~Pl+12vuqw#I^D3xFKc48;Pz9hm!gtNS4?F+AfbF7Db`@8o0PQNPWr4OUtt4k#vq%H+ zI$M|U!w7>ApF!E*;R=o-(Ii_RF-!WiUIwFAjFOQCYfj@{)@PC1T-c8}$CXfn9Le8d zT--^O{^KR22^b-L2kDLI;6ZQ%F9{uD(*<6G1rPDtp|x!Q2+*I_ZW$N;dZ0@PK_RjV zjL2*(L;xW!%)(?yPEy;^@D|hD^P-OtWJ1F7XQdEpA>8ni)%Gkd${E29yssYAezKpj;e{u`Fh#Yp@ zn48u7iaWWbgFh$XH%0%jg zr&;zXNG57yHkLN14i0o5jVdUg*+Nga!(SUdK?)8;Q_)oNZc4sOp75^f=Ar{t4QaJX zuWve<_7CDpdIfTw*|G-HZsrDk`O+ z#G4dA9b)8jpgz6;{b!>Y`r)?>QI2K@Gu@xwh(aqo6v6kDZMFvo!%Q?29V9pKep=gj zTGdb-B5OS~m<#9q!>V312YoWjFPOtz z^GS!6x>P!}b41!?%tDgeYVYYyNY2HQU!NF$ePkS`Bf))~20;9$-%PfI77auL(rE<9 zAh!f^atBsA?r~|?rl|9SV2S^UA^gzRLwp6TB^(eNeFsC9j@FJ5A`6UlIkd_&)M~d{ zMQ%iwonpLE{X)iVvD=M<8pvtO1~86ue7kMe#Pn3 zgHXddSix=BLVB&^ot;r0_^wb;vt%`BG~6{+M%+!t^K1 zaWygDi_DMMXlFPU&!E&dTDy+~g~>)9NCrZ-~#aI(4R4Vj|=H zJYotU062Amb1qcSN-`)E3U6>=yAh>#UA4n`Sh3kf*8FWIzr{ob{=b7M) zxWMY&MU9#i*5Ss=p2yW%o}!$*tz{|G)X@QUAyCe2Ez`!aE7#koWmc0+zS%;&msA8; zab-9grDES)E|oz3CT8?qT)_qsY)}|;z!88C;P4AEfHLlL486>xeK=gq_OaqyCP@mWFHQFhlF3+&eqje$>4Cwe|?Taw%a(5a8 z!}z0)rhXtR2TwxaZs{A3@z@;E1p4WS&-4o?2G+(hpgP3Zzn~76{@0m1-{;|uyiymR#cP1EUKppt z9u8bmjD;U^E-7b_9%A#?!p9-+{%q0zO6v7A^Ax zF8Tuj4qj;+4=C`L@HAqoIiBiZuz{llCFY0)$@sp6vmM-X(}6izh*pz5H|cAet_5FP z;S3o!HLwv)WjFo<(WUO&Zc6d(g?}F8P@c!WQw}&+VdDgMLk_B5i`a9?^1P~_^Yc-W zKslqhwQJ|$e_+=pk2tXu7X1=Vp**UyYyvTTxvUK)LPYugRK#<*)k7L~Khif1{{VJT z%2C;$#y5y}2j!iUDJYqRu%tjSCULQP%XL%Amd#=Uvzdz2U~r zC3fx(lf`*;ekYH!r`-Vf7|ue;!=XEdO?;h|e~F2-$P_X0=S)WP5FsC>*X}XLo!eMn zmVn=mRaVC%dGZHnJGYvCm)V?z>6PvOI6tO^7PQriQ%WvhXMq$>x)`uCJ-CuQDylm$ zMtFd+!O`8IR#2yE)eeg^g9ONf0IyPzUZ+n{?yJ|e{8(jm7BFDzg`m}GwOc&I$`+He zuL;_&o?x}tg)<_yGGSwfHf7Hqm+BN+#GxdZ8k&(ZHjoK}S8);Gp2Nva!rgPw2qm87 z)Q_b|DiiF~oh+yhL0 z1j%AC#yQh2{?oy6;+(@1$Ll!r>cj=dS;J{&^+vqP`aLFcxvw(EV-|v7zl1~u4Yi;O zlr&Hw(jEm#fO8&KKnL75faksCU+y7$!p_=dTt8|bwlCUM!4?+vWo)BdLr!5QlJ_VX zwb?OF4o<3n!Z?ZlYd9rORLc|BwL@=G2fIQ&LBUo+VZQL(^YY-R{BIn=jRU!hqx?=z lLXKs>X~y{s$VfM3Mji diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc deleted file mode 100644 index 5f08c9ef7e7b4218d18e35e990f999d7fd89577a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 13594 zcmc&*Ym8jiUB9oHx%1rF+3~)-b~cXfW<%C>0M8Rf6B| zoIA6#>rI@}0<*g3o_p@O=kY)P=kJ_-I-k#I_?>*Ca_QNpH0>YwF!-6sd>EJib4Xmx z)tg#H$K7ZeEwf_kEHj%{D^W?bl9i-*XtXx9Hk#`np~zxEVKlU8{__IX91c@p;WHxMSBfckG%`8F%$pwT1D!7?E~v zL63}fd%=jto_}fS)Qe|7>Q&vd9oK6{c@`{nR=sw8(-WvSj)HSu=edq}rn>G|n=hV4 zVd8~Wb;U!*cB_FG)LY(C&0x#;aq-VP$uzv@RxKX5ziL6lf-uKSl4%*YD7jUY_?t|;hDucC+*a#-JsrS zKD$Zs`TRq>B>rLt>F=Lt@S2 zAq2O~fMiUk_+-c@q*L0mAxR^2)Vs?>jvlM1Z$$c4`Jj?xlrqC+uF0*g`E%JBxeDJ?PHiZo7xv!?>s2 z2i$qwGwy@#5!|!xL+(-BbM7$^_dU@#S8xd=d){k%q8fCBB==j1m2U9fQ*TY4SX}Oi zOLf=v+895?#6QvQUVW={VzJh&cb(d0ueRFlfEWF{V;Jw!-2~D*KT-bD^MAj2{mdOa z;P>q_cRu%v{QcRPJCDoXU!S@27vJFTpU&K2)$*7A`AnprxU=xun_vC6hrfO1ZW*7| z$U6M-A&D!C%Rh;=3^cV)qTuRn!!@=v$ArjPA;b;Bm1rdUDV29kyt9LJKNA|;I`Y{d z*FfI2+{7(@m*3JF*+zl2kjwYSkeX~oJ{2jVL83n%q6~FPH?(ElO}eS;tetI4^vmR( zYkYTLV_JSQDY+@rn{e%L!kJq81>!pdN52O zoeR;nb0|c8=Wv)o`aqaPIv?hcJ{V%e&XJHKeu!h}9}UO0bvHYM5j)4iBGTh7YrZEO zNBXc-F!^(lRaKwS2F+yd&~&WllHieo*VQx=%fg8d9i| znxSu=xQDAxI8aaZz;RBTZbFAO{il!_eEi&|{jvJed!Jr>{1ikz>3Xl%YhILhJ6GC} zan})@4s>m!9>{5^@h?OP2utm!YXKxB5T#V>Q@dxb=L!osObgJW3f03s-q|`ytVbZ@O)=QpdE|C z9gRs5$7s4x`=0Md<3Uxdc(5kk8X9S@1ec=;e1np~WMCJd8Vl|9mQzO`Ray=|Iv^Wr zc6^N6Yhksk!MgCGj4bxC8eU}AJSWcif!FoH1yPY>aoQcxa+a$#P&>*muls0m@L0mo zV2#;{)AicbW^gr1D$a@0UfYeSlwme*A?YnUb$4TwQtHh@S-g+LhHs;Zk07bc?j@$P zTyJ_#w;Eio6y?xIw4eAbssPHKBN;2gA_^Nw@Vr2!s@7f#xUHo4Qmu3zqm1d=aJB z+y2_R=RrX}gJ-E`MyD`f7y5XDJR>J@5t(=r`;TTC7365(0HODNY4!r@L^&l|m@}Q> z5!&tFac!TD$xB=+Nd~$;vg7h=Mh|Xsh{@Avlw-58vm`JeT&PH|4@^PoidSE`9Qdlv zaI(v&6;Jl=f8nEb1HDfS_8t z4DXvtA+ov56ns(TqlHI5N=qG}vdUFn4*aMPXBAJV@qt!SxA0@RFwxi^-l4DOOLvD( zvjOQfY<bM|Ie9taXxZQO^wV z*GRb#aMzR~XVUpq;ue~Ho5HtXofDVPCTJ?Sw!%fSkQS0RZ0Q`*cC&BEnv&4M!()_< zIfuTvxXcQ5erqYR6j((W16J^(JZ}TBikEoE)G!N>yI`b33*Y-PI+di$8FTvPk^PA_ z`dN*P$RFKGQjN^%e+>!gFjRDj5k;4f$}Bme-V9liidElF$Er`Nw=gN+*|K(0D#Mg4 zNykbs6)Qn|s01@x2!o*}`(Xxk3@A!Fu9F_r`JmRzg!qGkOfvwMtA*!#6lCuEQjp>T z=rAhdJHoi5TCR{_bn2iVpWL6|aLz8#lNqe$|F)78l%yRV9G$`bT54lw4*RRe!>Ay& z)Z~sq*{!A}X~4`B1kwE-`@ed>{@<;77PMCPpKtgB>u>E{_1D{{KMLCa+zN<`_U%7~PyHYD8|>-#Q#<|cqb-&7kxEg}R*4lRI~p;oG-3m#h%cxN#X4)e&SaHI zlZjwLwO6r~s)W+CV>cu=Ii_l9=WE)?s7&3tS0>rG^qUAsl$7~FI;-3IVH%+0?@qL9 zv_AaNab%vBz-M?A>LGk7*c0C-X@&)Kc!W5ytr8f-ZOjo~U_Tdy`=J54; zCN8+ljRf2t)S%t%CGZT39Df0Oe9P=5ZyOs%V53ahfZO_o@AOhzIsgqq97fL$?S6J! z>!krLXsfpP*D&4B5#W$iwO}puMm~fs!M)%nTLWp{U|&G^}V_0dq6G?03%^$*+45vjNfd3BPazEjq)`;tA*(>d77vv z?yj+HoYpu(_U{7P1I(A6(t6o25hll=4S*>mJ_ypG7RpBJApppQx&8snNFSM5PTE00 zH*R7^o7H-Gw3MIGdI?y^xO7t5O3*w3nrARlz8@Vmf((ut{K+?ZmAs z$;Z7EhrOWgu$OV^eH3ILoC0QhOmqI#p>Fo`T z<2x$vw~HShV=r*33F^eod?gdKx{i`Tl~<6${$~15(13C^^v%cb6~kSsVLu-Oh*y+x z0dN5e0F*tPjq+aNfP}hAn);)RgcyprSm8k1g6t@4i{S7zeQJPf>6_O4v19W9jZ7fy z;=P<@0-;3r>c9|RL=j?;;?7*9n(`w#d%xBJezM1cd=;O>UKMCTRCPF7(w^YD44)oG zMvit~g--+8yBY#%PVHL0?17OcUrzO7v?P{+`R}X=F+XHBV zz@r0X{!*vYj1mCaRy>KwL5y(xC$gDP- z3ngWT=>`etgY=jw!fmYJ8KmczY8by{zi4M!u;#uNnIRzQV zewsO&M->WeBKIaP-$kP7vxbGepQ2ti%6d+>5y%$xS$!J!Bz%gpQABAOenwdbcmC7mN(dgt(3Gk}x(<;-L{Hw-aosVc9L6lP6^js5ht1~Y;@T&ZLMoI(?2njDv`9;ITOxiX=?b1r!RY0)UV zib_&Wc9m7qfX$Vnk@KlooTccqb%`hpVFmFSl>H4Z{|FLIa_?cYpik;W@N>y17)fy8 zB<@9h8oAB+{djOx*VO0>@^=UqwKlc4e+g-fFG!z&UItJ~_)^A$y z{Js@fc#<&bZR5gU_H>B~C}>sT3#pCy5V*pHX;=ZtFlt{JUSN`&SM(i%RDdloFih$i zcAtAd8L$NgzfCe`Hq5$7S~}^OrAAmd7i1eb_wjT`h_qbD(`ZvSHwn? z(HINI${LUf=8Jfn?~iYT<&eYH8FSNGzvQO2b+F+${GsuFXG91rpy%Jjj zHOgU;HBmAVmKp~nWiqSAZCxzEBb#K|8xTO$29+$WSMD8XKM|I~Qh#b2J>d>NZ14nT zKM+oY6Y;wV`7VCKyRw@N50o{emNaE@(|V{&2$+GQ|Cy&h2G|R#@YwYYm%GC zB;^!0Cx^VM#$debc1^a&c0&WIt@t+R*KRZi%&rfBzH0-}-Pq;e5T|&Rg1h@mZY(q~ zp}tAjegM};DS0RW+lZF*8N?g-QhXLkl!V@ak3DdL=q~NmK<`J{cW@0LwWsdYFuQc> zets^>2D_R(w-T&(o5Ma2zKcHh@vZOMr;HnR5sjKpdWF=b(kq-LcqU^JlH69CNbf*$ z`StO^*M~=;Hxj&oZG9l1`mF)!4LJDHHw1Vf0sf5KN|nBNRM@q<>3lEn;D2BUKd|)> zUq)*&2Si5K*+0}7=@J!acctl)5m{`Hi5khDG~Myj-<{dD219^t9yR3}Ov8FWFF`wu ztU=7WyP2M7xLkoJ2H&0>tUv>sG+VG0QxZqmLY%boRL07MRdkUQP#yU05v5j|o_-bY_EZURIAn^W~`RQ@*-%SlD-0QzD8;on1YItIhA zt}bi*hTIwoE#Y?{uW#X7yyK-fEkPl_<0QgPJkq^;WOA%0|Rkk zu?_YzZsPZtyuySqvkYKmI5UV1^@ajklafABQtg$?utbIe*s4Jg?|WUjP4DkVz(w@5tp#sO9@Z&VlP8{)h=z6Z5^u{E&@yhGX$WNp+*X`_Rk(zv>cS z;edaK$umqC@eXE7vw~bAZUR)7K#FfEplumtnK~xg>>zL8M;PN++V7tMX z1xSG4A7@t>;1_Y9Wx!L_IYhUz{g<}C3df^w^Ey96FA#30GR zd!pcB@{mNq*G|EkW%!<0O^C0^8rT~cwgD)ik&4?$4cm}4R4X4+Sh^uk>V4a_dZ~Nb z_-2@DCozhBzgMH!_-FvILG3(TcyjMd5RpyeZ6+Q^W;hPab=O?JdtO^BVLm22^Tavn zFW&twj%O$bL7ff#&ohYqrKb2Ss0~2X3!-Gup^i}Jr9O`0kwa5|%*W@Ld;-aWtprp` zH;ToV|yl5>5e}AB5J%$eM+BSL%4rfp9g@I z(T{+~Y;f6O@L&;7bH;v04c`gr7$Q@&!aTD;+wbGZl(z8*0bF&i6>KxhLmAjK;&Ffi za}9ZL0|9*jM>Dv2rUOgjc2c}~CY@{3wcuPUU?JrudN#tR^v1sxxGC-R?wyj-14_^b@FIjA*6oP!efI}ot z$WyXoVWCkhYQ3>wyisaQggil8K1e~p=1TtnwnK_xu|FB)xoOfX?i?&ZE7K4f^o;WH zu*mU+<7zC}1{&|r$c>wE)SC^N9%OC67AfU?Dv$u*mO?u$0e-L%Xr{xG9B%{chG%(( z+&)+5$#3aF?{eZO+tSZHii~vPkt_*@ z%iWJqz0$#Ev{Ve&8&`C@^&N0|H&l7)PP!=Y{bdxGOc8HdM&*KttQiE@s!h9k)8 zjd+dqJ51zqUuBM`Bn0t$34aQjXF=^JX`mvZJ-U%#<{U1cKB#Q~yL-pK%t?F9PTNIX zKWrbeFWP0n78dj+?2KGTPT?ey_b3#V=@Cx$PpJQ#IDvmgI3z?P@Q+ z8Cn!G+a{$_(M1(IZId2JjI=0 zH+zs;TXBw}OYGY>Z{GcW=l6TOvrx!>?8@@l#fJ^!o5sk;M*f}n`kzPQ8m`$imQDOx zO{GM3A3$(_XS6zZ1U zDb$_XvX-Y^^GRc2&nuXHYeoeYobH_W+VxFO zq2444PI;Y&I_gAq-LEzuJBh;Nqpj+yhm7M^123qzyc+xMe-d%VKa7kg`?G9aHdaj6 zbgfIya>7lxHfCsFHf&=#iFqX>=d`yGoOtZf$7@#Hn!oIcUwbdmSGwp z@|A>ZptX5w!H$ytde>9RLB+_aRO;<|P^m=87MF#U+P>FZQ8~QDv@?}Tty%T`O64Wv z)%Ps9o!WU%VJX$6M(12Rs8`#~YJ17=tgD*mFP*&a;YUv_)w$$Kt=bLhop!ywx`h4d z;BiStSX#ytr(Zc+4Ql5+cd>g>WieVFU!O!{m_>7QvcJm3{tS04mrFju25c1?4WnU( z<`xeQOSEM6QO;}Eu%=3Y{X!XQMm8T*5m}u~urR&SQ5UMpE#bkJO*J=5gYn~+3n^|^ zyV~-wS=mYji&<|npRZKb)~ikZMpmNA?6nYYUCnDYE0y7x3#KY?d=4##z^#^rJ7(UN zk{U_|AD${dZ-hoNLU!K=k*EwtRC}55Y&1FOJ&Z|1FG6-wR>Ep~@Esmuc`%9Z{S5Q* z846AH4>2iE!yVOI$xYwS3B)HjIUIg)glY;E)HD;$b$D88hK~cf*l|YO1vVL;5lc=F z=#d9Rx)pjnicDy##Q+q!Wi%`oRN1qFL}+~iqAxT)VSFmFWtv7{Hjy$=Um=EU{(khKUqeQ!NA#(gctx$?Y$*!umgDBx&^rMs? zxSjPNO0G25{c{TmHOU6)XFYYc<9kuY4Jv2{P8=>!wk&d*omKGvN@orYdjp zQtPzImp7tve0u68>OsNOXZ-g1Q`&E9va{TI@^7+Q`t6%#nKmUCEDyvR^?%$3_=Hk_b zXP*Cq|GfEI$6q1S)X0|n@F9h-NV|{YWXT~sHC1%Yw&hyeP*NzfDaXMd;EqOG=iLO} zIicm+=PmVAkZEMy>%&^wB(a1F>*e7y@ zMhU6#q&!WMCoy_CL>cN#i8{`@xl62_YfLxxknZH&%*H-xvsWlF+naLp;Z$Yc+RINH z!Tv_sEwJa%hS^cPVn4$f?#FjFIDqo~(!(59d_&_Pp6BA{Hx1s+58lnoyZJDQ)^83| zNDqZ+q&J3`Q{`}&L3$+2BE2chA$?1jM|yLJ@hb~q5$P@HHP^T`oY*qmiTwsjZVR#Q z%Iz-G#c&eol1QS-Q;}U;cdJpN)?JT`%`}Oq?R7TyKG^iugL9s`@6n~W?Dn@FecQ>? zFcYRBe!A;H<(wumKY+30Z)*D6%%S$t+gqMnUvKSTX^IpIWd9hFRfv%@cafQ0h;i4% zA=&|08WXn8+X`esAw!|6Y%IpY#zOdf)LOGCZiL|1DJQ=8X+=2#J15cf+EKa=YX{~| zly9~)@zY$3(ps-5h$UfZYlo`wHYhGlLpUc+z$ z4Yzo#30tS>--}GY@k5)=Q}xrg-@kO{QOKl}>pfeqc~QabTxdfWxfRvvK$17=fviGL z|BfgLKCYcpKZ=o~eBG~9pRLxL)w4}cJ;6(!`)Em)cP*tlQw4Aqwf!nLWxo8UGprpSG zkmV5m#rAruQpdciG)*~M%D-j@_HE#`Kqb{+U3u~7K913goSIjObAI4;eNazS;`}S^ zj%rm_sx=%_R9IQ}(PRIyjOD=C=OC4?*RD2$i&08bT$J(JZoIQuX5$`G-b$tJZjA1( zezP#G$?_z&K|RjBnB9EipHc4*tji`muq5BXaD%Ghd-RuTJAuXLUJCt)0T-rRhsvjy z3ZbsH;A6epc*1y(HbF9{t`T=na-y^~Mcfs8CMQH8rv;hbT}x3>Ms^3Z154!u9z;Uz zvQ^Xrc=0U0Vl$qSlv!{4YwMl|8T$mDMO2M0;Tj<;i?}dc#Q52a?pS{|qdR7*_o0jH z541UEAkq#+pe&n)uq=D}8LJ1wJ;Y+EXQpFKGdu$593D($)?@7vr7w8()pLQb>j-b~ zgnIbKC!|HY8716$>y_`|v*4&l*h#wG(LyodueT^7)>{-to)Y1gU0H9}C`=%OGZZ`3 zCdCjGNl>kwgM%Qht;2Js5KsZTwDIUiT1e`wh||hS;77$ctBFES7@Nm#q_Or6CUz0e zIu|1599;wpYz!u!vmC2zZoc`tq&ILTY;;4O4wwDgv3NFa9o0hzUHb>}h9II@TZwjznB4=@RlfUad%#G}$$i5ka`3jFD_Rea4#VxP-oZy2ys?T-KU1 zHxG^9v(aWX5+*;~FiwQT>VFZ*4ndb35p+orW~mYJW=Tmp7JVO#MW2YbFeUFCshtvG zn3j@EECkcB5OfAYFuM&<4r21tVHR~P2udfe18bXYx6HeZUN*#EgO)EW=PT9$DdC~+ z5RirMOF*hM?87L}51CUV0_47raKm`8kdOC_-*6asXK4-fPxFT@B$dK+hEzDZg7M<< z`QaMIi^qegAW~~)s8Du_Y1tYOGg(1xy}|KcxZe209A$hs;}p;1RGN&;iq~K7FaE&B zBcgZx@qTH{cx1HiT`MRb&HH*G6YGUaeV1nV`2W$kG2l7SNWU^1SH$IMY!T&Uls@ly zUFaq)I8!6?af9o*$T~I^kd}=(EgPxe;H^+Bs>`gm&E%(;{4^8FKJ`H)kijA}AwHeh zAc=KPLiaRm&ACR?l%l)2^9tUG1x1*+to0DmS(us!sflj-?wGu0qaS`~Lh`Ur;OpOp zMBs8N9#}e5NTIct#-xqB9oj$zESLsK8hlVVK$(*2LujR4sdLPoLJ}2LH5L=Y&ss-I zP3YtfT>xwa+$^u!s)$oDaM(0L6BS~YI)2RcAJJ@=oQo1_z5OW+8Y6fqt8A65%{|v{ z1nLb;XDV`V?(UOW;f6ISGRZHQWRga1zf@&T)%0 zTl7*dSQ}R0piH#E3+9=B@1^mS4TgDIffBMh{o zJ+Y<9FcD@~fFrp{8W8QT2IXL?F}-Exj4%_XjuD~7uWPMX#|+MpWwC*#r? zts`yM#-3*}Q@$Ua^`xFP=#y{EFdrtTaVlX-p6QZI;@xn4Cf;NlGuZvTm`iC3*es~l z4wDCr=MvbbeZZ&p<1C8e?f^rByE7XW;oykH@SL@sXdHlc&s;Hk6PK;a3G7^!wZaMN z`RA-)4z_?p{SHGd)wluB%<0FCwbM^y?>Gw2W-v-gM#;yclm?@q?qHNju|o=#L1Lk#0}^`D^wZEu*HcV5eZJEX!mLxX5q*7f`vumK?UgVksgCSDE2l<3>T)VK6c zTpVs-Y=dnyMn+(D%iPYJepic=z0u=GKCix1sN!1h{}{xX_H z382#bb4bN!kb1rS|0Y_+`#+cs_ijK5#Gz$2`TpLGIFi~-id}Bm5;Mw=NZ#FBaCI=6 zb~Z$D4wHq0+PzB!z}$h#3;6;?LnMYaAUyoG2Vez0NeA(Xvz<;eN&>}O^#rwviqwB? zs!V`mXMF&gomE6IJKgg@lQz15`w+`LSM7RH8vM{&50HtyZ=ko;KrLVp^)8NteO!s` zXY1aD$nw@S#6v1jpOm+oK;qIwc{(akholHd+d^5ZPDNrsz-M)^L@??i)`ao^<8(e~Ta#WP+MJ8eo(YF3Ae0>**Va{4MB48zR+L{JN z<-ox$nX~3T{H9=HOj{+CPQ&b&Hi2@%0CDimA|=v=KNuj>31D7uj*Gy%&7f;-wvusAjBikz zro%LQM2b)6WyvGQ~*9#08y_)NPPj@1$jh+7C(Q8pCLnz%PnLD9tCbr zaDWKw5C{N2byk8_bt5YH=jtm1JUNonF z$VO2Xwj;hddi8KCr99LX{xn~PV!$KC~e7$6Yp zC6xUWzCO9x5PCkCD4H{72^3zoidG6#I)mSmxevL``EgV^Dt&4sMSl8$Jj!_gWsqeI zMTpS=1JK-9fhHo(YCd4V0DP)vj=&Qda7b^=hX5MR?1M%SLWysxWc2UaJqVxy@d;pV zhUqP0LI`ew{9wKbndHKpo21qg18tPN@_|N)@*pVUvIW=aApiygv=G{2icR!N42d;z z;l#87sDk-C-WD2#Ei4YXq|WU#V6Nr2OpxCs%%Vx$E0X$Qu`$6QA=;HNKbte`mE-xA z!5W-zYU37J;i16_CmZF)l-$S3GzR$>%(c0vp^th~8#jm8!86p6+!2_iFqIK%gn^87 z1$xrNZmvBB<7%pBv6bvh~mdJ9ERo*BeO+kDT-zFJ4m7ws2nC=-&Q%y7rW^F z0hDDxjD4)<-fQW32M3R{y-?N<*})1+2J7ACV9Z;-i!sO8d)FOv$9|E;rtlS}6(ePq zfQ}qAciVa0-R3z->zj~~)(1!7AntVok#u-<=9NBrZ^7QwAbnDsDp?5;ifI@XcLeMz zZv?WUy83|~h59-c6d!(Mbt{;%NY$aVJOD>^aP8`IpSOBZI;XE^uwBZ+bdF(bkp2YSUGOSGz8xz)2v#i~DI z10APeBz7b?`C0ZyH5xf*t2KrpT^;|3`(+qb(Yw$lbOmw(hG*N==u{jx_1VdvF!^I9 zV#|JwZ3;VM!`a%pQnBi_x_5xi-Wp8^`vMv6E9yakR>IoY5dF*GI3`b@D9R1O%CPR! z2q5!Slv}AIfFn~EJNQDsS8bz?hE%_dUT_K4@JZLexnPXK$V3p8rez7gvsB}Wl37d~ zMv0RsCRmDqE0lZDoI_n2rBwBNvd0>oJ492>=1BPP-XA?T(eB?y#xWdSM#I zDBv2Jr2C;iUCT{fp|+NiW5(KNVPoX&zh8;3vAZzEb zMIvkW(aZtHLA!e(smE$X2TepyzrEk~j=r2GXul1kX0PC0+i>oOPt#bsJ9OzS5p1h7 zPhEZScN1Sf1}FIubf3hGj)zn~I35PHrN1X)+#=CH9Bv!-AlGI|0c_C@0VE$d~tJtSkqAGll0M(HMQz`EWBf*(4g{(&9Swg9v*)H`NH z8a((*R65RGJ&X*m&e{MF1UN{PZNLX`{mTFlW>9_$+$_l91Fu8r0~t&Tq$zj;U;lpe z51;BD-i4!+uVpO!s+)TD#%UmEkcaSu074T-gGAOc`0bFWaLvOcVFG~#9)mk=w;h*| zfekmB-juv2oE@fa5I9?;AHyEW^GWr2sR6%h&<8<)blgXJ(1+B}y}Uz%;+EXs`<811 z7T_qOeSAGkw^NwK*!FrB8;?NR8#;t==APXvdBIfs@HQLIBRiM}*1BUY--a|VLz*Wb z$y4!#Wa>FBIrVMU|33c8w|NmGH>5FOq+SrEf)4E#eI@EVTh*B~mA&g+Q`B?FA%?`*DG4N{WSgUe3*rAT?`8J#h-tV_}~-p$4ey% ze&EV2ZfCg;0!C zplmiQAw9rbU(=BOa=8JvL70(i?zvNSUQQQ++M$#Np4=SvS(M21Ff>j@bj9i`A0qx! z{&`$P+zBQ@w%&)via2;*&Qi{dH_R`PW-t}ra~Pve2Nvl#4c-r>taCWtnc0;r%;_sM zIb0iU`+%r$IZ3W`ilHGExcX~0`)wwlVIs2nZxXa_>=`*KO zr?tnZMNhw=!y~!{8QcMaJx}YrgitNC2ENl^?jj=CS+M#n-p_)~XZMWL7lWf5@H4F+ zs#f7=Ty6Xz^7F=j?c(LX?`6T`wpOK6CIrm!pEWDf>;bn*+l*POZUKNB)cy$`Q z>Ikm_2NPbM0Y@UFIm)Xj*#}LCkx@PwmMCk&Nd`doGe|qx*iXsOn-q>_dVsZoCL*}6 z<4>Bqdu8x;8K@vQ9CI&={Aoak^f^BBS{y&{2*{(t@$t<`eHlis<1ns{f6I$G-&8+_ zszAF2EJ9xBHmfyHy_;`Fu6~YZ`j<#iG1HF$p9Yjt=`_ch?&BFhW_R^Jn8Oc-&DD3< z)oZb~f{Zs5IXlWqC`?;_ifx9hJ~fWj6(Y|NA{CWj7d8`eno2U4W`h4K#edl`;kgLs zD&mAHiv&tuE{+2WSXcN*$M`b%HSCM`{de5TUq@5_D7w3bk-^5cF!C(!J2OH=HZH;h zg#JbhsW{dV-_iFe$TW0ssN0z=GFf7B2a?n3t;}*{bte;gJnC&seuT;0OymVUOZ5&W z_aIrA82f*3Kg0JTlt04U$Cy0N#o$|j0`X3(;p;#kUw=&1s kjaz)=9Mk^`dN0q-r^45_-su1^$2ej+Q+Bdx$99tc3(NT7a{vGU diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc deleted file mode 100644 index 816cae9c86e358182300740c9e5c22541a55d211..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 15568 zcmd5@d5k32S+A?RtB>jF>6yJ|$D7UC>v(+3IBN$dX3564P3)CpFLpAq>GXEh%yw^I z^Qy)>o>W7^ZrEAG#tx8hgjhWYp(Kcq|3UnLLP#KytOSH84uOKSm@`6vgcXk8?|aof zJ+p0Zj1xqU>eZ`vzwdqb_vYS-iJXDoCoeCZUU2ex>;xymI_I` ziKU5$3@_uk7Yxt!?MCsuy)=oEtd~PcE-5K_GhV@)xL__#c|~vXg0WQgO5POS(`Z}v zrqOo#ytOprnNJ$?dtSk+jg#|cY_;3-Ry_I8Td zLcU^q26~$(=bbnetaW{*Tr`Z`N~PYehm}g4YH?fGsT26kWtB%QmYuCsYRzg8R4Sh{ zUVZ7AYOu?N!3wY1?tSdQUXjr-`etNq z@!+sUOI9D}{dNs|s)T?S%Gfe?_@Ro(>SU6IndOc;S5;mKAAtbX+$as_Pb%lr1XjD+ z@&T+|rGm|@HJL9|DywVNrmm5ls4_>LNI+Nfo6SmPd(L@NO>llrEr&o*%f?-^;7Cag zCBq+2mA^NFBUwSaH-RK7iy75ECOjLB4jP6DYUoV}BxNUTwg=Dl5tfIG_}`ylAvr_A zslg$pfMs68>%L7*Stk90EPlNSE#;1+X*yl~t2%Sc%pC)+g_PLgXe=d9fV$I zEsRsk&9&gnysf6#A@iKCPIm%7&U#@5{XmJ^8x)YmZnLuj`d{wMp+i81V3-c5uWYI# z$Paf&hz@sn7#|VX`2p;c={|m0@`ZGtv21|~?UTwxyH}XVtY9qWdTc(wdkOif=T&N* zHtF(uG)_)W-H3dRz4_z(gt9=GD_ft7jJ~mD+;2Qv>VfhbCRtr%R?OAh(;A}S19qJ2 zRuw2V^i>c$WLtz$r%HA}s*BC;?)B`&S^&6^^zAo;O4z9^ckej%HnbjMwz>shH9W%~ z0sbC5K^w_{k9#+JQ(hVJ#`dPY8FF!NkGB`^l()~@k9XQT;LYNl@eX=(c)Q+p-XXlR z-dnuGc;~#gdPnfid)Iq6;9c;JdN<-d;l0hf3Gbpe4^ZC{Pm$2cuO9N-zN&^DCBT0* zwbBi5JND|{g+&m+>AL6nZR`gUD_CfEpMSNyuvlx>yA?3u+SzUg=oGv{Qa<%c3h9*( z&wSybzuvfT;tD?S`}T<|zx0dzef7kZJLLDx6IZ_eHU9qL#1%H3`NBV+h|Ps7^Uu8S zyZ?66w@T`bag>{fy3gHtO8S z@|?4ldMeEJb6#r8L|Hy8G?4ewUgk3UPqY`%%0;US(kh#@Dxy_(yVa)A$TucACUO(~ z5>lIU%GVTWGEDW$5z5eJTEubA%U@vYd}F4+hj@3Mb$i$MOP_s$iMigiSBR!7`&VCn z+6WIc%H9M=jvS~R#mmk!T;Tycv*AIMACM8|0EO%Nhwwcov(1HX8P**h)*Y6*!%+&o z-x{Tn9*HtYuaB^%$_-H#>Cq^M^u{QU^lecA=}i&lugph9q&H*KeE*hca?A844`8R2 zTO;hda+}9=A(}$ED4b~OWbD+|ylQOMx@)nqks%Vb{m#a|2b%s`c*a-vJhGUS-S+ll z?|AGK)PxzZpYB>%Iiry*2%+o*8yf#M@@ReRwwCYJ*IGMBnk_!lNjeO!ueMLS3J;!eP?Ko41 zwgYu1E;L&j`KhkO87)^7*pjtyySjlrOi>XuLm(ZznmgXC2Vt$#>c0Cad<*@j4a?kE zxQgJ05^mvm6S___csDYG&JS+5Pt{M|cJJaH$H0@)p8s6E=EoCW=Uf}i$g8MM2b{c7 z4`mm6`M1X@&~fdI`a#Sb7wSQ!`dqc%te$TA>H!wzTh%tyj!NMB9(2@hoL1Gg*J)9I zLrF2=sgt3R;@wDKgqP3(29u6w7V|vRxNtw1_v2vZ&8FYfs%uu-P;SIiVO6d8kaGSiCT_2UXX0th zgOULhz{|n>i|w^mrH*w~shV=N}D8T{AzU(eg2XK|M}b%nX(kk4%46a) zS!S}rq|T(lq=_U>pM`F`5~Dsq#|2N zz){FSURW{>sdJ>xk-F3-IM3Ov`cP)|bzVlpcoW|ESVfp);QOLr0WBuTV@{s zLJB!Xd6}?YQ<&Dgqn6Mk&aSnC)ivJ-k9`8)!m38Ma21o4O`O}_#Q4>WLTs>_QHYu9 zgBar46K#x{h`7U}k!@2EmTk{GWA&i8N7yXo%uFI_wvPZZhX)gz^+bBanR9-9)x+$cX?HX~w1c%4xx-qEe8^YA_j1c??Hc(A*l(5` zr`jY3f(!|(wKK2{#Hw|9#H4=}a!3cCL9F?s&I%{3EQdi{OtKmu^jZNo4kMklVKBK1 zZ`Qa6(d1|ySfF6g_nhTgWpm@L*TlM^9pPXYk}Clx>|%iZeON}sDOgHQKLr9k0=I?H zPe&HYY}iVfhTF(SHp%<4)w8MNYeI|cevV{Ei-W5Kf{`TV^Xdb5#jfbP6!}AR6tF9z zoO+fC;jD-l>cdPTB*10?h}cnD7E#Uw@dRH3VoDBrXaYknfZD>51{W|^K*QJ&pyjMN zbK}T(aE*4W5gqxXt~_LKU&6-@_Ldr9Zz*A6=@F)8Nl7MQcppg^o-nm2Ep@K6P74dn zNJ%zfeVK&yxkJ{M+k`6)%&{Nk(8dCLbdxsOL7TgbUM|94!=dBN%U`w*N(s+%2U{$> zDYmFq0f$ja-zG`z2O!OT9M6C&-aQ`RQ0q=p2O6B__nS2;g*9zs;ph&=GsPFS_b{F* z9zp}*ReQI^$WG=gpaC+I6x8`QIRA^+nxCFVnI9%NM8a4>CNs0)wfFn8-?RDX$Gi4? zKR0H+Hzx~JfFTFcAe%^oN`03+_|*T=yosB0D2IMwdtTv`rvM_d$~be@_q&ux=@Hhr z&aFJpYKLM+v&Fn-iU`Un!RU*VVFoNl6lBofEgBW?HQ zGUFLC(A|W288xCZ(Jx%q(g*1*bjw4OKR14N^j&i>4u8}cnHM2+D>C9Ahq$3Tf?_DL z_Cd|*IS@cjX#aB=N+GcD^M3W-XbeB_wf&IzK=G zyUMS&Dq=+pEi~0@|0V0=20Q_|{v=AQEh}>RC9ceZK*70SZ5ZjxQLbmVZ;kAvKjJ<+`oo#j zn2PKuw``$2O|_x@wXhscH)hV8c_Ye3>Em?C;_X?>)^UR?l#0@m zP!HjeNZkd{p%uzT+d;?`qC$TUc4Ut1tRTWL&ovz|W{utiMw&Qa^imL!N$FlK9vwQf z(At*z(N#No)xb~oG^0Y4n!%|=Y5ArNGKISB{h6rA_4nc|_F*lhEjVL=wN8{eXk4

OL0yf2%*$&1!S8<5MhRWwQU;F1;2K@G@y zh07MyvvmJD_+m~yZmgbq8nELmJe%PxC7ER+nWZ$G1#O43Oo<9olsTqk4p8H!b|FtP zkn;eLHBNcq^DxfgGo*UQB4J?KQKI#|p@e*04*ptFif$0<)C?C5^!i-b>Q*!j#*){L zh4Hfbk1>GOH_VMYt|r4BtYMs@K_6Pq^Me{Z0`SuhcSGHjG%L0Vd6mf98z^sSm3TPe z(AWg+XwMj(;OEVK1Fx95;T%4C^e}u*HvIpJUbZ*|`=a~&Ksr)j!39>@-m9>wK<_xH zb>KzWCBFU?T@uv_usDpb;c5wc1`%YnHz6adJ*+LM0qi{kHner#B%N;?k#N>)hWhj9 z7TfSe56&T-oI(2a;QuZ3O5i_S4Pi$U;Pli|v^*z||Z6B@^*dHB4AWnBW%{T==-HI=MnYc)S*QShw z`|ET7mt|)KG0IN&Ec`|5T{!d*tvyri`f&yX&{_+TN$hL*XRD!_$0X{#oCz>nj-BV~ z{<+xlSG5y|C$4^3YB%6(%g~S0QK32_MewD~m$e-J88-X`OLM^0#A}DRk|02w-}b8k zGvJC0u);3fhlCX)QCxZ&ORq!F=Vd#I5Xa`wxCoCH#=w_&yONT! zwegl~aDbsmrhP|sc8C%e@ktiZ=yfhT$mTnfmh3F&xjWb$K(R5_J7q9e!(#xbO)RnYB}tsxbTsf?LFRMQVw2 zxE29xpW(DPlen4VY2B}K+J}!8gL104ECjN%EDLwL<}k96*x~lnZ*g?-Ck&km>T@Xj zdprS&))2HkWEai7W(kN~wu)97$ha5plDQwbjl<)RZGx#EJ}}HXd%=-pF-kbg(A^IBnyh=M5L9+$zF-EtwuhY zoH5{iV7`FbiT=bEHiukF=k^=@qLfn`%rcc%KqEomPW`;(0JVbl`V=R_IcYdKkO zIZDfV)9W|O4v!3W*qCaR`_po!2tG@(M~F=KXP^S_L3BFH zbQyGM0-CFjLTQ@rS?ncH&3bfGYpAA&Fw<+W-oUuOs>h8iP*CnQJS{st2LQTx0r~{> zuk!=9uC+ssrhp<(5>0mYr0cqY4fSn0h6fH|+^1oT9lq+iNn!mULHW*f(ogfk7{kKgc}-!B*|+9lVT)12R{KFh^^2uYkra2T=SBt$EEmffm$xV5#3 z+Y?*nC1?HR(21-Dd@KjR`ysS`=Ms3-MF)02@0})5(@Th>UNq?sKurq16x4(e;o9*J zkB0W_JQqP>Ts^kF4soA}9Lf@mcm}aj7os)I8TdYKtlz6!=vk%JLN60$8#&!B%0S74 z--W8?0SHXbq90^XcidS0Jd}#O{vM3Edf{nf^%(sS^f+nd@p`~OoCZ6X1Vx!J*TcvN zYu-=22F?oV(?iyr2(z7HkRX=6?;P~KeZZI-2o!D_G{|Y#Zo)v{{+gj?zHOFUqS3DG zed@|fziEHtI1Jr~F`W7-CV!4(2gAU%xzlm}Tpd?C&Q)5~;H;LHL$B@wBfMiwq(d$# zklL;xCN(A}Ca`b~OZ^=eMAh>(r!+Sfua?fQJLQg}7=rUoOd>~d>oe9Rc)B1diEjom z?qX`0Y9S5=cbWyK#%d{4VqH9PJPbw4cuyq02fKk-+9ni0VEc}v*~FEq(SzD_HgSCP zr@$9xO~V*zz$*&B8G=}Y7%&(rQ48*J4=*_~eT9MyO}s2(a}OLCW?H zSOIMc1Qp?F3jD6zC<(?v{E^`y+Z%7T z9^%{fQXHsOPIbK=%onuAk)^Mr(yZf#_)z$1$wJ*vW>qJc`x7SbMFM`?X~T_yD>4Xp zb8Yiky#QhR!twu#?Xm<9(HMaQLW|ir)5P6X&tFFXLub_AqOHc{*WQtK3j7&bo*-0j zK!%rI9XJNWAz+ssxCLO+m*E)Li}IUcSAh>7d>u?5@?c^hb-(-Y1ovWm7*Tig7Fz=h zK=W7_QaART^;7VIfgi&30ox(V=-c$&fn&ZC92F*al%h$GdvK0=6lSv1c0IZR9GJ&6 zpQN6C=_q}j_@#yW(V$M9O{p(R3m96%F<{&`GRYX3;TY0FkMb_<-L~ZB-M2lbm$_<; zZ$z1P8mk!FU$0{QVQ_mx#{!fhU>syckJam;O1p;a~qsIm0TtM z4(?89&Z+OA{>KRz<(~zhwP|7Fx>uYIJ5*WpJ*Km)t21dZ5v=Gu{3tHX@TRjuy;X5j< z--tj1ax&`#LJnEpH*Mtl`DkJba*h_R2;}zDw5E2Xtm7&#v(svbYz=xGRmOK)<7+w!+ttA zi`#=c9Y>kd;JOafb}Hc|gjOLO@XUaai-w_rJ-!GvymlOeodjcK9FnIT=mza@0J z7j%gb>{!vi~78}ZDLLyTRbNF{pTs&uzVZtL3qE$4W75#sZ<#MST z*8ZBpKOrU;wXdRF)DP+i#VvFVj$yc~Xc!di2o2BTE;3^~q~IcyJcwyTe~M#G?HyyE z1UtjPCq99Y$^Bt^*<#ZL9zBp-NGDKH*WK>dtCpY=DT@r Y0p+^Bkxmt#IYt1>*|M9i{%8i%-qEE)cEAYf|AVqyv)3G{i@WWd>~gp kH$NpcM?XG3GcU6wK3=b&@)m~;P^>g3)edA+F%UBV0D?s;0RR91 diff --git a/zerocap/model/__pycache__/__init__.cpython-37.pyc b/zerocap/model/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 3108a5ae3baca75bab5ef489bf20563f7177ad29..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 167 zcmZ?b<>g`kf~mbxNg(<$h=2h`Aj1KOi&=m~3PUi1CZpd*((3te=^in4TJ+oLEqjnV*-Lm#$xxT9gmu l>gVRCq~_?y$7kkcmc+;F6;$5humOsd=A_zzZ2Anu3;>VrE06#H diff --git a/zerocap/mscoco_zerocap.sh b/zerocap/mscoco_zerocap.sh deleted file mode 100755 index a06ae50..0000000 --- a/zerocap/mscoco_zerocap.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash - -# lm_model: -# 1. cambridgeltl/magic_mscoco -# 2. cambridgeltl/magic_flickr30k -CUDA_VISIBLE_DEVICES=1 python run.py \ - --beam_size 1 \ - --target_seq_length 16 \ - --reset_context_delta \ - --lm_model cambridgeltl/magic_mscoco \ - --test_image_prefix_path ../data/mscoco/test_images \ - --test_path ../data/mscoco/mscoco_test.json \ - --save_path_prefix ../inference_result/mscoco/baselines/ \ - --save_name zerocap_result.json diff --git a/zerocap/predict.py b/zerocap/predict.py deleted file mode 100644 index 46271f4..0000000 --- a/zerocap/predict.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import tempfile -import sys -sys.path.append('CLIP') -from pathlib import Path -import cog -import argparse -import torch -import clip -from model.ZeroCLIP import CLIPTextGenerator - -def perplexity_score(text, lm_model, lm_tokenizer, device): - encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') - input_ids = encodings.input_ids.to(device) - target_ids = input_ids.clone() - - outputs = lm_model(input_ids, labels=target_ids) - log_likelihood = outputs[0] - ll = log_likelihood.item() - - return ll - -class Predictor(cog.Predictor): - def setup(self): - self.args = get_args() - self.args.reset_context_delta = True - self.text_generator = CLIPTextGenerator(**vars(self.args)) - - @cog.input( - "image", - type=Path, - help="input image" - ) - @cog.input( - "cond_text", - type=str, - default='Image of a', - help="conditional text", - ) - @cog.input( - "beam_size", - type=int, - default=5, min=1, max=10, - help="Number of beams to use", - ) - @cog.input( - "end_factor", - type=float, - default=1.01, min=1.0, max=1.10, - help="Higher value for shorter captions", - ) - @cog.input( - "max_seq_length", - type=int, - default=15, min=1, max=20, - help="Maximum number of tokens to generate", - ) - @cog.input( - "ce_loss_scale", - type=float, - default=0.2, min=0.0, max=0.6, - help="Scale of cross-entropy loss with un-shifted language model", - ) - def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): - self.args.cond_text = cond_text - self.text_generator.end_factor = end_factor - self.text_generator.target_seq_length = max_seq_length - self.text_generator.ce_scale = ce_loss_scale - - image_features = self.text_generator.get_img_feature([str(image)], None) - captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) - - # CLIP SCORE - encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) - for c in captions] - encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] - best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() - - # Perplexity SCORE - ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] - best_ppl_index = torch.tensor(ppl_scores).argmin().item() - - best_clip_caption = self.args.cond_text + captions[best_clip_idx] - best_mixed = self.args.cond_text + captions[0] - best_PPL = self.args.cond_text + captions[best_ppl_index] - - final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' - - return final - # return self.args.cond_text + captions[best_clip_idx] - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") - parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") - parser.add_argument("--target_seq_length", type=int, default=15) - parser.add_argument("--cond_text", type=str, default="Image of a") - parser.add_argument("--reset_context_delta", action="store_true", - help="Should we reset the context at each token gen") - parser.add_argument("--num_iterations", type=int, default=5) - parser.add_argument("--clip_loss_temperature", type=float, default=0.01) - parser.add_argument("--clip_scale", type=float, default=1) - parser.add_argument("--ce_scale", type=float, default=0.2) - parser.add_argument("--stepsize", type=float, default=0.3) - parser.add_argument("--grad_norm_factor", type=float, default=0.9) - parser.add_argument("--fusion_factor", type=float, default=0.99) - parser.add_argument("--repetition_penalty", type=float, default=1) - parser.add_argument("--end_token", type=str, default=".", help="Token to end text") - parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") - parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") - parser.add_argument("--beam_size", type=int, default=5) - - args = parser.parse_args('') - return args diff --git a/zerocap/predict_arithmetic.py b/zerocap/predict_arithmetic.py deleted file mode 100644 index 1e2ade2..0000000 --- a/zerocap/predict_arithmetic.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -import tempfile -import sys -sys.path.append('CLIP') -from pathlib import Path -import cog -import argparse -import torch -import clip -from model.ZeroCLIP import CLIPTextGenerator - -def perplexity_score(text, lm_model, lm_tokenizer, device): - encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') - input_ids = encodings.input_ids.to(device) - target_ids = input_ids.clone() - - outputs = lm_model(input_ids, labels=target_ids) - log_likelihood = outputs[0] - ll = log_likelihood.item() - - return ll - -class Predictor(cog.Predictor): - def setup(self): - self.args = get_args() - self.args.reset_context_delta = True - self.text_generator = CLIPTextGenerator(**vars(self.args)) - - @cog.input( - "image1", - type=Path, - help="Final result will be: image1 + (image2 - image3)" - ) - @cog.input( - "image2", - type=Path, - help="Final result will be: image1 + (image2 - image3)" - ) - @cog.input( - "image3", - type=Path, - help="Final result will be: image1 + (image2 - image3)" - ) - @cog.input( - "cond_text", - type=str, - default='Image of a', - help="conditional text", - ) - @cog.input( - "beam_size", - type=int, - default=3, min=1, max=10, - help="Number of beams to use", - ) - @cog.input( - "end_factors", - type=float, - default=1.06, min=1.0, max=1.10, - help="Higher value for shorter captions", - ) - @cog.input( - "max_seq_lengths", - type=int, - default=3, min=1, max=20, - help="Maximum number of tokens to generate", - ) - @cog.input( - "ce_loss_scale", - type=float, - default=0.2, min=0.0, max=0.6, - help="Scale of cross-entropy loss with un-shifted language model", - ) - def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): - self.args.cond_text = cond_text - self.text_generator.end_factor = end_factors - self.text_generator.target_seq_length = max_seq_lengths - self.text_generator.ce_scale = ce_loss_scale - self.text_generator.fusion_factor = 0.95 - self.text_generator.grad_norm_factor = 0.95 - - image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) - captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) - - # CLIP SCORE - encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) - for c in captions] - encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] - best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() - - # Perplexity SCORE - ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] - best_ppl_index = torch.tensor(ppl_scores).argmin().item() - - best_clip_caption = self.args.cond_text + captions[best_clip_idx] - best_mixed = self.args.cond_text + captions[0] - best_PPL = self.args.cond_text + captions[best_ppl_index] - - final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' - - return final - # return self.args.cond_text + captions[best_clip_idx] - - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") - parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") - parser.add_argument("--target_seq_length", type=int, default=15) - parser.add_argument("--cond_text", type=str, default="Image of a") - parser.add_argument("--reset_context_delta", action="store_true", - help="Should we reset the context at each token gen") - parser.add_argument("--num_iterations", type=int, default=5) - parser.add_argument("--clip_loss_temperature", type=float, default=0.01) - parser.add_argument("--clip_scale", type=float, default=1) - parser.add_argument("--ce_scale", type=float, default=0.2) - parser.add_argument("--stepsize", type=float, default=0.3) - parser.add_argument("--grad_norm_factor", type=float, default=0.95) - parser.add_argument("--fusion_factor", type=float, default=0.95) - parser.add_argument("--repetition_penalty", type=float, default=1) - parser.add_argument("--end_token", type=str, default=".", help="Token to end text") - parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") - parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") - parser.add_argument("--beam_size", type=int, default=5) - - args = parser.parse_args('') - return args diff --git a/zerocap/requirements.txt b/zerocap/requirements.txt deleted file mode 100644 index 0eaf0ad..0000000 --- a/zerocap/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -ftfy -regex -tqdm diff --git a/zerocap/run.py b/zerocap/run.py deleted file mode 100644 index fab33b9..0000000 --- a/zerocap/run.py +++ /dev/null @@ -1,131 +0,0 @@ -import argparse -import ipdb -from tqdm import tqdm -import progressbar -import torch -import ipdb -import clip -from model.ZeroCLIP import CLIPTextGenerator -from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu - -def get_args(): - parser = argparse.ArgumentParser() - - parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") - parser.add_argument("--test_path", type=str) - parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") - parser.add_argument("--save_name", type=str, help="the name of the saved file") - - parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") - parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") - parser.add_argument("--target_seq_length", type=int, default=15) - parser.add_argument("--cond_text", type=str, default="Image of a") - parser.add_argument("--reset_context_delta", action="store_true", - help="Should we reset the context at each token gen") - parser.add_argument("--num_iterations", type=int, default=5) - parser.add_argument("--clip_loss_temperature", type=float, default=0.01) - parser.add_argument("--clip_scale", type=float, default=1) - parser.add_argument("--ce_scale", type=float, default=0.2) - parser.add_argument("--stepsize", type=float, default=0.3) - parser.add_argument("--grad_norm_factor", type=float, default=0.9) - parser.add_argument("--fusion_factor", type=float, default=0.99) - parser.add_argument("--repetition_penalty", type=float, default=1) - parser.add_argument("--end_token", type=str, default=".", help="Token to end text") - parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") - parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") - parser.add_argument("--beam_size", type=int, default=1) - - parser.add_argument("--multi_gpu", action="store_true") - - parser.add_argument('--run_type', - default='caption', - nargs='?', - choices=['caption', 'arithmetics']) - - parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', - help="Path to image for captioning") - - parser.add_argument("--arithmetics_imgs", nargs="+", - default=['example_images/arithmetics/woman2.jpg', - 'example_images/arithmetics/king2.jpg', - 'example_images/arithmetics/man2.jpg']) - parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) - - args = parser.parse_args() - - return args - -def run(args, text_generator, img_path): - image_features = text_generator.get_img_feature([img_path], None) - captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) - - encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] - encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] - best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() - return captions - - -if __name__ == '__main__': - if torch.cuda.is_available(): - print ('Cuda is available.') - cuda_available = torch.cuda.is_available() - args = get_args() - device = torch.device('cuda') - - save_path_prefix = args.save_path_prefix - import os - if os.path.exists(save_path_prefix): - pass - else: # recursively construct directory - os.makedirs(save_path_prefix, exist_ok=True) - # parse save name - save_name = args.save_name - full_save_path = save_path_prefix + '/' + save_name - print ('full save path is {}'.format(full_save_path)) - - print ('Loading data...') - import json - with open(args.test_path) as f: - item_list = json.load(f) - print ('Data loaded.') - print ('Number of test instances is {}'.format(len(item_list))) - - # ZeroCap generator - text_generator = CLIPTextGenerator(**vars(args)) - - result_list = [] - invalid_num = 0 - print ('----------------------------------------------------------------') - test_num = len(item_list) - #test_num = 10 - print ('Number of inference instances is {}'.format(test_num)) - p = progressbar.ProgressBar(test_num) - p.start() - for p_idx in tqdm(range(test_num)): - p.update(p_idx) - one_test_dict = item_list[p_idx] - - one_res_dict = { - 'split':one_test_dict['split'], - 'image_name':one_test_dict['image_name'], - #'file_path':one_test_dict['file_path'], - 'captions':one_test_dict['captions'] - } - - image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] - try: - output_text = run(args, text_generator, img_path=image_full_path) - one_res_dict['prediction'] = output_text[0] - result_list.append(one_res_dict) - except Exception as error: - print(f'[!] ERROR:', error) - invalid_num += 1 - print ('invalid number is {}'.format(invalid_num)) - continue - p.finish() - print ('Inference completed!') - - import json - with open(full_save_path, 'w') as outfile: - json.dump(result_list, outfile, indent=4) diff --git a/zerocap/setup.py b/zerocap/setup.py deleted file mode 100644 index 8ae2efe..0000000 --- a/zerocap/setup.py +++ /dev/null @@ -1,19 +0,0 @@ -import os - -import pkg_resources -from setuptools import setup, find_packages - -setup( - name="zero-shot-image-to-text", - py_modules=["zero-shot-image-to-text"], - version="1.0", - description="", - packages=find_packages(), - install_requires=[ - str(r) - for r in pkg_resources.parse_requirements( - open(os.path.join(os.path.dirname(__file__), "requirements.txt")) - ) - ], - include_package_data=True -) \ No newline at end of file