magic
              
                
                
            
          copied
				 32 changed files with 1334 additions and 1376 deletions
			
			
		
								
									Binary file not shown.
								
							
						
					
								
									Binary file not shown.
								
							
						
					@ -1,2 +1,81 @@ | 
			
		|||||
# magic | 
				 | 
			
		||||
 | 
				# Image Captioning with MAGIC | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				*author: David Wang* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Description | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				This operator generates the caption with [MAGIC](https://arxiv.org/abs/2205.02655) which describes the content of the given image. MAGIC is a simple yet efficient plug-and-play framework, which directly combines an off-the-shelf LM (i.e., GPT-2) and an image-text matching model (i.e., CLIP) for image-grounded text generation. During decoding, MAGIC influences the generation of the LM by introducing a CLIP-induced score, called magic score, which regularizes the generated result to be semantically related to a given image while being coherent to the previously generated context. This is an adaptation from [yxuansu / MAGIC](https://github.com/yxuansu/MAGIC). | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Code Example | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Load an image from path './image.jpg' to generate the caption.  | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				 *Write the pipeline in simplified style*: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				import towhee | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				towhee.glob('./image.jpg') \ | 
			
		||||
 | 
				      .image_decode() \ | 
			
		||||
 | 
				      .image_captioning.magic(model_name='expansionnet_rf') \ | 
			
		||||
 | 
				      .show() | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				<img src="./cap.png" alt="result1" style="height:20px;"/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				*Write a same pipeline with explicit inputs/outputs name specifications:* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				import towhee | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				towhee.glob['path']('./image.jpg') \ | 
			
		||||
 | 
				      .image_decode['path', 'img']() \ | 
			
		||||
 | 
				      .image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \ | 
			
		||||
 | 
				      .select['img', 'text']() \ | 
			
		||||
 | 
				      .show() | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				<img src="./tabular.png" alt="result2" style="height:60px;"/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Factory Constructor | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Create the operator via the following factory method | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				***expansionnet_v2(model_name)*** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Parameters:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				​   ***model_name:*** *str* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				​   The model name of MAGIC. Supported model names:  | 
			
		||||
 | 
				- magic_mscoco     | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<br /> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				## Interface | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Parameters:** | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				​	***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*  | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				​  The image to generate embedding.	 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**Returns:** *str* | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				​   The caption generated by model. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
			
		|||||
@ -0,0 +1,167 @@ | 
			
		|||||
 | 
				## Unsupervised Domain Adaptation of Language Model | 
			
		||||
 | 
				**** | 
			
		||||
 | 
				### Catalogue: | 
			
		||||
 | 
				* <a href='#mscoco'>1. MSCOCO Benchmark</a> | 
			
		||||
 | 
				    * <a href='#mscoco_data_preparation'>1.1. MSCOCO Data Preparation</a> | 
			
		||||
 | 
				    * <a href='#mscoco_training'>1.2. Unsupervised Domain Adaptation on MSCOCO</a> | 
			
		||||
 | 
				* <a href='#flickr30k'>2. Flickr30k Benchmark</a> | 
			
		||||
 | 
				    * <a href='#flickr30k_data_preparation'>2.1. Flickr30k Data Preparation</a> | 
			
		||||
 | 
				    * <a href='#flickr30k_training'>2.2. Unsupervised Domain Adaptation on Flickr30k</a> | 
			
		||||
 | 
				* <a href='#unsupervised_baselines'>3. Unsupervised Baselines</a> | 
			
		||||
 | 
				    * <a href='#contrastive_search'>3.1. Contrastive Search</a> | 
			
		||||
 | 
				    * <a href='#top_k_sampling'>3.2. Top-k Sampling</a> | 
			
		||||
 | 
				    * <a href='#nucleus_sampling'>3.3. Nucleus Sampling</a>  | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**** | 
			
		||||
 | 
				<span id='mscoco'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				#### 1. MSCOCO Benchmark: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='mscoco_data_preparation'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 1.1. MSCOCO Data Preparation: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				To prepare the MSCOCO benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='mscoco_training'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 1.2.Unsupervised Domain Adaptation on MSCOCO: | 
			
		||||
 | 
				After preparing the MSCOCO data, run the following command to train the language model. | 
			
		||||
 | 
				```yaml | 
			
		||||
 | 
				chmod +x ./train_mscoco.sh | 
			
		||||
 | 
				./train_mscoco.sh | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				The arguments are as follows: | 
			
		||||
 | 
				* `--model_name`: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large). | 
			
		||||
 | 
				* `--train_path`: The file path of training set. | 
			
		||||
 | 
				* `--dev_path`: The file path of validation set. | 
			
		||||
 | 
				* `--test_path`: The file path of test set. | 
			
		||||
 | 
				* `--add_eos_token_to_data`: Whether adding an eos token at the end of text sequence. | 
			
		||||
 | 
				* `--margin`: The contrastive margin $\rho$. | 
			
		||||
 | 
				* `--max_len`: The maximum length of training samples. | 
			
		||||
 | 
				* `--number_of_gpu`: The number of available GPUs. | 
			
		||||
 | 
				* `--batch_size_per_gpu`: The batch size for each GPU. | 
			
		||||
 | 
				* `--gradient_accumulation_steps`: How many forward computations between two gradient updates. | 
			
		||||
 | 
				* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu. | 
			
		||||
 | 
				* `--total_steps`: The number of total gradient update steps. | 
			
		||||
 | 
				* `--print_every`: Have many steps to show the intermediate results. | 
			
		||||
 | 
				* `--save_every`: How many steps to save one checkpoint. | 
			
		||||
 | 
				* `--learning_rate`: The learning rate. | 
			
		||||
 | 
				* `--save_path_prefix`: Where to save the checkpoints. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**** | 
			
		||||
 | 
				<span id='flickr30k'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				#### 2. Flickr30k Benchmark: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='flickr30k_data_preparation'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 2.1. Flickr30k Data Preparation: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				To prepare the Flickr30k benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='flickr30k_training'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 2.2. Unsupervised Domain Adaptation on Flickr30k: | 
			
		||||
 | 
				After preparing the Flickr30k data, run the following command to train the language model. | 
			
		||||
 | 
				```yaml | 
			
		||||
 | 
				chmod +x ./train_flickr30k.sh | 
			
		||||
 | 
				./train_flickr30k.sh | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				**** | 
			
		||||
 | 
				<span id='unsupervised_baselines'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				#### 3. Unsupervised Baselines: | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are **unsupervised** as the language model is a text-only model and does not take image as input.  | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				# first, load the language model | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from simctg import SimCTG | 
			
		||||
 | 
				sos_token, pad_token = r'<-start_of_text->', r'<-pad->' | 
			
		||||
 | 
				# we use the language model adapted on MSCOCO as an example. | 
			
		||||
 | 
				language_model_name = r'cambridgeltl/magic_mscoco' | 
			
		||||
 | 
				generation_model = SimCTG(language_model_name, sos_token, pad_token) | 
			
		||||
 | 
				generation_model.eval() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token. | 
			
		||||
 | 
				tokens = generation_model.tokenizer.tokenize(sos_token) | 
			
		||||
 | 
				input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens) | 
			
		||||
 | 
				input_ids = torch.LongTensor(input_ids).view(1,-1) | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='contrastive_search'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 3.1. Contrastive Search : | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				   use contrastive search to generate the result. | 
			
		||||
 | 
				   note that, contrastive search is a deterministic decoding method, thus the generated text is always the same. | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				beam_width, alpha, decoding_len = 45, 0.1, 16 | 
			
		||||
 | 
				output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len) | 
			
		||||
 | 
				print (output_text) | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				   A man is riding a skateboard down a street. | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				The arguments are as follows: | 
			
		||||
 | 
				* `--input_ids`: The id of the start of sentence token. | 
			
		||||
 | 
				* `--beam_width`: k in the contrastive search. | 
			
		||||
 | 
				* `--alpha`: alpha in the contrastive search. | 
			
		||||
 | 
				* `--decoding_len`: Number of tokens to generate. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='top_k_sampling'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 3.2. Top-k Sampling : | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				   use top-k sampling to generate the result. | 
			
		||||
 | 
				   note that, the this method is a stochastic method, thus the generated text is always different. | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				top_k, decoding_len = 40, 16 | 
			
		||||
 | 
				output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len) | 
			
		||||
 | 
				print (output_text) | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				   some very different types of vases with flowers together | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				The arguments are as follows: | 
			
		||||
 | 
				* `--input_ids`: The id of the start of sentence token. | 
			
		||||
 | 
				* `--k`: The k in top-k sampling. | 
			
		||||
 | 
				* `--decoding_len`: Number of tokens to generate. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				<span id='nucleus_sampling'/> | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				##### 3.3. Nucleus Sampling : | 
			
		||||
 | 
				```python | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				   use nucleus sampling to generate the result. | 
			
		||||
 | 
				   note that, the this method is a stochastic method, thus the generated text is always different. | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				nucleus_p, decoding_len = 0.95, 16 | 
			
		||||
 | 
				output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len) | 
			
		||||
 | 
				print (output_text) | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				   Two young girls enjoying a hot dog hot dog bun. | 
			
		||||
 | 
				''' | 
			
		||||
 | 
				``` | 
			
		||||
 | 
				The arguments are as follows: | 
			
		||||
 | 
				* `--input_ids`: The id of the start of sentence token. | 
			
		||||
 | 
				* `--nucleus_p`: The probability in nucleus sampling. | 
			
		||||
 | 
				* `--decoding_len`: Number of tokens to generate. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
@ -0,0 +1,157 @@ | 
			
		|||||
 | 
				import json | 
			
		||||
 | 
				import random | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import progressbar | 
			
		||||
 | 
				from torch.nn.utils import rnn | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class Data: | 
			
		||||
 | 
				    def __init__(self, model_name, train_path, dev_path, test_path, max_len,  | 
			
		||||
 | 
				        sos_token, pad_token, add_eos_token_to_data): | 
			
		||||
 | 
				        ''' | 
			
		||||
 | 
				            model_name: gpt2 | 
			
		||||
 | 
				            train_path: training data path | 
			
		||||
 | 
				            dev_path: validation data path | 
			
		||||
 | 
				            test_path: test data path  | 
			
		||||
 | 
				            max_len: maximum length for training sequences  | 
			
		||||
 | 
				            sos_token: initialized sos token <-start_of_text-> | 
			
		||||
 | 
				            pad_token: used to pad the sequences <-pad-> | 
			
		||||
 | 
				            add_eos_token_to_data: whether we want to the model learn to generate eos token; | 
			
		||||
 | 
				                if so, the model could automatically stop generation by generating eos token | 
			
		||||
 | 
				        ''' | 
			
		||||
 | 
				        from transformers import GPT2TokenizerFast | 
			
		||||
 | 
				        self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name) | 
			
		||||
 | 
				        self.sos_token, self.sos_token_id = self.add_special_token(sos_token) | 
			
		||||
 | 
				        print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) | 
			
		||||
 | 
				        self.pad_token, self.pad_token_id = self.add_special_token(pad_token) | 
			
		||||
 | 
				        print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) | 
			
		||||
 | 
				        self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id | 
			
		||||
 | 
				        print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) | 
			
		||||
 | 
				        self.add_eos_token_to_data = add_eos_token_to_data | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        self.max_len = max_len | 
			
		||||
 | 
				        self.train_token_list, self.train_token_id_list = self.process_one_file(train_path) | 
			
		||||
 | 
				        self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path) | 
			
		||||
 | 
				        self.test_token_list, self.test_token_id_list = self.process_one_file(test_path) | 
			
		||||
 | 
				        self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \ | 
			
		||||
 | 
				        len(self.test_token_list) | 
			
		||||
 | 
				        print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        self.train_idx_list = [i for i in range(self.train_num)] | 
			
		||||
 | 
				        random.shuffle(self.train_idx_list) | 
			
		||||
 | 
				        self.dev_idx_list = [j for j in range(self.dev_num)] | 
			
		||||
 | 
				        self.test_idx_list = [j for j in range(self.test_num)] | 
			
		||||
 | 
				        self.dev_current_idx, self.test_current_idx = 0, 0 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def add_special_token(self, special_token): | 
			
		||||
 | 
				        if special_token in self.tokenizer.vocab: | 
			
		||||
 | 
				            print (special_token + ' token exists.') | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            print ('Add token to the tokenizer.') | 
			
		||||
 | 
				            print ('Original vocabulary size is {}'.format(len(self.tokenizer))) | 
			
		||||
 | 
				            self.tokenizer.add_tokens([special_token]) | 
			
		||||
 | 
				            print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) | 
			
		||||
 | 
				            assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 | 
			
		||||
 | 
				        special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] | 
			
		||||
 | 
				        return special_token, special_token_id | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def process_one_file(self, path): | 
			
		||||
 | 
				        print ('Processing {}'.format(path)) | 
			
		||||
 | 
				        with open(path) as f: | 
			
		||||
 | 
				            item_list = json.load(f) | 
			
		||||
 | 
				        lines = [] | 
			
		||||
 | 
				        for item in item_list: | 
			
		||||
 | 
				            captions_list = item['captions'] | 
			
		||||
 | 
				            for one_caption in captions_list: | 
			
		||||
 | 
				                lines.append(one_caption.strip()) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        res_token_list, res_token_id_list = [], [] | 
			
		||||
 | 
				        n = len(lines) | 
			
		||||
 | 
				        p = progressbar.ProgressBar(n) | 
			
		||||
 | 
				        p.start() | 
			
		||||
 | 
				        for i in range(n): | 
			
		||||
 | 
				            p.update(i) | 
			
		||||
 | 
				            text = lines[i].strip('\n') | 
			
		||||
 | 
				            self.process_one_text(text, res_token_list, res_token_id_list) | 
			
		||||
 | 
				        p.finish() | 
			
		||||
 | 
				        print ('{} processed!'.format(path)) | 
			
		||||
 | 
				        return res_token_list, res_token_id_list | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def process_one_text(self, text, res_token_list, res_token_id_list): | 
			
		||||
 | 
				        tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True) | 
			
		||||
 | 
				        if len(tokens) <= 1: # filter out too short sequence | 
			
		||||
 | 
				            return | 
			
		||||
 | 
				        tokens = [self.sos_token] + tokens[:self.max_len] | 
			
		||||
 | 
				        if self.add_eos_token_to_data: | 
			
		||||
 | 
				            tokens = tokens + [self.eos_token] | 
			
		||||
 | 
				        token_ids = self.tokenizer.convert_tokens_to_ids(tokens) | 
			
		||||
 | 
				        res_token_list.append(tokens) | 
			
		||||
 | 
				        res_token_id_list.append(token_ids) | 
			
		||||
 | 
				        return | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def pad_batch(self, batch_id_list): | 
			
		||||
 | 
				        batch_id_list = [torch.LongTensor(item) for item in batch_id_list] | 
			
		||||
 | 
				        batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id) | 
			
		||||
 | 
				        batch_mask = torch.ones_like(batch_tensor) | 
			
		||||
 | 
				        batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor) | 
			
		||||
 | 
				        return batch_tensor, batch_mask | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def process_output(self, batch_tgt_id_list): | 
			
		||||
 | 
				        batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list] | 
			
		||||
 | 
				        batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence | 
			
		||||
 | 
				        batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone() | 
			
		||||
 | 
				        batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone() | 
			
		||||
 | 
				        return batch_tgt_input_tensor, batch_tgt_output_tensor | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def parse_batch(self, batch_id_list): | 
			
		||||
 | 
				        batch_input, batch_labels = self.process_output(batch_id_list) | 
			
		||||
 | 
				        batch_labels[batch_labels[:, :] == self.pad_token_id] = -100 | 
			
		||||
 | 
				        return batch_input, batch_labels | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def get_next_train_batch(self, batch_size): | 
			
		||||
 | 
				        batch_idx_list = random.sample(self.train_idx_list, batch_size) | 
			
		||||
 | 
				        batch_id_list, batch_token_list = [], [] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        for idx in batch_idx_list: | 
			
		||||
 | 
				            batch_id_list.append(self.train_token_id_list[idx]) | 
			
		||||
 | 
				            batch_token_list.append(self.train_token_list[idx]) | 
			
		||||
 | 
				        batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) | 
			
		||||
 | 
				        return batch_input_tensor, batch_labels, batch_token_list | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def get_next_validation_batch(self, batch_size, mode): | 
			
		||||
 | 
				        batch_id_list, batch_token_list = [], [] | 
			
		||||
 | 
				        if mode == 'dev': | 
			
		||||
 | 
				            curr_select_idx, instance_num = self.dev_current_idx, self.dev_num | 
			
		||||
 | 
				            tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list | 
			
		||||
 | 
				        elif mode == 'test': | 
			
		||||
 | 
				            curr_select_idx, instance_num = self.test_current_idx, self.test_num | 
			
		||||
 | 
				            tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            raise Exception('Wrong Validation Mode!!!') | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        if curr_select_idx + batch_size < instance_num: | 
			
		||||
 | 
				            for i in range(batch_size): | 
			
		||||
 | 
				                curr_idx = curr_select_idx + i | 
			
		||||
 | 
				                batch_id_list.append(tgt_token_id_list[curr_idx]) | 
			
		||||
 | 
				                batch_token_list.append(tgt_token_list[curr_idx]) | 
			
		||||
 | 
				            if mode == 'dev': | 
			
		||||
 | 
				                self.dev_current_idx += batch_size | 
			
		||||
 | 
				            else: | 
			
		||||
 | 
				                self.test_current_idx += batch_size | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            for i in range(batch_size): | 
			
		||||
 | 
				                curr_idx = curr_select_idx + i | 
			
		||||
 | 
				                if curr_idx > instance_num - 1:  | 
			
		||||
 | 
				                    curr_idx = 0 | 
			
		||||
 | 
				                    if mode == 'dev': | 
			
		||||
 | 
				                        self.dev_current_idx = 0 | 
			
		||||
 | 
				                    else: | 
			
		||||
 | 
				                        self.test_current_idx = 0 | 
			
		||||
 | 
				                batch_id_list.append(tgt_token_id_list[curr_idx]) | 
			
		||||
 | 
				                batch_token_list.append(tgt_token_list[curr_idx]) | 
			
		||||
 | 
				            if mode == 'dev': | 
			
		||||
 | 
				                self.dev_current_idx = 0 | 
			
		||||
 | 
				            else: | 
			
		||||
 | 
				                self.test_current_idx = 0 | 
			
		||||
 | 
				        batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) | 
			
		||||
 | 
				        return batch_input_tensor, batch_labels, batch_token_list | 
			
		||||
@ -0,0 +1,80 @@ | 
			
		|||||
 | 
				import torch | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def compute_valid_token_num(valid_len_list): | 
			
		||||
 | 
				    res = 0 | 
			
		||||
 | 
				    for one_len in valid_len_list: | 
			
		||||
 | 
				        res += one_len * (one_len - 1) | 
			
		||||
 | 
				    return res | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0): | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				        prefix_len: the length of prefix that we do not want to compute CL loss for. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        (1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4), | 
			
		||||
 | 
				            then the loss padding matrix looks like | 
			
		||||
 | 
				                 [0., 1., 1., 1.], | 
			
		||||
 | 
				                 [1., 0., 1., 1.], | 
			
		||||
 | 
				                 [1., 1., 0., 1.], | 
			
		||||
 | 
				                 [1., 1., 1., 0.] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        (2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3), | 
			
		||||
 | 
				            then the loss padding matrix looks like | 
			
		||||
 | 
				                 [0., 1., 1., 0.], | 
			
		||||
 | 
				                 [1., 0., 1., 0.], | 
			
		||||
 | 
				                 [1., 1., 0., 0.], | 
			
		||||
 | 
				                 [0., 0., 0., 0.] | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				    res_list = [] | 
			
		||||
 | 
				    base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen) | 
			
		||||
 | 
				    base_mask = base_mask.type(torch.FloatTensor) | 
			
		||||
 | 
				    bsz = len(valid_len_list) | 
			
		||||
 | 
				    for i in range(bsz): | 
			
		||||
 | 
				        one_base_mask = base_mask.clone() | 
			
		||||
 | 
				        one_valid_len = valid_len_list[i] | 
			
		||||
 | 
				        one_base_mask[:,one_valid_len:] = 0. | 
			
		||||
 | 
				        one_base_mask[one_valid_len:, :] = 0. | 
			
		||||
 | 
				        if prefix_len > 0: | 
			
		||||
 | 
				            one_base_mask[:prefix_len, :prefix_len] = 0. | 
			
		||||
 | 
				        res_list.append(one_base_mask) | 
			
		||||
 | 
				    res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list) | 
			
		||||
 | 
				    #print (res_mask) | 
			
		||||
 | 
				    assert res_mask.size() == torch.Size([bsz, seqlen, seqlen]) | 
			
		||||
 | 
				    return res_mask | 
			
		||||
 | 
				         | 
			
		||||
 | 
				def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0): | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				       margin: predefined margin to push similarity score away | 
			
		||||
 | 
				       score_matrix: bsz x seqlen x seqlen | 
			
		||||
 | 
				       input_ids: bsz x seqlen | 
			
		||||
 | 
				       pad_token_id: indicating which tokens are padding token | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				    bsz, seqlen, _ = score_matrix.size() | 
			
		||||
 | 
				    gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen | 
			
		||||
 | 
				    gold_score = torch.unsqueeze(gold_score, -1) | 
			
		||||
 | 
				    assert gold_score.size() == torch.Size([bsz, seqlen, 1]) | 
			
		||||
 | 
				    difference_matrix = gold_score - score_matrix | 
			
		||||
 | 
				    assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen]) | 
			
		||||
 | 
				    loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen | 
			
		||||
 | 
				    loss_matrix = torch.nn.functional.relu(loss_matrix) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    ### input mask | 
			
		||||
 | 
				    input_mask = torch.ones_like(input_ids).type(torch.FloatTensor) | 
			
		||||
 | 
				    if loss_matrix.is_cuda: | 
			
		||||
 | 
				        input_mask = input_mask.cuda(loss_matrix.get_device()) | 
			
		||||
 | 
				    input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    if loss_matrix.is_cuda: | 
			
		||||
 | 
				        input_mask = input_mask.cuda(loss_matrix.get_device()) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    valid_len_list = torch.sum(input_mask, dim = -1).tolist() | 
			
		||||
 | 
				    loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len) | 
			
		||||
 | 
				    if score_matrix.is_cuda: | 
			
		||||
 | 
				        loss_mask = loss_mask.cuda(score_matrix.get_device()) | 
			
		||||
 | 
				    masked_loss_matrix = loss_matrix * loss_mask | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    loss_matrix = torch.sum(masked_loss_matrix, dim = -1) | 
			
		||||
 | 
				    assert loss_matrix.size() == input_ids.size() | 
			
		||||
 | 
				    loss_matrix = loss_matrix * input_mask | 
			
		||||
 | 
				    cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask) | 
			
		||||
 | 
				    return cl_loss | 
			
		||||
 | 
				     | 
			
		||||
@ -0,0 +1,233 @@ | 
			
		|||||
 | 
				import os | 
			
		||||
 | 
				import sys | 
			
		||||
 | 
				import operator | 
			
		||||
 | 
				from tqdm import tqdm | 
			
		||||
 | 
				from operator import itemgetter | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from torch import nn | 
			
		||||
 | 
				import random | 
			
		||||
 | 
				import argparse | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import torch.nn.functional as F | 
			
		||||
 | 
				from torch.nn import CrossEntropyLoss | 
			
		||||
 | 
				from loss_func import contrastive_loss | 
			
		||||
 | 
				from utlis import PlugAndPlayContrastiveDecodingOneStepFast | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import seaborn as sns | 
			
		||||
 | 
				import matplotlib.pyplot as plt | 
			
		||||
 | 
				import pandas as pd | 
			
		||||
 | 
				import datetime | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				train_fct = CrossEntropyLoss() | 
			
		||||
 | 
				val_fct = CrossEntropyLoss(reduction='none') | 
			
		||||
 | 
				class SimCTG(nn.Module): | 
			
		||||
 | 
				    def __init__(self, model_name, sos_token, pad_token): | 
			
		||||
 | 
				        super(SimCTG, self).__init__() | 
			
		||||
 | 
				        from transformers import AutoTokenizer, GPT2LMHeadModel | 
			
		||||
 | 
				        self.tokenizer = AutoTokenizer.from_pretrained(model_name) | 
			
		||||
 | 
				        self.sos_token, self.sos_token_id = self.add_special_token(sos_token) | 
			
		||||
 | 
				        print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) | 
			
		||||
 | 
				        self.pad_token, self.pad_token_id = self.add_special_token(pad_token) | 
			
		||||
 | 
				        print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) | 
			
		||||
 | 
				        self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id | 
			
		||||
 | 
				        print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) | 
			
		||||
 | 
				        self.model = GPT2LMHeadModel.from_pretrained(model_name) | 
			
		||||
 | 
				        self.vocab_size = len(self.tokenizer) | 
			
		||||
 | 
				        print ('Resizing model embedding...') | 
			
		||||
 | 
				        self.model.resize_token_embeddings(len(self.tokenizer))  | 
			
		||||
 | 
				        print ('Model embedding resized!') | 
			
		||||
 | 
				        self.embed_dim = self.model.config.hidden_size | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def add_special_token(self, special_token): | 
			
		||||
 | 
				        if special_token in self.tokenizer.vocab: | 
			
		||||
 | 
				            print (special_token + ' token exists.') | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            print ('Add token to the tokenizer.') | 
			
		||||
 | 
				            print ('Original vocabulary size is {}'.format(len(self.tokenizer))) | 
			
		||||
 | 
				            self.tokenizer.add_tokens([special_token]) | 
			
		||||
 | 
				            print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) | 
			
		||||
 | 
				            assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 | 
			
		||||
 | 
				        special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] | 
			
		||||
 | 
				        return special_token, special_token_id | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def compute_logits_and_hidden_states(self, input_ids): | 
			
		||||
 | 
				        # used for advanced decoding | 
			
		||||
 | 
				        # input_ids: 1 x seqlen | 
			
		||||
 | 
				        outputs = self.model(input_ids=input_ids, output_hidden_states=True) | 
			
		||||
 | 
				        last_hidden_states = outputs.hidden_states[-1] | 
			
		||||
 | 
				        logits = outputs.logits | 
			
		||||
 | 
				        return last_hidden_states, logits | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def forward(self, input_ids, labels, margin): | 
			
		||||
 | 
				        bsz, seqlen = input_ids.size() | 
			
		||||
 | 
				        outputs = self.model(input_ids=input_ids, output_hidden_states=True) | 
			
		||||
 | 
				        logits = outputs.logits | 
			
		||||
 | 
				        assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) | 
			
		||||
 | 
				        last_hidden_states = outputs.hidden_states[-1] | 
			
		||||
 | 
				        assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) | 
			
		||||
 | 
				        mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True) | 
			
		||||
 | 
				        cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2))  | 
			
		||||
 | 
				        assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen]) | 
			
		||||
 | 
				        cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0) | 
			
		||||
 | 
				        return mle_loss, cl_loss | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def eval_loss(self, input_ids, labels): | 
			
		||||
 | 
				        bsz, seqlen = input_ids.size() | 
			
		||||
 | 
				        outputs = self.model(input_ids=input_ids, output_hidden_states=True) | 
			
		||||
 | 
				        logits = outputs.logits | 
			
		||||
 | 
				        assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) | 
			
		||||
 | 
				        last_hidden_states = outputs.hidden_states[-1] | 
			
		||||
 | 
				        assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) | 
			
		||||
 | 
				        mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1)) | 
			
		||||
 | 
				        assert mle_loss.size() == torch.Size([bsz * seqlen]) | 
			
		||||
 | 
				        mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0) | 
			
		||||
 | 
				        mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0) | 
			
		||||
 | 
				        # sum  | 
			
		||||
 | 
				        mle_loss_sum = torch.sum(mle_loss) | 
			
		||||
 | 
				        token_num_sum = torch.sum(mask) | 
			
		||||
 | 
				        return mle_loss_sum, token_num_sum | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def save_model(self, ckpt_save_path): | 
			
		||||
 | 
				        import os | 
			
		||||
 | 
				        if os.path.exists(ckpt_save_path): | 
			
		||||
 | 
				            pass | 
			
		||||
 | 
				        else: # recursively construct directory | 
			
		||||
 | 
				            os.makedirs(ckpt_save_path, exist_ok=True) | 
			
		||||
 | 
				        # save model | 
			
		||||
 | 
				        self.model.save_pretrained(ckpt_save_path) | 
			
		||||
 | 
				        # save tokenizer | 
			
		||||
 | 
				        self.tokenizer.save_pretrained(ckpt_save_path) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def parse_sentences(self, text, num_of_sentences_to_keep): | 
			
		||||
 | 
				        item_list = text.split('.') | 
			
		||||
 | 
				        res_list = item_list[:num_of_sentences_to_keep] | 
			
		||||
 | 
				        if len(item_list) > num_of_sentences_to_keep: | 
			
		||||
 | 
				            res_text = '.'.join(res_list).strip('.') + '.' | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            res_text = '.'.join(res_list).strip('.').strip() | 
			
		||||
 | 
				        return res_text | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def parse_generated_result(self, output, num_of_sentences_to_keep): | 
			
		||||
 | 
				        output_text = self.tokenizer.decode(output) | 
			
		||||
 | 
				        item_list = output_text.split(self.eos_token) | 
			
		||||
 | 
				        full_text = self.eos_token.join(item_list[:2]).strip() | 
			
		||||
 | 
				        full_text = self.parse_sentences(full_text, num_of_sentences_to_keep) | 
			
		||||
 | 
				        generated_text = item_list[1].strip() | 
			
		||||
 | 
				        generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep) | 
			
		||||
 | 
				        return full_text, generated_text | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    # decoding functions | 
			
		||||
 | 
				    # ------------------------------------------------------- # | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def parse_output_token_list(self, output): | 
			
		||||
 | 
				        output = output.tolist() | 
			
		||||
 | 
				        res_list = [] | 
			
		||||
 | 
				        for token_id in output: | 
			
		||||
 | 
				            if token_id == self.sos_token_id: | 
			
		||||
 | 
				                continue | 
			
		||||
 | 
				            elif token_id == self.eos_token_id: | 
			
		||||
 | 
				                break | 
			
		||||
 | 
				            else: | 
			
		||||
 | 
				                res_list.append(token_id) | 
			
		||||
 | 
				        text = self.tokenizer.decode(res_list).strip() | 
			
		||||
 | 
				        return ' '.join(text.split()).strip() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    @torch.no_grad() | 
			
		||||
 | 
				    def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip,  | 
			
		||||
 | 
				        clip_text_max_len):#, add_token_level_score=False): | 
			
		||||
 | 
				        prefix_len = input_ids.size()[1] | 
			
		||||
 | 
				        #from utlis import PlugAndPlayContrastiveDecodingOneStepFast | 
			
		||||
 | 
				        past_key_values, last_hidden_states, logits = None, None, None | 
			
		||||
 | 
				        generated = [item for item in input_ids.tolist()] | 
			
		||||
 | 
				        input_ids_for_class = input_ids.clone() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        image_embeds = clip.compute_image_representation_from_image_instance(image_instance) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        start_time = datetime.datetime.now() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # the maximum supported length of generation for SimCTG is 256 | 
			
		||||
 | 
				        # to support longer generated length, you can re-train the SimCTG model with longer sequences | 
			
		||||
 | 
				        decoding_len = decoding_len - prefix_len | 
			
		||||
 | 
				        for step in range(decoding_len): | 
			
		||||
 | 
				            input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \ | 
			
		||||
 | 
				            PlugAndPlayContrastiveDecodingOneStepFast( | 
			
		||||
 | 
				                self.model,  | 
			
		||||
 | 
				                input_ids,  | 
			
		||||
 | 
				                prefix_len, | 
			
		||||
 | 
				                beam_width,  | 
			
		||||
 | 
				                alpha,  | 
			
		||||
 | 
				                beta,  | 
			
		||||
 | 
				                self.tokenizer, | 
			
		||||
 | 
				                image_embeds,  | 
			
		||||
 | 
				                clip,  | 
			
		||||
 | 
				                clip_text_max_len, | 
			
		||||
 | 
				                past_key_values, | 
			
		||||
 | 
				                last_hidden_states, | 
			
		||||
 | 
				                logits, | 
			
		||||
 | 
				                first_step=step==0, | 
			
		||||
 | 
				                input_ids_for_class=input_ids_for_class, | 
			
		||||
 | 
				            ) | 
			
		||||
 | 
				        end_time = datetime.datetime.now() | 
			
		||||
 | 
				        time_diff = (end_time - start_time) | 
			
		||||
 | 
				        execution_time = time_diff.total_seconds() * 1000 | 
			
		||||
 | 
				        return self.parse_output_token_list(input_ids_for_class[0]) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len): | 
			
		||||
 | 
				        ''' | 
			
		||||
 | 
				           input_ids: prefix input; 1 x prefix_len | 
			
		||||
 | 
				           decoding_len: how many tokens to generate | 
			
		||||
 | 
				           beam_width: size of candidate pool during decoding | 
			
		||||
 | 
				           alpha: regulates importance of model confidence and degeneration penalty | 
			
		||||
 | 
				        ''' | 
			
		||||
 | 
				        self.model.eval() | 
			
		||||
 | 
				        #from utlis import ContrastiveDecodingOneStepFast | 
			
		||||
 | 
				        # sanity check | 
			
		||||
 | 
				        assert alpha >= 0. and alpha <= 1.0 | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # fast mode | 
			
		||||
 | 
				        prefix_len = input_ids.size()[1] | 
			
		||||
 | 
				        batch_size, seqlen = input_ids.size() | 
			
		||||
 | 
				        #generated = [[] for _ in range(batch_size)] | 
			
		||||
 | 
				        generated = [item for item in input_ids.tolist()] | 
			
		||||
 | 
				        past_key_values = None | 
			
		||||
 | 
				        last_hidden_states = None | 
			
		||||
 | 
				        logits = None | 
			
		||||
 | 
				        decoding_len = decoding_len - prefix_len | 
			
		||||
 | 
				        for step in range(decoding_len): | 
			
		||||
 | 
				            input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast( | 
			
		||||
 | 
				                self.model, | 
			
		||||
 | 
				                input_ids, | 
			
		||||
 | 
				                beam_width, | 
			
		||||
 | 
				                alpha, | 
			
		||||
 | 
				                past_key_values, | 
			
		||||
 | 
				                last_hidden_states, | 
			
		||||
 | 
				                self.tokenizer, | 
			
		||||
 | 
				                logits, | 
			
		||||
 | 
				                first_step=step == 0, | 
			
		||||
 | 
				            ) | 
			
		||||
 | 
				            tokens = input_ids.squeeze(dim=-1).tolist() | 
			
		||||
 | 
				            for idx, t in enumerate(tokens): | 
			
		||||
 | 
				                generated[idx].append(t) | 
			
		||||
 | 
				        return self.parse_output_token_list(torch.LongTensor(generated[0])) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def top_k_sampling(self, input_ids, k, decoding_len): | 
			
		||||
 | 
				        _, prefix_len = input_ids.size() | 
			
		||||
 | 
				        output = self.model.generate( | 
			
		||||
 | 
				                            input_ids,  | 
			
		||||
 | 
				                            do_sample=True,  | 
			
		||||
 | 
				                            max_length=decoding_len,  | 
			
		||||
 | 
				                            top_p=1.0, | 
			
		||||
 | 
				                            top_k=k) | 
			
		||||
 | 
				        return self.parse_output_token_list(output[0]) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def nucleus_sampling(self, input_ids, nucleus_p, decoding_len): | 
			
		||||
 | 
				        _, prefix_len = input_ids.size() | 
			
		||||
 | 
				        output = self.model.generate( | 
			
		||||
 | 
				                            input_ids,  | 
			
		||||
 | 
				                            do_sample=True,  | 
			
		||||
 | 
				                            max_length=decoding_len,  | 
			
		||||
 | 
				                            top_p=nucleus_p, | 
			
		||||
 | 
				                            top_k=0) | 
			
		||||
 | 
				        return self.parse_output_token_list(output[0]) | 
			
		||||
@ -0,0 +1,107 @@ | 
			
		|||||
 | 
				# coding=utf-8 | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				import torch.nn as nn | 
			
		||||
 | 
				import torch.nn.functional as F | 
			
		||||
 | 
				import torch.multiprocessing as mp | 
			
		||||
 | 
				import argparse, os | 
			
		||||
 | 
				import random | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import time | 
			
		||||
 | 
				import logging | 
			
		||||
 | 
				import progressbar | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import logging | 
			
		||||
 | 
				logging.getLogger('transformers.generation_utils').disabled = True | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def parse_config(): | 
			
		||||
 | 
				    parser = argparse.ArgumentParser() | 
			
		||||
 | 
				    # data configuration | 
			
		||||
 | 
				    parser.add_argument("--model_name", type=str, default='gpt2') | 
			
		||||
 | 
				    parser.add_argument("--train_path", type=str) | 
			
		||||
 | 
				    parser.add_argument("--dev_path", type=str) | 
			
		||||
 | 
				    parser.add_argument("--test_path", type=str) | 
			
		||||
 | 
				    parser.add_argument("--max_len", type=int) | 
			
		||||
 | 
				    parser.add_argument("--add_eos_token_to_data", type=str) | 
			
		||||
 | 
				    # mini-batch training configuration | 
			
		||||
 | 
				    parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.")   | 
			
		||||
 | 
				    parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.')  | 
			
		||||
 | 
				    parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.") | 
			
		||||
 | 
				    parser.add_argument("--effective_batch_size", type=int,  | 
			
		||||
 | 
				        help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps") | 
			
		||||
 | 
				    # pre-training configuration | 
			
		||||
 | 
				    parser.add_argument("--total_steps", type=int,  | 
			
		||||
 | 
				        help="total effective training steps") | 
			
		||||
 | 
				    parser.add_argument("--print_every", type=int,  | 
			
		||||
 | 
				        help="how many update steps to print one intermediate result") | 
			
		||||
 | 
				    parser.add_argument("--save_every", type=int,  | 
			
		||||
 | 
				        help="how many update steps to save one model") | 
			
		||||
 | 
				    # learning configuration | 
			
		||||
 | 
				    parser.add_argument("--learning_rate", type=float, default=2e-5) | 
			
		||||
 | 
				    parser.add_argument("--margin", type=float) | 
			
		||||
 | 
				    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") | 
			
		||||
 | 
				    parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.") | 
			
		||||
 | 
				    return parser.parse_args() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def load_previous_best_model(path): | 
			
		||||
 | 
				    import os | 
			
		||||
 | 
				    filenames = os.listdir(path) | 
			
		||||
 | 
				    for file in filenames: | 
			
		||||
 | 
				        if file.startswith('training_step'): | 
			
		||||
 | 
				            return path + '/' + file | 
			
		||||
 | 
				    raise Exception('No best model found!') | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import argparse | 
			
		||||
 | 
				if __name__ == '__main__': | 
			
		||||
 | 
				    if torch.cuda.is_available(): | 
			
		||||
 | 
				        print ('Cuda is available.') | 
			
		||||
 | 
				    cuda_available = torch.cuda.is_available() | 
			
		||||
 | 
				    multi_gpu_training = False | 
			
		||||
 | 
				    if cuda_available: | 
			
		||||
 | 
				        if torch.cuda.device_count() > 1: | 
			
		||||
 | 
				            multi_gpu_training = True | 
			
		||||
 | 
				            print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count())) | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            print ('Using single GPU training.') | 
			
		||||
 | 
				    else: | 
			
		||||
 | 
				        pass | 
			
		||||
 | 
				    args = parse_config() | 
			
		||||
 | 
				    device = torch.device('cuda') | 
			
		||||
 | 
				    model_name = args.model_name | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    sos_token, pad_token = r'<-start_of_text->', r'<-pad->' | 
			
		||||
 | 
				    add_eos_token_to_data = args.add_eos_token_to_data | 
			
		||||
 | 
				    if add_eos_token_to_data == 'True': | 
			
		||||
 | 
				        add_eos_token_to_data = True | 
			
		||||
 | 
				        print ('Add eos token to data!') | 
			
		||||
 | 
				    elif add_eos_token_to_data == 'False': | 
			
		||||
 | 
				        add_eos_token_to_data = False | 
			
		||||
 | 
				        print ('Do not add eos token to data!') | 
			
		||||
 | 
				    else: | 
			
		||||
 | 
				        raise Exception('Wrong eos configuration for data!!!') | 
			
		||||
 | 
				    print ('Loading data...') | 
			
		||||
 | 
				    from dataclass import Data | 
			
		||||
 | 
				    data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len,  | 
			
		||||
 | 
				        sos_token, pad_token, add_eos_token_to_data) | 
			
		||||
 | 
				    print ('Data loaded.') | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    from trainer import model_training | 
			
		||||
 | 
				    print ('############################################################') | 
			
		||||
 | 
				    print ('Start Training...') | 
			
		||||
 | 
				    from simctg import SimCTG | 
			
		||||
 | 
				    print ('Initializaing SimCTG model...') | 
			
		||||
 | 
				    model = SimCTG(model_name, sos_token, pad_token) | 
			
		||||
 | 
				    if cuda_available: | 
			
		||||
 | 
				        if multi_gpu_training: | 
			
		||||
 | 
				            model = nn.DataParallel(model) # multi-gpu training | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            pass | 
			
		||||
 | 
				        model = model.to(device) | 
			
		||||
 | 
				    else: | 
			
		||||
 | 
				        pass | 
			
		||||
 | 
				    print ('Model loaded')  | 
			
		||||
 | 
				    total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every | 
			
		||||
 | 
				    ckpt_save_path = args.save_path_prefix | 
			
		||||
 | 
				    model = model_training(args, data, model, total_steps, print_every, save_every,  | 
			
		||||
 | 
				        ckpt_save_path, cuda_available, device) | 
			
		||||
 | 
				    print ('Training stage completed!') | 
			
		||||
 | 
				    print ('############################################################') | 
			
		||||
@ -0,0 +1,17 @@ | 
			
		|||||
 | 
				CUDA_VISIBLE_DEVICES=0 python train.py\ | 
			
		||||
 | 
				    --model_name gpt2\ | 
			
		||||
 | 
				    --train_path ../data/flickr30k/flickr30k_train.json\ | 
			
		||||
 | 
				    --dev_path ../data/flickr30k/flickr30k_val.json\ | 
			
		||||
 | 
				    --test_path ../data/flickr30k/flickr30k_test.json\ | 
			
		||||
 | 
				    --add_eos_token_to_data True\ | 
			
		||||
 | 
				    --margin 0.5\ | 
			
		||||
 | 
				    --max_len 64\ | 
			
		||||
 | 
				    --number_of_gpu 1\ | 
			
		||||
 | 
				    --batch_size_per_gpu 32\ | 
			
		||||
 | 
				    --gradient_accumulation_steps 4\ | 
			
		||||
 | 
				    --effective_batch_size 128\ | 
			
		||||
 | 
				    --total_steps 10000\ | 
			
		||||
 | 
				    --print_every 50\ | 
			
		||||
 | 
				    --save_every 250\ | 
			
		||||
 | 
				    --learning_rate 2e-5\ | 
			
		||||
 | 
				    --save_path_prefix ./magic_flickr30k/ | 
			
		||||
@ -0,0 +1,17 @@ | 
			
		|||||
 | 
				CUDA_VISIBLE_DEVICES=0 python train.py\ | 
			
		||||
 | 
				    --model_name gpt2\ | 
			
		||||
 | 
				    --train_path ../data/mscoco/mscoco_train.json\ | 
			
		||||
 | 
				    --dev_path ../data/mscoco/mscoco_val.json\ | 
			
		||||
 | 
				    --test_path ../data/mscoco/mscoco_test.json\ | 
			
		||||
 | 
				    --add_eos_token_to_data True\ | 
			
		||||
 | 
				    --margin 0.5\ | 
			
		||||
 | 
				    --max_len 64\ | 
			
		||||
 | 
				    --number_of_gpu 1\ | 
			
		||||
 | 
				    --batch_size_per_gpu 32\ | 
			
		||||
 | 
				    --gradient_accumulation_steps 4\ | 
			
		||||
 | 
				    --effective_batch_size 128\ | 
			
		||||
 | 
				    --total_steps 20000\ | 
			
		||||
 | 
				    --print_every 100\ | 
			
		||||
 | 
				    --save_every 500\ | 
			
		||||
 | 
				    --learning_rate 2e-5\ | 
			
		||||
 | 
				    --save_path_prefix ./magic_mscoco/ | 
			
		||||
@ -0,0 +1,165 @@ | 
			
		|||||
 | 
				# coding=utf-8 | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				import torch.nn as nn | 
			
		||||
 | 
				import torch.nn.functional as F | 
			
		||||
 | 
				import torch.multiprocessing as mp | 
			
		||||
 | 
				import argparse, os | 
			
		||||
 | 
				import random | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import time | 
			
		||||
 | 
				import logging | 
			
		||||
 | 
				import progressbar | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import logging | 
			
		||||
 | 
				logging.getLogger('transformers.generation_utils').disabled = True | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def eval_model(args, model, data, cuda_available, device): | 
			
		||||
 | 
				    dataset_batch_size = args.batch_size_per_gpu * args.number_of_gpu | 
			
		||||
 | 
				    eval_step = int(data.test_num / dataset_batch_size) + 1 | 
			
		||||
 | 
				    val_loss, token_sum = 0., 0. | 
			
		||||
 | 
				    model.eval() | 
			
		||||
 | 
				    with torch.no_grad(): | 
			
		||||
 | 
				        p = progressbar.ProgressBar(eval_step) | 
			
		||||
 | 
				        p.start() | 
			
		||||
 | 
				        for idx in range(eval_step): | 
			
		||||
 | 
				            p.update(idx) | 
			
		||||
 | 
				            batch_input_tensor, batch_labels, _ = \ | 
			
		||||
 | 
				            data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test') | 
			
		||||
 | 
				            if cuda_available: | 
			
		||||
 | 
				                batch_input_tensor = batch_input_tensor.cuda(device) | 
			
		||||
 | 
				                batch_labels = batch_labels.cuda(device) | 
			
		||||
 | 
				            one_val_loss, one_val_token_sum = model.eval_loss(batch_input_tensor, batch_labels) | 
			
		||||
 | 
				            one_val_loss = torch.sum(one_val_loss) | 
			
		||||
 | 
				            one_val_token_sum = torch.sum(one_val_token_sum) | 
			
		||||
 | 
				            val_loss += one_val_loss.item() | 
			
		||||
 | 
				            token_sum += one_val_token_sum.item() | 
			
		||||
 | 
				        p.finish() | 
			
		||||
 | 
				    model.train() | 
			
		||||
 | 
				    val_loss = val_loss / token_sum | 
			
		||||
 | 
				    return val_loss | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def model_training(args, data, model, total_steps, print_every, save_every, ckpt_save_path, cuda_available, device): | 
			
		||||
 | 
				    import os | 
			
		||||
 | 
				    if os.path.exists(ckpt_save_path): | 
			
		||||
 | 
				        pass | 
			
		||||
 | 
				    else: # recursively construct directory | 
			
		||||
 | 
				        os.makedirs(ckpt_save_path, exist_ok=True) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    max_save_num = 1 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    batch_size_per_gpu, gradient_accumulation_steps, number_of_gpu, effective_batch_size = \ | 
			
		||||
 | 
				    args.batch_size_per_gpu, args.gradient_accumulation_steps, args.number_of_gpu, args.effective_batch_size | 
			
		||||
 | 
				    assert effective_batch_size == batch_size_per_gpu * gradient_accumulation_steps * number_of_gpu | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    warmup_steps = int(0.1 * total_steps) # 10% of training steps are used for warmup | 
			
		||||
 | 
				    print ('total training steps is {}, warmup steps is {}'.format(total_steps, warmup_steps)) | 
			
		||||
 | 
				    from transformers.optimization import AdamW, get_linear_schedule_with_warmup | 
			
		||||
 | 
				    optimizer = AdamW(model.parameters(), lr=args.learning_rate) | 
			
		||||
 | 
				    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) | 
			
		||||
 | 
				    optimizer.zero_grad() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    effective_batch_acm = 0 | 
			
		||||
 | 
				    all_batch_step = 1 | 
			
		||||
 | 
				    print_valid, save_valid = False, False | 
			
		||||
 | 
				    train_loss, train_cl_loss, min_val_loss = 0., 0., 1e10 | 
			
		||||
 | 
				    train_ave_bleu = 0. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    print ('--------------------------------------------------------------------------') | 
			
		||||
 | 
				    print ('Start Training:') | 
			
		||||
 | 
				    model.train() | 
			
		||||
 | 
				    number_of_saves = 0 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    while effective_batch_acm < total_steps: | 
			
		||||
 | 
				        all_batch_step += 1 | 
			
		||||
 | 
				        train_batch_input_tensor, train_batch_labels, _ = data.get_next_train_batch(batch_size_per_gpu * number_of_gpu) | 
			
		||||
 | 
				        if cuda_available: | 
			
		||||
 | 
				            train_batch_input_tensor = train_batch_input_tensor.cuda(device) | 
			
		||||
 | 
				            train_batch_labels = train_batch_labels.cuda(device) | 
			
		||||
 | 
				        mle_loss, cl_loss = model(train_batch_input_tensor, train_batch_labels, args.margin) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        loss = mle_loss + cl_loss | 
			
		||||
 | 
				        loss = loss.mean() | 
			
		||||
 | 
				        loss.backward() | 
			
		||||
 | 
				        train_loss += mle_loss.item() | 
			
		||||
 | 
				        train_cl_loss += cl_loss.item() | 
			
		||||
 | 
				        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # parameter update | 
			
		||||
 | 
				        if all_batch_step % gradient_accumulation_steps == 0: | 
			
		||||
 | 
				            optimizer.step() | 
			
		||||
 | 
				            scheduler.step() | 
			
		||||
 | 
				            optimizer.zero_grad() | 
			
		||||
 | 
				            effective_batch_acm += 1 | 
			
		||||
 | 
				            print_valid, save_valid = True, True | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # print intermediate result | 
			
		||||
 | 
				        if effective_batch_acm % print_every == 0 and print_valid: | 
			
		||||
 | 
				            denominator = (effective_batch_acm - (number_of_saves * save_every)) * gradient_accumulation_steps | 
			
		||||
 | 
				            one_train_loss = train_loss / denominator | 
			
		||||
 | 
				            one_train_cl_loss = train_cl_loss / denominator | 
			
		||||
 | 
				            print ('At training steps {}, training MLE loss is {}, train CL loss is {}'.format(effective_batch_acm,  | 
			
		||||
 | 
				                one_train_loss, one_train_cl_loss)) | 
			
		||||
 | 
				            print_valid = False | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # saving result | 
			
		||||
 | 
				        if effective_batch_acm % save_every == 0 and save_valid: | 
			
		||||
 | 
				            number_of_saves += 1 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            save_valid = False | 
			
		||||
 | 
				            one_train_loss = train_loss / (save_every * gradient_accumulation_steps) | 
			
		||||
 | 
				            one_train_cl_loss = train_cl_loss / (save_every * gradient_accumulation_steps) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            model.eval() | 
			
		||||
 | 
				            one_val_loss = eval_model(args, model, data, cuda_available, device) | 
			
		||||
 | 
				            model.train() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            print ('At training steps {}, training MLE loss is {}, train CL loss is {}, validation loss is {}'.format(effective_batch_acm,  | 
			
		||||
 | 
				                one_train_loss, one_train_cl_loss, one_val_loss)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            train_loss, train_cl_loss = 0., 0. | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            if one_val_loss < min_val_loss: | 
			
		||||
 | 
				                # in finetuning stage, we always save the model | 
			
		||||
 | 
				                min_val_loss = min(one_val_loss, min_val_loss) | 
			
		||||
 | 
				                print ('Saving model...') | 
			
		||||
 | 
				                one_val_ppl = np.exp(one_val_loss) | 
			
		||||
 | 
				                one_val_ppl = round(one_val_ppl, 3) | 
			
		||||
 | 
				                save_name = 'training_step_{}_train_mle_loss_{}_train_cl_loss_{}_dev_loss_{}_dev_ppl_{}'.format(effective_batch_acm, | 
			
		||||
 | 
				                round(one_train_loss,5), round(one_train_cl_loss,5), round(one_val_loss,5), one_val_ppl) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				                model_save_path = ckpt_save_path + '/' + save_name | 
			
		||||
 | 
				                import os | 
			
		||||
 | 
				                if os.path.exists(model_save_path): | 
			
		||||
 | 
				                    pass | 
			
		||||
 | 
				                else: # recursively construct directory | 
			
		||||
 | 
				                    os.makedirs(model_save_path, exist_ok=True) | 
			
		||||
 | 
				                if cuda_available and torch.cuda.device_count() > 1: | 
			
		||||
 | 
				                    model.module.save_model(model_save_path) | 
			
		||||
 | 
				                else: | 
			
		||||
 | 
				                    model.save_model(model_save_path) | 
			
		||||
 | 
				                print ('Model Saved!') | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				                # --------------------------------------------------------------------------------------------- # | 
			
		||||
 | 
				                # removing extra checkpoints... | 
			
		||||
 | 
				                import os | 
			
		||||
 | 
				                from operator import itemgetter | 
			
		||||
 | 
				                fileData = {} | 
			
		||||
 | 
				                test_output_dir = ckpt_save_path | 
			
		||||
 | 
				                for fname in os.listdir(test_output_dir): | 
			
		||||
 | 
				                    if fname.startswith('training_step'): | 
			
		||||
 | 
				                        fileData[fname] = os.stat(test_output_dir + '/' + fname).st_mtime | 
			
		||||
 | 
				                    else: | 
			
		||||
 | 
				                        pass | 
			
		||||
 | 
				                sortedFiles = sorted(fileData.items(), key=itemgetter(1)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				                if len(sortedFiles) < max_save_num: | 
			
		||||
 | 
				                    pass | 
			
		||||
 | 
				                else: | 
			
		||||
 | 
				                    delete = len(sortedFiles) - max_save_num | 
			
		||||
 | 
				                    for x in range(0, delete): | 
			
		||||
 | 
				                        one_folder_name = test_output_dir + '/' + sortedFiles[x][0] | 
			
		||||
 | 
				                        os.system('rm -r ' + one_folder_name) | 
			
		||||
 | 
				                print ('-----------------------------------') | 
			
		||||
 | 
				                # --------------------------------------------------------------------------------------------- # | 
			
		||||
 | 
				    return model | 
			
		||||
 | 
				
 | 
			
		||||
@ -0,0 +1,291 @@ | 
			
		|||||
 | 
				import sys | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				import operator | 
			
		||||
 | 
				from operator import itemgetter | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from torch import nn | 
			
		||||
 | 
				import torch.nn.functional as F | 
			
		||||
 | 
				import random | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import argparse | 
			
		||||
 | 
				import random | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def parse_prompt(text): | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				        process the prompt text; | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				    eos_token = '<|endoftext|>' | 
			
		||||
 | 
				    text = text.strip(eos_token).strip() | 
			
		||||
 | 
				    left_bracket_idx, right_bracket_idx = -1, -1 | 
			
		||||
 | 
				    for idx in range(len(text)): | 
			
		||||
 | 
				        char = text[idx] | 
			
		||||
 | 
				        if char == '[' and left_bracket_idx == -1: # first [ is met | 
			
		||||
 | 
				            left_bracket_idx = idx | 
			
		||||
 | 
				        elif char == ']' and right_bracket_idx == -1: # first ] is met | 
			
		||||
 | 
				            right_bracket_idx = idx | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            pass | 
			
		||||
 | 
				    res_text = '' | 
			
		||||
 | 
				    remove = False | 
			
		||||
 | 
				    if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx: | 
			
		||||
 | 
				        if right_bracket_idx - left_bracket_idx <= 6: | 
			
		||||
 | 
				            remove = True | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            pass | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    for idx in range(len(text)): | 
			
		||||
 | 
				        if remove: | 
			
		||||
 | 
				            if idx >= left_bracket_idx and idx <= right_bracket_idx: | 
			
		||||
 | 
				                continue | 
			
		||||
 | 
				            else: | 
			
		||||
 | 
				                res_text += text[idx] | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            res_text += text[idx] | 
			
		||||
 | 
				    res_text = res_text.strip() | 
			
		||||
 | 
				    res_text = ' '.join(res_text.split()).strip() | 
			
		||||
 | 
				    return res_text | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def typical_filtering(scores, mass, min_tokens_to_keep, filter_value): | 
			
		||||
 | 
				    # calculate entropy | 
			
		||||
 | 
				    normalized = torch.nn.functional.log_softmax(scores, dim=-1) | 
			
		||||
 | 
				    p = torch.exp(normalized) | 
			
		||||
 | 
				    ent = -(normalized * p).nansum(-1, keepdim=True) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    # shift and sort | 
			
		||||
 | 
				    shifted_scores = torch.abs((-normalized) - ent) | 
			
		||||
 | 
				    sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) | 
			
		||||
 | 
				    sorted_logits = scores.gather(-1, sorted_indices) | 
			
		||||
 | 
				    cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    # Remove tokens with cumulative mass above the threshold | 
			
		||||
 | 
				    last_ind = (cumulative_probs < mass).sum(dim=1) | 
			
		||||
 | 
				    last_ind[last_ind < 0] = 0 | 
			
		||||
 | 
				    sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) | 
			
		||||
 | 
				    if min_tokens_to_keep > 1: | 
			
		||||
 | 
				        # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) | 
			
		||||
 | 
				        sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 | 
			
		||||
 | 
				    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    scores = scores.masked_fill(indices_to_remove, filter_value) | 
			
		||||
 | 
				    return scores | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf): | 
			
		||||
 | 
				    assert logits.dim() == 1 | 
			
		||||
 | 
				    top_k = min(top_k, logits.size(-1)) | 
			
		||||
 | 
				    if top_k > 0: | 
			
		||||
 | 
				        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | 
			
		||||
 | 
				        logits[indices_to_remove] = filter_value | 
			
		||||
 | 
				    if top_p > 0.0: | 
			
		||||
 | 
				        sorted_logits, sorted_indices = torch.sort(logits, descending=True) | 
			
		||||
 | 
				        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | 
			
		||||
 | 
				        sorted_indices_to_remove = cumulative_probs > top_p | 
			
		||||
 | 
				        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | 
			
		||||
 | 
				        sorted_indices_to_remove[..., 0] = 0 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        indices_to_remove = sorted_indices[sorted_indices_to_remove] | 
			
		||||
 | 
				        logits[indices_to_remove] = filter_value | 
			
		||||
 | 
				         | 
			
		||||
 | 
				    indices_to_remove = logits < threshold | 
			
		||||
 | 
				    logits[indices_to_remove] = filter_value | 
			
		||||
 | 
				    return logits | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# ========== batch version ========= # | 
			
		||||
 | 
				def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width): | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				        context_hidden: bsz*beam x seqlen x embed_dim | 
			
		||||
 | 
				        next_hidden: bsz*beam x 1 x embed_dim | 
			
		||||
 | 
				        next_top_k_probs: bsz x beam | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				    _, context_len, embed_dim = context_hidden.size() | 
			
		||||
 | 
				    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) | 
			
		||||
 | 
				    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) | 
			
		||||
 | 
				    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)    # [B*K, S] | 
			
		||||
 | 
				    scores, _ = torch.max(cosine_matrix, dim=-1)    # [B*K] | 
			
		||||
 | 
				    next_top_k_probs = next_top_k_probs.view(-1)    # [B*K] | 
			
		||||
 | 
				    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores  | 
			
		||||
 | 
				    scores = torch.stack(torch.split(scores, beam_width))    # [B, K] | 
			
		||||
 | 
				    selected_idx = scores.max(dim=-1)[1]    # [B] | 
			
		||||
 | 
				    return selected_idx | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def ContrastiveDecodingOneStepFast( | 
			
		||||
 | 
				    model,  | 
			
		||||
 | 
				    ids,  | 
			
		||||
 | 
				    beam_width,  | 
			
		||||
 | 
				    alpha,  | 
			
		||||
 | 
				    past_key_values, | 
			
		||||
 | 
				    last_hidden_states, | 
			
		||||
 | 
				    vocab, | 
			
		||||
 | 
				    logit_for_next_step, | 
			
		||||
 | 
				    first_step=False, | 
			
		||||
 | 
				    ): | 
			
		||||
 | 
				    # input_ids: [B, S] | 
			
		||||
 | 
				    if first_step: | 
			
		||||
 | 
				        output = model( | 
			
		||||
 | 
				            input_ids=ids,  | 
			
		||||
 | 
				            past_key_values=past_key_values, | 
			
		||||
 | 
				            use_cache=True, | 
			
		||||
 | 
				            output_hidden_states=True | 
			
		||||
 | 
				        ) | 
			
		||||
 | 
				        past_key_values = output.past_key_values | 
			
		||||
 | 
				        last_hidden_states = output.hidden_states[-1]    # [B, S, E] | 
			
		||||
 | 
				        logit_for_next_step = output.logits[:, -1, :]    # [B, V] | 
			
		||||
 | 
				    bsz, seqlen, embed_dim = last_hidden_states.size() | 
			
		||||
 | 
				    p = random.uniform(0, 1) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    next_probs = F.softmax(logit_for_next_step, dim=-1) | 
			
		||||
 | 
				    _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width)    # [B, K] | 
			
		||||
 | 
				    top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids)    # [B, K] | 
			
		||||
 | 
				    # compute new hidden | 
			
		||||
 | 
				    past_key_values = enlarge_past_key_values(past_key_values, beam_width) | 
			
		||||
 | 
				    output = model( | 
			
		||||
 | 
				        input_ids=top_k_ids.view(-1, 1),  | 
			
		||||
 | 
				        attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), | 
			
		||||
 | 
				        past_key_values=past_key_values, | 
			
		||||
 | 
				        output_hidden_states=True, | 
			
		||||
 | 
				        use_cache=True, | 
			
		||||
 | 
				    ) | 
			
		||||
 | 
				    past_key_values = output.past_key_values | 
			
		||||
 | 
				    logits = output.logits[:, -1, :]    # [B*K, V] | 
			
		||||
 | 
				    next_hidden = output.hidden_states[-1]    # [B*K, 1, E] | 
			
		||||
 | 
				    context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim)    # [B*K, S, E] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    selected_idx = ranking_fast( | 
			
		||||
 | 
				        context_hidden,  | 
			
		||||
 | 
				        next_hidden,  | 
			
		||||
 | 
				        top_k_probs,    # [B, K]  | 
			
		||||
 | 
				        alpha, | 
			
		||||
 | 
				        beam_width, | 
			
		||||
 | 
				    )     # [B] | 
			
		||||
 | 
				    # prepare for the next step | 
			
		||||
 | 
				    next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)    # [B, 1] | 
			
		||||
 | 
				    next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))    # [B, K, E] | 
			
		||||
 | 
				    next_hidden = next_hidden[range(bsz), selected_idx, :]    # [B, E] | 
			
		||||
 | 
				    last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)    # [B, S, E] | 
			
		||||
 | 
				    past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) | 
			
		||||
 | 
				    logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]    # [B, V] | 
			
		||||
 | 
				    # next_id: [B, 1] | 
			
		||||
 | 
				    return next_id, past_key_values, last_hidden_states, logits  | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def enlarge_past_key_values(past_key_values, beam_width): | 
			
		||||
 | 
				    # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] | 
			
		||||
 | 
				    new_key_values = [] | 
			
		||||
 | 
				    for layer in past_key_values: | 
			
		||||
 | 
				        items = [] | 
			
		||||
 | 
				        for item in layer: | 
			
		||||
 | 
				            # item is the key and value matrix | 
			
		||||
 | 
				            bsz, num_head, seq_len, esz = item.size() | 
			
		||||
 | 
				            item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz)    # [bsz*beam, num_head, seq_len, esz] | 
			
		||||
 | 
				            items.append(item) | 
			
		||||
 | 
				        new_key_values.append(items) | 
			
		||||
 | 
				    return new_key_values | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def select_past_key_values(past_key_values, beam_width, selected_idx): | 
			
		||||
 | 
				    '''select_idx: [B]''' | 
			
		||||
 | 
				    new_key_values = [] | 
			
		||||
 | 
				    for layer in past_key_values: | 
			
		||||
 | 
				        items = [] | 
			
		||||
 | 
				        for item in layer: | 
			
		||||
 | 
				            bsz_and_beam, num_head, seq_len, esz = item.size() | 
			
		||||
 | 
				            bsz = int(bsz_and_beam//beam_width) | 
			
		||||
 | 
				            item = torch.stack(torch.split(item, beam_width, dim=0))    # [B, K, num_head, seq_len, esz]  | 
			
		||||
 | 
				            item = item[range(bsz), selected_idx, :, :, :]   # [B, num_head, seq_len, esz] | 
			
		||||
 | 
				            items.append(item) | 
			
		||||
 | 
				        new_key_values.append(items) | 
			
		||||
 | 
				    return new_key_values | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# ========== fast plug and play version ========= # | 
			
		||||
 | 
				def plug_and_play_fast_ranking( | 
			
		||||
 | 
				    context_hidden,  | 
			
		||||
 | 
				    next_hidden,  | 
			
		||||
 | 
				    next_top_k_ids,  | 
			
		||||
 | 
				    next_top_k_probs,  | 
			
		||||
 | 
				    alpha,  | 
			
		||||
 | 
				    beta,  | 
			
		||||
 | 
				    batch_class_score, | 
			
		||||
 | 
				    beam_width, | 
			
		||||
 | 
				): | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				        context_hidden: beam_width x context_len x embed_dim | 
			
		||||
 | 
				        next_hidden: beam_width x 1 x embed_dim | 
			
		||||
 | 
				        next_top_k_ids: beam_width x 1 | 
			
		||||
 | 
				        batch_class_score: beam_width x 1 | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				    _, context_len, embed_dim = context_hidden.size() | 
			
		||||
 | 
				    norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) | 
			
		||||
 | 
				    norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) | 
			
		||||
 | 
				    cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) | 
			
		||||
 | 
				    scores, _ = torch.max(cosine_matrix, dim = -1) | 
			
		||||
 | 
				    next_top_k_probs = next_top_k_probs.view(-1) | 
			
		||||
 | 
				    scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width]) | 
			
		||||
 | 
				    scores = torch.stack(torch.split(scores, beam_width)) | 
			
		||||
 | 
				    selected_idx = scores.max(dim=-1)[1] | 
			
		||||
 | 
				    return selected_idx | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta,  | 
			
		||||
 | 
				    simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states,  | 
			
		||||
 | 
				    logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False): | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				        model: the generation model, e.g., gpt2 | 
			
		||||
 | 
				        input_ids: 1 x seqlen | 
			
		||||
 | 
				    ''' | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    if first_step: | 
			
		||||
 | 
				        output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True) | 
			
		||||
 | 
				        past_key_values = output.past_key_values | 
			
		||||
 | 
				        last_hidden_states = output.hidden_states[-1]    # [B, S, E] | 
			
		||||
 | 
				        logit_for_next_step = output.logits[:, -1, :]    # [B, V] | 
			
		||||
 | 
				    bsz, seqlen, embed_dim = last_hidden_states.size() | 
			
		||||
 | 
				    next_probs = F.softmax(logit_for_next_step, dim = -1) | 
			
		||||
 | 
				    _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width) | 
			
		||||
 | 
				    top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    # compute the new hidden | 
			
		||||
 | 
				    past_key_values = enlarge_past_key_values(past_key_values, beam_width) | 
			
		||||
 | 
				    output = model( | 
			
		||||
 | 
				        input_ids=top_k_ids.view(-1, 1) , | 
			
		||||
 | 
				        attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), | 
			
		||||
 | 
				        past_key_values=past_key_values, | 
			
		||||
 | 
				        output_hidden_states=True, | 
			
		||||
 | 
				        use_cache=True, | 
			
		||||
 | 
				    ) | 
			
		||||
 | 
				    past_key_values = output.past_key_values | 
			
		||||
 | 
				    logits = output.logits[:, -1, :] | 
			
		||||
 | 
				    next_hidden = output.hidden_states[-1] | 
			
		||||
 | 
				    context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    # prepare for the classification model | 
			
		||||
 | 
				    input_ids_for_class_ = torch.cat([ | 
			
		||||
 | 
				        input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen), | 
			
		||||
 | 
				        top_k_ids.view(-1, 1) | 
			
		||||
 | 
				        ], dim=-1 | 
			
		||||
 | 
				    ) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    batch_text_list = [] | 
			
		||||
 | 
				    for one_input_id in input_ids_for_class_: | 
			
		||||
 | 
				        one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:])  | 
			
		||||
 | 
				        # we only consider the class score of the generated text continuation | 
			
		||||
 | 
				        batch_text_list.append(one_text) | 
			
		||||
 | 
				    batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    selected_idx = plug_and_play_fast_ranking( | 
			
		||||
 | 
				        context_hidden,  | 
			
		||||
 | 
				        next_hidden,  | 
			
		||||
 | 
				        top_k_ids,  | 
			
		||||
 | 
				        top_k_probs,  | 
			
		||||
 | 
				        alpha,  | 
			
		||||
 | 
				        beta,  | 
			
		||||
 | 
				        batch_score, | 
			
		||||
 | 
				        beam_width, | 
			
		||||
 | 
				    )        | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    # prepare for the next step | 
			
		||||
 | 
				    next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) | 
			
		||||
 | 
				    next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) | 
			
		||||
 | 
				    next_hidden = next_hidden[range(bsz), selected_idx, :] | 
			
		||||
 | 
				    last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) | 
			
		||||
 | 
				    past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) | 
			
		||||
 | 
				    logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] | 
			
		||||
 | 
				    input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1) | 
			
		||||
 | 
				    return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
@ -1,89 +0,0 @@ | 
			
		|||||
### Our Implementation of the ZeroCap Baseline Model  | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
### Catalogue: | 
				 | 
			
		||||
* <a href='#environment'>1. Environment Preparation</a> | 
				 | 
			
		||||
* <a href='#mscoco'>2. Image Captioning on MSCOCO</a> | 
				 | 
			
		||||
* <a href='#flickr30k'>3. Image Captioning on Flickr30k</a> | 
				 | 
			
		||||
* <a href='#flickr30k_to_mscoco'>4. Cross Domain Image Captioning on MSCOCO</a> | 
				 | 
			
		||||
* <a href='#mscoco_to_flickr30k'>5. Cross Domain Image Captioning on Flickr30k</a> | 
				 | 
			
		||||
* <a href='#citation'>6. Citation</a> | 
				 | 
			
		||||
* <a href='#acknowledgements'>7. Acknowledgements</a> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='environment'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 1. Environment Preparation: | 
				 | 
			
		||||
To install the correct environment, please run the following command: | 
				 | 
			
		||||
```yaml | 
				 | 
			
		||||
pip install -r requirements.txt | 
				 | 
			
		||||
``` | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='mscoco'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 2. Image Captioning on MSCOCO: | 
				 | 
			
		||||
To perform image captioning on MSCOCO, please run the following command: | 
				 | 
			
		||||
```yaml | 
				 | 
			
		||||
chmod +x ./mscoco_zerocap.sh | 
				 | 
			
		||||
./mscoco_zerocap.sh | 
				 | 
			
		||||
``` | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='flickr30k'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 3. Image Captioning on Flickr30k: | 
				 | 
			
		||||
To perform image captioning on Flickr30k, please run the following command: | 
				 | 
			
		||||
```yaml | 
				 | 
			
		||||
chmod +x ./flickr30k_zerocap.sh | 
				 | 
			
		||||
./flickr30k_zerocap.sh | 
				 | 
			
		||||
``` | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='flickr30k_to_mscoco'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 4. Cross Domain Image Captioning on MSCOCO: | 
				 | 
			
		||||
To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command: | 
				 | 
			
		||||
```yaml | 
				 | 
			
		||||
chmod +x ./flickr30k_to_mscoco_zerocap.sh | 
				 | 
			
		||||
./flickr30k_to_mscoco_zerocap.sh | 
				 | 
			
		||||
``` | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='mscoco_to_flickr30k'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 5. Cross Domain Image Captioning on Flickr30k: | 
				 | 
			
		||||
To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command: | 
				 | 
			
		||||
```yaml | 
				 | 
			
		||||
chmod +x ./mscoco_to_flickr30k_zerocap.sh | 
				 | 
			
		||||
./mscoco_to_flickr30k_zerocap.sh | 
				 | 
			
		||||
``` | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='citation'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 6. Citation: | 
				 | 
			
		||||
If you find our code helpful, please cite the original paper as | 
				 | 
			
		||||
 | 
				 | 
			
		||||
```bibtex | 
				 | 
			
		||||
@article{tewel2021zero, | 
				 | 
			
		||||
  title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, | 
				 | 
			
		||||
  author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, | 
				 | 
			
		||||
  journal={arXiv preprint arXiv:2111.14447}, | 
				 | 
			
		||||
  year={2021} | 
				 | 
			
		||||
} | 
				 | 
			
		||||
``` | 
				 | 
			
		||||
 | 
				 | 
			
		||||
**** | 
				 | 
			
		||||
 | 
				 | 
			
		||||
<span id='acknowledgements'/> | 
				 | 
			
		||||
 | 
				 | 
			
		||||
#### 7. Acknowledgements: | 
				 | 
			
		||||
We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text). | 
				 | 
			
		||||
 | 
				 | 
			
		||||
@ -1,12 +0,0 @@ | 
			
		|||||
build: | 
				 | 
			
		||||
  gpu: true | 
				 | 
			
		||||
  python_version: "3.8" | 
				 | 
			
		||||
  system_packages: | 
				 | 
			
		||||
    - "libgl1-mesa-glx" | 
				 | 
			
		||||
    - "libglib2.0-0" | 
				 | 
			
		||||
  python_packages: | 
				 | 
			
		||||
    - "git+https://github.com/openai/CLIP.git" | 
				 | 
			
		||||
    - "git+https://github.com/YoadTew/zero-shot-image-to-text.git" | 
				 | 
			
		||||
 | 
				 | 
			
		||||
predict: "predict.py:Predictor" | 
				 | 
			
		||||
#predict: "predict_arithmetic.py:Predictor" | 
				 | 
			
		||||
@ -1,14 +0,0 @@ | 
			
		|||||
#!/bin/bash | 
				 | 
			
		||||
 | 
				 | 
			
		||||
# lm_model: | 
				 | 
			
		||||
#  1. cambridgeltl/magic_mscoco | 
				 | 
			
		||||
#  2. cambridgeltl/magic_flickr30k | 
				 | 
			
		||||
CUDA_VISIBLE_DEVICES=1 python run.py \ | 
				 | 
			
		||||
    --beam_size 1 \ | 
				 | 
			
		||||
    --target_seq_length 16 \ | 
				 | 
			
		||||
    --reset_context_delta \ | 
				 | 
			
		||||
    --lm_model cambridgeltl/magic_flickr30k \ | 
				 | 
			
		||||
    --test_image_prefix_path ../data/flickr30k/test_images \ | 
				 | 
			
		||||
    --test_path ../data/flickr30k/flickr30k_test.json \ | 
				 | 
			
		||||
    --save_path_prefix ../inference_result/flickr30k/baselines/ \ | 
				 | 
			
		||||
    --save_name zerocap_result.json | 
				 | 
			
		||||
								
									Binary file not shown.
								
							
						
					@ -1,389 +0,0 @@ | 
			
		|||||
import numpy as np | 
				 | 
			
		||||
from torch import nn | 
				 | 
			
		||||
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer | 
				 | 
			
		||||
from transformers.models.gpt_neo import GPTNeoForCausalLM | 
				 | 
			
		||||
import torch | 
				 | 
			
		||||
import clip | 
				 | 
			
		||||
from PIL import Image | 
				 | 
			
		||||
from datetime import datetime | 
				 | 
			
		||||
import sys | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def log_info(text, verbose=True): | 
				 | 
			
		||||
    if verbose: | 
				 | 
			
		||||
        dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") | 
				 | 
			
		||||
        print(f'{dt_string} | {text}') | 
				 | 
			
		||||
        sys.stdout.flush() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def add_context(x, y): | 
				 | 
			
		||||
    return (x[0] + y[0], x[1] + y[1]) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def convert_models_to_fp32(model): | 
				 | 
			
		||||
    for p in model.parameters(): | 
				 | 
			
		||||
        p.data = p.data.float() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class CLIPTextGenerator: | 
				 | 
			
		||||
    def __init__(self, | 
				 | 
			
		||||
                 seed=0, | 
				 | 
			
		||||
                 lm_model='gpt-2', | 
				 | 
			
		||||
                 forbidden_tokens_file_path='./forbidden_tokens.npy', | 
				 | 
			
		||||
                 clip_checkpoints='./clip_checkpoints', | 
				 | 
			
		||||
                 target_seq_length=15, | 
				 | 
			
		||||
                 reset_context_delta=True, | 
				 | 
			
		||||
                 num_iterations=5, | 
				 | 
			
		||||
                 clip_loss_temperature=0.01, | 
				 | 
			
		||||
                 clip_scale=1., | 
				 | 
			
		||||
                 ce_scale=0.2, | 
				 | 
			
		||||
                 stepsize=0.3, | 
				 | 
			
		||||
                 grad_norm_factor=0.9, | 
				 | 
			
		||||
                 fusion_factor=0.99, | 
				 | 
			
		||||
                 repetition_penalty=1., | 
				 | 
			
		||||
                 end_token='.', | 
				 | 
			
		||||
                 end_factor=1.01, | 
				 | 
			
		||||
                 forbidden_factor=20, | 
				 | 
			
		||||
                 **kwargs): | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # set Random seed | 
				 | 
			
		||||
        torch.manual_seed(seed) | 
				 | 
			
		||||
        np.random.seed(seed) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Initialize Language model | 
				 | 
			
		||||
        self.context_prefix = '' | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model) | 
				 | 
			
		||||
        self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True) | 
				 | 
			
		||||
        self.context_prefix = self.lm_tokenizer.bos_token | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.lm_model.to(self.device) | 
				 | 
			
		||||
        self.lm_model.eval() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.forbidden_tokens = np.load(forbidden_tokens_file_path) | 
				 | 
			
		||||
        self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if | 
				 | 
			
		||||
                                      (x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Freeze LM weights | 
				 | 
			
		||||
        for param in self.lm_model.parameters(): | 
				 | 
			
		||||
            param.requires_grad = False | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Initialize CLIP | 
				 | 
			
		||||
        self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, | 
				 | 
			
		||||
                                                    download_root=clip_checkpoints, jit=False) | 
				 | 
			
		||||
        # convert_models_to_fp32(self.clip) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Init arguments | 
				 | 
			
		||||
        self.target_seq_length = target_seq_length | 
				 | 
			
		||||
        self.reset_context_delta = reset_context_delta | 
				 | 
			
		||||
        self.num_iterations = num_iterations | 
				 | 
			
		||||
        self.clip_loss_temperature = clip_loss_temperature | 
				 | 
			
		||||
        self.clip_scale = clip_scale | 
				 | 
			
		||||
        self.ce_scale = ce_scale | 
				 | 
			
		||||
        self.stepsize = stepsize | 
				 | 
			
		||||
        self.grad_norm_factor = grad_norm_factor | 
				 | 
			
		||||
        self.fusion_factor = fusion_factor | 
				 | 
			
		||||
        self.repetition_penalty = repetition_penalty | 
				 | 
			
		||||
        self.end_token = self.lm_tokenizer.encode(end_token)[0] | 
				 | 
			
		||||
        self.end_factor = end_factor | 
				 | 
			
		||||
        self.ef_idx = 1 | 
				 | 
			
		||||
        self.forbidden_factor = forbidden_factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_img_feature(self, img_path, weights): | 
				 | 
			
		||||
        imgs = [Image.open(x) for x in img_path] | 
				 | 
			
		||||
        clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            image_fts = [self.clip.encode_image(x) for x in clip_imgs] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if weights is not None: | 
				 | 
			
		||||
                image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) | 
				 | 
			
		||||
            else: | 
				 | 
			
		||||
                image_features = sum(image_fts) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            image_features = image_features / image_features.norm(dim=-1, keepdim=True) | 
				 | 
			
		||||
            return image_features.detach() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_txt_features(self, text): | 
				 | 
			
		||||
        clip_texts = clip.tokenize(text).to(self.device) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            text_features = self.clip.encode_text(clip_texts) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            text_features = text_features / text_features.norm(dim=-1, keepdim=True) | 
				 | 
			
		||||
        return text_features.detach() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_combined_feature(self, img_path, texts, weights_i, weights_t): | 
				 | 
			
		||||
        imgs = [Image.open(x) for x in img_path] | 
				 | 
			
		||||
        clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] | 
				 | 
			
		||||
        clip_texts = [clip.tokenize(x).to(self.device) for x in texts] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            image_fts = [self.clip.encode_image(x) for x in clip_imgs] | 
				 | 
			
		||||
            text_fts = [self.clip.encode_text(x) for x in clip_texts] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) | 
				 | 
			
		||||
            if weights_t is not None: | 
				 | 
			
		||||
                features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            features = features / features.norm(dim=-1, keepdim=True) | 
				 | 
			
		||||
            return features.detach() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def run(self, image_features, cond_text, beam_size): | 
				 | 
			
		||||
        self.image_features = image_features | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        output_tokens, output_text = self.generate_text(context_tokens, beam_size) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return output_text | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def generate_text(self, context_tokens, beam_size): | 
				 | 
			
		||||
        context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        gen_tokens = None | 
				 | 
			
		||||
        scores = None | 
				 | 
			
		||||
        seq_lengths = torch.ones(beam_size, device=self.device) | 
				 | 
			
		||||
        is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        for i in range(self.target_seq_length): | 
				 | 
			
		||||
            probs = self.get_next_probs(i, context_tokens) | 
				 | 
			
		||||
            logits = probs.log() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if scores is None: | 
				 | 
			
		||||
                scores, next_tokens = logits.topk(beam_size, -1) | 
				 | 
			
		||||
                context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) | 
				 | 
			
		||||
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
                if gen_tokens is None: | 
				 | 
			
		||||
                    gen_tokens = next_tokens | 
				 | 
			
		||||
                else: | 
				 | 
			
		||||
                    gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) | 
				 | 
			
		||||
                    gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) | 
				 | 
			
		||||
            else: | 
				 | 
			
		||||
                logits[is_stopped] = -float(np.inf) | 
				 | 
			
		||||
                logits[is_stopped, 0] = 0 | 
				 | 
			
		||||
                scores_sum = scores[:, None] + logits | 
				 | 
			
		||||
                seq_lengths[~is_stopped] += 1 | 
				 | 
			
		||||
                scores_sum_average = scores_sum / seq_lengths[:, None] | 
				 | 
			
		||||
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( | 
				 | 
			
		||||
                    beam_size, -1) | 
				 | 
			
		||||
                next_tokens_source = next_tokens // scores_sum.shape[1] | 
				 | 
			
		||||
                seq_lengths = seq_lengths[next_tokens_source] | 
				 | 
			
		||||
                next_tokens = next_tokens % scores_sum.shape[1] | 
				 | 
			
		||||
                next_tokens = next_tokens.unsqueeze(1) | 
				 | 
			
		||||
                gen_tokens = gen_tokens[next_tokens_source] | 
				 | 
			
		||||
                gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) | 
				 | 
			
		||||
                context_tokens = context_tokens[next_tokens_source] | 
				 | 
			
		||||
                scores = scores_sum_average * seq_lengths | 
				 | 
			
		||||
                is_stopped = is_stopped[next_tokens_source] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            context_tokens = torch.cat((context_tokens, next_tokens), dim=1) | 
				 | 
			
		||||
            is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            #### | 
				 | 
			
		||||
            tmp_scores = scores / seq_lengths | 
				 | 
			
		||||
            tmp_output_list = gen_tokens.cpu().numpy() | 
				 | 
			
		||||
            tmp_output_texts = [ | 
				 | 
			
		||||
                self.lm_tokenizer.decode(tmp_output) | 
				 | 
			
		||||
                for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) | 
				 | 
			
		||||
            ] | 
				 | 
			
		||||
            tmp_order = tmp_scores.argsort(descending=True) | 
				 | 
			
		||||
            tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] | 
				 | 
			
		||||
            log_info(tmp_output_texts, verbose=True) | 
				 | 
			
		||||
            #### | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if is_stopped.all(): | 
				 | 
			
		||||
                break | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        scores = scores / seq_lengths | 
				 | 
			
		||||
        output_list = gen_tokens.cpu().numpy() | 
				 | 
			
		||||
        output_texts = [ | 
				 | 
			
		||||
            self.lm_tokenizer.decode(output[: int(length)]) | 
				 | 
			
		||||
            for output, length in zip(output_list, seq_lengths) | 
				 | 
			
		||||
        ] | 
				 | 
			
		||||
        order = scores.argsort(descending=True) | 
				 | 
			
		||||
        output_texts = [output_texts[i] for i in order] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return context_tokens, output_texts | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_next_probs(self, i, context_tokens): | 
				 | 
			
		||||
        last_token = context_tokens[:, -1:] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        if self.reset_context_delta and context_tokens.size(1) > 1: | 
				 | 
			
		||||
            context = self.lm_model(context_tokens[:, :-1])["past_key_values"] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Logits of LM with unshifted context | 
				 | 
			
		||||
        logits_before_shift = self.lm_model(context_tokens)["logits"] | 
				 | 
			
		||||
        logits_before_shift = logits_before_shift[:, -1, :] | 
				 | 
			
		||||
        probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        if context: | 
				 | 
			
		||||
            context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        lm_output = self.lm_model(last_token, past_key_values=context) | 
				 | 
			
		||||
        logits, past = ( | 
				 | 
			
		||||
            lm_output["logits"], | 
				 | 
			
		||||
            lm_output["past_key_values"], | 
				 | 
			
		||||
        ) | 
				 | 
			
		||||
        logits = logits[:, -1, :] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        logits = self.update_special_tokens_logits(context_tokens, i, logits) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        probs = nn.functional.softmax(logits, dim=-1) | 
				 | 
			
		||||
        probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) | 
				 | 
			
		||||
        probs = probs / probs.sum() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return probs | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): | 
				 | 
			
		||||
        context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        window_mask = torch.ones_like(context[0][0]).to(self.device) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        for i in range(self.num_iterations): | 
				 | 
			
		||||
            curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in | 
				 | 
			
		||||
                          context_delta] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for p0, p1 in curr_shift: | 
				 | 
			
		||||
                p0.retain_grad() | 
				 | 
			
		||||
                p1.retain_grad() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            shifted_context = list(map(add_context, context, curr_shift)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) | 
				 | 
			
		||||
            logits = shifted_outputs["logits"][:, -1, :] | 
				 | 
			
		||||
            probs = nn.functional.softmax(logits, dim=-1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            loss = 0.0 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # CLIP LOSS | 
				 | 
			
		||||
            clip_loss, clip_losses = self.clip_loss(probs, context_tokens) | 
				 | 
			
		||||
            loss += self.clip_scale * clip_loss | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # CE/Fluency loss | 
				 | 
			
		||||
            ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) | 
				 | 
			
		||||
            loss += ce_loss.sum() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            loss.backward() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # ---------- Weights ---------- | 
				 | 
			
		||||
            combined_scores_k = -(ce_loss) | 
				 | 
			
		||||
            combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # minmax | 
				 | 
			
		||||
            if combined_scores_k.shape[0] == 1: | 
				 | 
			
		||||
                tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) | 
				 | 
			
		||||
            else: | 
				 | 
			
		||||
                tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( | 
				 | 
			
		||||
                        combined_scores_k.max() - combined_scores_k.min()) | 
				 | 
			
		||||
                tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( | 
				 | 
			
		||||
                        combined_scores_c.max() - combined_scores_c.min()) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c | 
				 | 
			
		||||
            tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            factor = 1 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # --------- Specific Gen --------- | 
				 | 
			
		||||
            sep_grads = None | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for b in range(context_tokens.shape[0]): | 
				 | 
			
		||||
                tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] | 
				 | 
			
		||||
                                 for p_ in curr_shift] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
                # normalize gradients | 
				 | 
			
		||||
                tmp_grad = [tuple([-self.stepsize * factor * ( | 
				 | 
			
		||||
                        x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ | 
				 | 
			
		||||
                    j] ** self.grad_norm_factor).data.cpu().numpy() | 
				 | 
			
		||||
                                   for j, x in enumerate(p_)]) | 
				 | 
			
		||||
                            for i, p_ in enumerate(curr_shift)] | 
				 | 
			
		||||
                if sep_grads is None: | 
				 | 
			
		||||
                    sep_grads = tmp_grad | 
				 | 
			
		||||
                else: | 
				 | 
			
		||||
                    for l_index in range(len(sep_grads)): | 
				 | 
			
		||||
                        sep_grads[l_index] = list(sep_grads[l_index]) | 
				 | 
			
		||||
                        for k_index in range(len(sep_grads[0])): | 
				 | 
			
		||||
                            sep_grads[l_index][k_index] = np.concatenate( | 
				 | 
			
		||||
                                (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) | 
				 | 
			
		||||
                        sep_grads[l_index] = tuple(sep_grads[l_index]) | 
				 | 
			
		||||
            final_grads = sep_grads | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # --------- update context --------- | 
				 | 
			
		||||
            context_delta = list(map(add_context, final_grads, context_delta)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for p0, p1 in curr_shift: | 
				 | 
			
		||||
                p0.grad.data.zero_() | 
				 | 
			
		||||
                p1.grad.data.zero_() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            new_context = [] | 
				 | 
			
		||||
            for p0, p1 in context: | 
				 | 
			
		||||
                new_context.append((p0.detach(), p1.detach())) | 
				 | 
			
		||||
            context = new_context | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) | 
				 | 
			
		||||
                         for p_ in context_delta] | 
				 | 
			
		||||
        context = list(map(add_context, context, context_delta)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        new_context = [] | 
				 | 
			
		||||
        for p0, p1 in context: | 
				 | 
			
		||||
            new_context.append((p0.detach(), p1.detach())) | 
				 | 
			
		||||
        context = new_context | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return context | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def update_special_tokens_logits(self, context_tokens, i, logits): | 
				 | 
			
		||||
        for beam_id in range(context_tokens.shape[0]): | 
				 | 
			
		||||
            for token_idx in set(context_tokens[beam_id][-4:].tolist()): | 
				 | 
			
		||||
                factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) | 
				 | 
			
		||||
                logits[beam_id, token_idx] /= factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if i >= self.ef_idx: | 
				 | 
			
		||||
                factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) | 
				 | 
			
		||||
                logits[beam_id, self.end_token] *= factor | 
				 | 
			
		||||
            if i == 0: | 
				 | 
			
		||||
                start_factor = 1.6 | 
				 | 
			
		||||
                factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) | 
				 | 
			
		||||
                logits[beam_id, self.end_token] /= factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for token_idx in list(self.forbidden_tokens): | 
				 | 
			
		||||
                factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) | 
				 | 
			
		||||
                logits[beam_id, token_idx] /= factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return logits | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def clip_loss(self, probs, context_tokens): | 
				 | 
			
		||||
        for p_ in self.clip.transformer.parameters(): | 
				 | 
			
		||||
            if p_.grad is not None: | 
				 | 
			
		||||
                p_.grad.data.zero_() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        top_size = 512 | 
				 | 
			
		||||
        _, top_indices = probs.topk(top_size, -1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        clip_loss = 0 | 
				 | 
			
		||||
        losses = [] | 
				 | 
			
		||||
        for idx_p in range(probs.shape[0]): | 
				 | 
			
		||||
            top_texts = [] | 
				 | 
			
		||||
            prefix_text = prefix_texts[idx_p] | 
				 | 
			
		||||
            for x in top_indices[idx_p]: | 
				 | 
			
		||||
                top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) | 
				 | 
			
		||||
            text_features = self.get_txt_features(top_texts) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            with torch.no_grad(): | 
				 | 
			
		||||
                similiraties = (self.image_features @ text_features.T) | 
				 | 
			
		||||
                target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() | 
				 | 
			
		||||
                target_probs = target_probs.type(torch.float32) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            target = torch.zeros_like(probs[idx_p]) | 
				 | 
			
		||||
            target[top_indices[idx_p]] = target_probs[0] | 
				 | 
			
		||||
            target = target.unsqueeze(0) | 
				 | 
			
		||||
            cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            clip_loss += cur_clip_loss | 
				 | 
			
		||||
            losses.append(cur_clip_loss) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return clip_loss, losses | 
				 | 
			
		||||
@ -1,449 +0,0 @@ | 
			
		|||||
import numpy as np | 
				 | 
			
		||||
from torch import nn | 
				 | 
			
		||||
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer | 
				 | 
			
		||||
from transformers.models.gpt_neo import GPTNeoForCausalLM | 
				 | 
			
		||||
import torch | 
				 | 
			
		||||
import clip | 
				 | 
			
		||||
from PIL import Image | 
				 | 
			
		||||
from datetime import datetime | 
				 | 
			
		||||
import sys | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class TextCLIP(nn.Module): | 
				 | 
			
		||||
    def __init__(self, model): | 
				 | 
			
		||||
        super(TextCLIP, self).__init__() | 
				 | 
			
		||||
        self.model = model | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def forward(self, text): | 
				 | 
			
		||||
        return self.model.encode_text(text) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class ImageCLIP(nn.Module): | 
				 | 
			
		||||
    def __init__(self, model): | 
				 | 
			
		||||
        super(ImageCLIP, self).__init__() | 
				 | 
			
		||||
        self.model = model | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def forward(self, image): | 
				 | 
			
		||||
        return self.model.encode_image(image) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def log_info(text, verbose=True): | 
				 | 
			
		||||
    if verbose: | 
				 | 
			
		||||
        dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") | 
				 | 
			
		||||
        print(f'{dt_string} | {text}') | 
				 | 
			
		||||
        sys.stdout.flush() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def add_context(x, y): | 
				 | 
			
		||||
    return (x[0] + y[0], x[1] + y[1]) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def convert_models_to_fp32(model): | 
				 | 
			
		||||
    for p in model.parameters(): | 
				 | 
			
		||||
        p.data = p.data.float() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class CLIPTextGenerator: | 
				 | 
			
		||||
    def __init__(self, | 
				 | 
			
		||||
                 seed=0, | 
				 | 
			
		||||
                 lm_model='gpt-2', | 
				 | 
			
		||||
                 forbidden_tokens_file_path='./forbidden_tokens.npy', | 
				 | 
			
		||||
                 clip_checkpoints='./clip_checkpoints', | 
				 | 
			
		||||
                 target_seq_length=15, | 
				 | 
			
		||||
                 reset_context_delta=True, | 
				 | 
			
		||||
                 num_iterations=5, | 
				 | 
			
		||||
                 clip_loss_temperature=0.01, | 
				 | 
			
		||||
                 clip_scale=1., | 
				 | 
			
		||||
                 ce_scale=0.2, | 
				 | 
			
		||||
                 stepsize=0.3, | 
				 | 
			
		||||
                 grad_norm_factor=0.9, | 
				 | 
			
		||||
                 fusion_factor=0.99, | 
				 | 
			
		||||
                 repetition_penalty=1., | 
				 | 
			
		||||
                 end_token='.', | 
				 | 
			
		||||
                 end_factor=1.01, | 
				 | 
			
		||||
                 forbidden_factor=20, | 
				 | 
			
		||||
                 **kwargs): | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.device = "cuda" if torch.cuda.is_available() else "cpu" | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # set Random seed | 
				 | 
			
		||||
        torch.manual_seed(seed) | 
				 | 
			
		||||
        np.random.seed(seed) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Initialize Language model | 
				 | 
			
		||||
        self.context_prefix = '' | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        if lm_model == 'gpt-neo': | 
				 | 
			
		||||
            self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') | 
				 | 
			
		||||
            self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) | 
				 | 
			
		||||
        elif lm_model == 'gpt-2': | 
				 | 
			
		||||
            self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') | 
				 | 
			
		||||
            self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) | 
				 | 
			
		||||
            self.context_prefix = self.lm_tokenizer.bos_token | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.lm_model.to(self.device) | 
				 | 
			
		||||
        self.lm_model.eval() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        self.forbidden_tokens = np.load(forbidden_tokens_file_path) | 
				 | 
			
		||||
        self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if | 
				 | 
			
		||||
                                      (x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Freeze LM weights | 
				 | 
			
		||||
        for param in self.lm_model.parameters(): | 
				 | 
			
		||||
            param.requires_grad = False | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Initialize CLIP | 
				 | 
			
		||||
        self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, | 
				 | 
			
		||||
                                                    download_root=clip_checkpoints, jit=False) | 
				 | 
			
		||||
        self.clip_image = ImageCLIP(self.clip) | 
				 | 
			
		||||
        self.clip_image = torch.nn.DataParallel(self.clip_image) | 
				 | 
			
		||||
        self.clip_text = TextCLIP(self.clip) | 
				 | 
			
		||||
        self.clip_text = torch.nn.DataParallel(self.clip_text) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Init arguments | 
				 | 
			
		||||
        self.target_seq_length = target_seq_length | 
				 | 
			
		||||
        self.reset_context_delta = reset_context_delta | 
				 | 
			
		||||
        self.num_iterations = num_iterations | 
				 | 
			
		||||
        self.clip_loss_temperature = clip_loss_temperature | 
				 | 
			
		||||
        self.clip_scale = clip_scale | 
				 | 
			
		||||
        self.ce_scale = ce_scale | 
				 | 
			
		||||
        self.stepsize = stepsize | 
				 | 
			
		||||
        self.grad_norm_factor = grad_norm_factor | 
				 | 
			
		||||
        self.fusion_factor = fusion_factor | 
				 | 
			
		||||
        self.repetition_penalty = repetition_penalty | 
				 | 
			
		||||
        self.end_token = self.lm_tokenizer.encode(end_token)[0] | 
				 | 
			
		||||
        self.end_factor = end_factor | 
				 | 
			
		||||
        self.ef_idx = 1 | 
				 | 
			
		||||
        self.forbidden_factor = forbidden_factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_img_feature(self, img_path, weights): | 
				 | 
			
		||||
        imgs = [Image.open(x) for x in img_path] | 
				 | 
			
		||||
        clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            image_fts = [self.clip_image(x) for x in clip_imgs] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if weights is not None: | 
				 | 
			
		||||
                image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) | 
				 | 
			
		||||
            else: | 
				 | 
			
		||||
                image_features = sum(image_fts) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            image_features = torch.nn.functional.normalize(image_features, dim=-1) | 
				 | 
			
		||||
            return image_features.detach() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_txt_features(self, text): | 
				 | 
			
		||||
        clip_texts = clip.tokenize(text).to(self.device) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            text_features = self.clip_text(clip_texts) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            text_features = torch.nn.functional.normalize(text_features, dim=-1) | 
				 | 
			
		||||
        return text_features.detach() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_combined_feature(self, img_path, texts, weights_i, weights_t): | 
				 | 
			
		||||
        imgs = [Image.open(x) for x in img_path] | 
				 | 
			
		||||
        clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] | 
				 | 
			
		||||
        clip_texts = [clip.tokenize(x).to(self.device) for x in texts] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            image_fts = [self.clip.encode_image(x) for x in clip_imgs] | 
				 | 
			
		||||
            text_fts = [self.clip.encode_text(x) for x in clip_texts] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) | 
				 | 
			
		||||
            if weights_t is not None: | 
				 | 
			
		||||
                features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            features = features / features.norm(dim=-1, keepdim=True) | 
				 | 
			
		||||
            return features.detach() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def run(self, image_features, cond_text, beam_size): | 
				 | 
			
		||||
        self.image_features = image_features | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        output_tokens, output_text = self.generate_text(context_tokens, beam_size) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return output_text | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def generate_text(self, context_tokens, beam_size): | 
				 | 
			
		||||
        context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        gen_tokens = None | 
				 | 
			
		||||
        scores = None | 
				 | 
			
		||||
        seq_lengths = torch.ones(beam_size, device=self.device) | 
				 | 
			
		||||
        is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        for i in range(self.target_seq_length): | 
				 | 
			
		||||
            probs = self.get_next_probs(i, context_tokens) | 
				 | 
			
		||||
            logits = probs.log() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if scores is None: | 
				 | 
			
		||||
                scores, next_tokens = logits.topk(beam_size, -1) | 
				 | 
			
		||||
                context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) | 
				 | 
			
		||||
                next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
                if gen_tokens is None: | 
				 | 
			
		||||
                    gen_tokens = next_tokens | 
				 | 
			
		||||
                else: | 
				 | 
			
		||||
                    gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) | 
				 | 
			
		||||
                    gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) | 
				 | 
			
		||||
            else: | 
				 | 
			
		||||
                logits[is_stopped] = -float(np.inf) | 
				 | 
			
		||||
                logits[is_stopped, 0] = 0 | 
				 | 
			
		||||
                scores_sum = scores[:, None] + logits | 
				 | 
			
		||||
                seq_lengths[~is_stopped] += 1 | 
				 | 
			
		||||
                scores_sum_average = scores_sum / seq_lengths[:, None] | 
				 | 
			
		||||
                scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( | 
				 | 
			
		||||
                    beam_size, -1) | 
				 | 
			
		||||
                next_tokens_source = next_tokens // scores_sum.shape[1] | 
				 | 
			
		||||
                seq_lengths = seq_lengths[next_tokens_source] | 
				 | 
			
		||||
                next_tokens = next_tokens % scores_sum.shape[1] | 
				 | 
			
		||||
                next_tokens = next_tokens.unsqueeze(1) | 
				 | 
			
		||||
                gen_tokens = gen_tokens[next_tokens_source] | 
				 | 
			
		||||
                gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) | 
				 | 
			
		||||
                context_tokens = context_tokens[next_tokens_source] | 
				 | 
			
		||||
                scores = scores_sum_average * seq_lengths | 
				 | 
			
		||||
                is_stopped = is_stopped[next_tokens_source] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            context_tokens = torch.cat((context_tokens, next_tokens), dim=1) | 
				 | 
			
		||||
            is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            #### | 
				 | 
			
		||||
            tmp_scores = scores / seq_lengths | 
				 | 
			
		||||
            tmp_output_list = gen_tokens.cpu().numpy() | 
				 | 
			
		||||
            tmp_output_texts = [ | 
				 | 
			
		||||
                self.lm_tokenizer.decode(tmp_output) | 
				 | 
			
		||||
                for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) | 
				 | 
			
		||||
            ] | 
				 | 
			
		||||
            tmp_order = tmp_scores.argsort(descending=True) | 
				 | 
			
		||||
            tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] | 
				 | 
			
		||||
            log_info(tmp_output_texts, verbose=True) | 
				 | 
			
		||||
            #### | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if is_stopped.all(): | 
				 | 
			
		||||
                break | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        scores = scores / seq_lengths | 
				 | 
			
		||||
        output_list = gen_tokens.cpu().numpy() | 
				 | 
			
		||||
        output_texts = [ | 
				 | 
			
		||||
            self.lm_tokenizer.decode(output[: int(length)]) | 
				 | 
			
		||||
            for output, length in zip(output_list, seq_lengths) | 
				 | 
			
		||||
        ] | 
				 | 
			
		||||
        order = scores.argsort(descending=True) | 
				 | 
			
		||||
        output_texts = [output_texts[i] for i in order] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return context_tokens, output_texts | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def get_next_probs(self, i, context_tokens): | 
				 | 
			
		||||
        last_token = context_tokens[:, -1:] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        if self.reset_context_delta and context_tokens.size(1) > 1: | 
				 | 
			
		||||
            context = self.lm_model(context_tokens[:, :-1])["past_key_values"] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Logits of LM with unshifted context | 
				 | 
			
		||||
        logits_before_shift = self.lm_model(context_tokens)["logits"] | 
				 | 
			
		||||
        logits_before_shift = logits_before_shift[:, -1, :] | 
				 | 
			
		||||
        probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        if context: | 
				 | 
			
		||||
            context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        lm_output = self.lm_model(last_token, past_key_values=context) | 
				 | 
			
		||||
        logits, past = ( | 
				 | 
			
		||||
            lm_output["logits"], | 
				 | 
			
		||||
            lm_output["past_key_values"], | 
				 | 
			
		||||
        ) | 
				 | 
			
		||||
        logits = logits[:, -1, :] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        logits = self.update_special_tokens_logits(context_tokens, i, logits) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        probs = nn.functional.softmax(logits, dim=-1) | 
				 | 
			
		||||
        probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) | 
				 | 
			
		||||
        probs = probs / probs.sum() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return probs | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): | 
				 | 
			
		||||
        context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        for i in range(self.num_iterations): | 
				 | 
			
		||||
            curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in | 
				 | 
			
		||||
                          context_delta] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for p0, p1 in curr_shift: | 
				 | 
			
		||||
                p0.retain_grad() | 
				 | 
			
		||||
                p1.retain_grad() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            shifted_context = list(map(add_context, context, curr_shift)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) | 
				 | 
			
		||||
            logits = shifted_outputs["logits"][:, -1, :] | 
				 | 
			
		||||
            probs = nn.functional.softmax(logits, dim=-1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            loss = 0.0 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # CLIP LOSS | 
				 | 
			
		||||
            clip_loss, clip_losses = self.clip_loss(probs, context_tokens) | 
				 | 
			
		||||
            loss += self.clip_scale * clip_loss | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # CE/Fluency loss | 
				 | 
			
		||||
            ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) | 
				 | 
			
		||||
            loss += ce_loss.sum() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            loss.backward() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # --------- Specific Gen --------- | 
				 | 
			
		||||
            final_grads = self.norm_grad(context, context_tokens, curr_shift) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # --------- update context --------- | 
				 | 
			
		||||
            context_delta = list(map(add_context, final_grads, context_delta)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for p0, p1 in curr_shift: | 
				 | 
			
		||||
                p0.grad.data.zero_() | 
				 | 
			
		||||
                p1.grad.data.zero_() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            new_context = [] | 
				 | 
			
		||||
            for p0, p1 in context: | 
				 | 
			
		||||
                new_context.append((p0.detach(), p1.detach())) | 
				 | 
			
		||||
            context = new_context | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) | 
				 | 
			
		||||
                         for p_ in context_delta] | 
				 | 
			
		||||
        context = list(map(add_context, context, context_delta)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        new_context = [] | 
				 | 
			
		||||
        for p0, p1 in context: | 
				 | 
			
		||||
            new_context.append((p0.detach(), p1.detach())) | 
				 | 
			
		||||
        context = new_context | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return context | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def norm_grad(self, context, context_tokens, curr_shift, ): | 
				 | 
			
		||||
        factor = 1 | 
				 | 
			
		||||
        sep_grads = None | 
				 | 
			
		||||
        window_mask = torch.ones_like(context[0][0]).to(self.device) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        for b in range(context_tokens.shape[0]): | 
				 | 
			
		||||
            tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] | 
				 | 
			
		||||
                             for p_ in curr_shift] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            # normalize gradients | 
				 | 
			
		||||
            tmp_grad = [tuple([-self.stepsize * factor * ( | 
				 | 
			
		||||
                    x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ | 
				 | 
			
		||||
                j] ** self.grad_norm_factor).data.cpu().numpy() | 
				 | 
			
		||||
                               for j, x in enumerate(p_)]) | 
				 | 
			
		||||
                        for i, p_ in enumerate(curr_shift)] | 
				 | 
			
		||||
            if sep_grads is None: | 
				 | 
			
		||||
                sep_grads = tmp_grad | 
				 | 
			
		||||
            else: | 
				 | 
			
		||||
                for l_index in range(len(sep_grads)): | 
				 | 
			
		||||
                    sep_grads[l_index] = list(sep_grads[l_index]) | 
				 | 
			
		||||
                    for k_index in range(len(sep_grads[0])): | 
				 | 
			
		||||
                        sep_grads[l_index][k_index] = np.concatenate( | 
				 | 
			
		||||
                            (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) | 
				 | 
			
		||||
                    sep_grads[l_index] = tuple(sep_grads[l_index]) | 
				 | 
			
		||||
        final_grads = sep_grads | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return final_grads | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def update_special_tokens_logits(self, context_tokens, i, logits): | 
				 | 
			
		||||
        for beam_id in range(context_tokens.shape[0]): | 
				 | 
			
		||||
            for token_idx in set(context_tokens[beam_id][-4:].tolist()): | 
				 | 
			
		||||
                factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) | 
				 | 
			
		||||
                logits[beam_id, token_idx] /= factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            if i >= self.ef_idx: | 
				 | 
			
		||||
                factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) | 
				 | 
			
		||||
                logits[beam_id, self.end_token] *= factor | 
				 | 
			
		||||
            if i == 0: | 
				 | 
			
		||||
                start_factor = 1.6 | 
				 | 
			
		||||
                factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) | 
				 | 
			
		||||
                logits[beam_id, self.end_token] /= factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            for token_idx in list(self.forbidden_tokens): | 
				 | 
			
		||||
                factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) | 
				 | 
			
		||||
                logits[beam_id, token_idx] /= factor | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return logits | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def clip_loss(self, probs, context_tokens): | 
				 | 
			
		||||
        for p_ in self.clip.transformer.parameters(): | 
				 | 
			
		||||
            if p_.grad is not None: | 
				 | 
			
		||||
                p_.grad.data.zero_() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        top_size = 512 | 
				 | 
			
		||||
        top_probs, top_indices = probs.topk(top_size, -1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        clip_loss = 0 | 
				 | 
			
		||||
        losses = [] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        top_texts = [] | 
				 | 
			
		||||
        for idx_p in range(probs.shape[0]): | 
				 | 
			
		||||
            prefix_text = prefix_texts[idx_p] | 
				 | 
			
		||||
            for x in top_indices[idx_p]: | 
				 | 
			
		||||
                top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        with torch.no_grad(): | 
				 | 
			
		||||
            similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) | 
				 | 
			
		||||
            similiraties = similiraties.reshape(probs.size(0), -1) | 
				 | 
			
		||||
            target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() | 
				 | 
			
		||||
            target_probs = target_probs.type(torch.float32) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) | 
				 | 
			
		||||
        # for idx_p in range(probs.shape[0]): | 
				 | 
			
		||||
        #     top_texts = [] | 
				 | 
			
		||||
        #     prefix_text = prefix_texts[idx_p] | 
				 | 
			
		||||
        #     for x in top_indices[idx_p]: | 
				 | 
			
		||||
        #         top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) | 
				 | 
			
		||||
        #     text_features = self.get_txt_features(top_texts) | 
				 | 
			
		||||
        # | 
				 | 
			
		||||
        #     with torch.no_grad(): | 
				 | 
			
		||||
        #         similiraties = (self.image_features @ text_features.T) | 
				 | 
			
		||||
        #         target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() | 
				 | 
			
		||||
        #         target_probs = target_probs.type(torch.float32) | 
				 | 
			
		||||
        # | 
				 | 
			
		||||
        #     target = torch.zeros_like(probs[idx_p]) | 
				 | 
			
		||||
        #     target[top_indices[idx_p]] = target_probs[0] | 
				 | 
			
		||||
        #     target = target.unsqueeze(0) | 
				 | 
			
		||||
        #     cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) | 
				 | 
			
		||||
        # | 
				 | 
			
		||||
        #     clip_loss += cur_clip_loss | 
				 | 
			
		||||
        #     losses.append(cur_clip_loss) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return clip_loss, losses | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    def clip_loss_old(self, probs, context_tokens): | 
				 | 
			
		||||
        for p_ in self.clip.transformer.parameters(): | 
				 | 
			
		||||
            if p_.grad is not None: | 
				 | 
			
		||||
                p_.grad.data.zero_() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        top_size = 512 | 
				 | 
			
		||||
        _, top_indices = probs.topk(top_size, -1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        clip_loss = 0 | 
				 | 
			
		||||
        losses = [] | 
				 | 
			
		||||
        for idx_p in range(probs.shape[0]): | 
				 | 
			
		||||
            top_texts = [] | 
				 | 
			
		||||
            prefix_text = prefix_texts[idx_p] | 
				 | 
			
		||||
            for x in top_indices[idx_p]: | 
				 | 
			
		||||
                top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) | 
				 | 
			
		||||
            text_features = self.get_txt_features(top_texts) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            with torch.no_grad(): | 
				 | 
			
		||||
                similiraties = (self.image_features @ text_features.T) | 
				 | 
			
		||||
                target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() | 
				 | 
			
		||||
                target_probs = target_probs.type(torch.float32) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            target = torch.zeros_like(probs[idx_p]) | 
				 | 
			
		||||
            target[top_indices[idx_p]] = target_probs[0] | 
				 | 
			
		||||
            target = target.unsqueeze(0) | 
				 | 
			
		||||
            cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
            clip_loss += cur_clip_loss | 
				 | 
			
		||||
            losses.append(cur_clip_loss) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return clip_loss, losses | 
				 | 
			
		||||
								
									Binary file not shown.
								
							
						
					
								
									Binary file not shown.
								
							
						
					
								
									Binary file not shown.
								
							
						
					
								
									Binary file not shown.
								
							
						
					
								
									Binary file not shown.
								
							
						
					
								
									Binary file not shown.
								
							
						
					@ -1,14 +0,0 @@ | 
			
		|||||
#!/bin/bash | 
				 | 
			
		||||
 | 
				 | 
			
		||||
# lm_model: | 
				 | 
			
		||||
#  1. cambridgeltl/magic_mscoco | 
				 | 
			
		||||
#  2. cambridgeltl/magic_flickr30k | 
				 | 
			
		||||
CUDA_VISIBLE_DEVICES=1 python run.py \ | 
				 | 
			
		||||
    --beam_size 1 \ | 
				 | 
			
		||||
    --target_seq_length 16 \ | 
				 | 
			
		||||
    --reset_context_delta \ | 
				 | 
			
		||||
    --lm_model cambridgeltl/magic_mscoco \ | 
				 | 
			
		||||
    --test_image_prefix_path ../data/mscoco/test_images \ | 
				 | 
			
		||||
    --test_path ../data/mscoco/mscoco_test.json \ | 
				 | 
			
		||||
    --save_path_prefix ../inference_result/mscoco/baselines/ \ | 
				 | 
			
		||||
    --save_name zerocap_result.json | 
				 | 
			
		||||
@ -1,117 +0,0 @@ | 
			
		|||||
import os | 
				 | 
			
		||||
import tempfile | 
				 | 
			
		||||
import sys | 
				 | 
			
		||||
sys.path.append('CLIP') | 
				 | 
			
		||||
from pathlib import Path | 
				 | 
			
		||||
import cog | 
				 | 
			
		||||
import argparse | 
				 | 
			
		||||
import torch | 
				 | 
			
		||||
import clip | 
				 | 
			
		||||
from model.ZeroCLIP import CLIPTextGenerator | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def perplexity_score(text, lm_model, lm_tokenizer, device): | 
				 | 
			
		||||
    encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') | 
				 | 
			
		||||
    input_ids = encodings.input_ids.to(device) | 
				 | 
			
		||||
    target_ids = input_ids.clone() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    outputs = lm_model(input_ids, labels=target_ids) | 
				 | 
			
		||||
    log_likelihood = outputs[0] | 
				 | 
			
		||||
    ll = log_likelihood.item() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    return ll | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class Predictor(cog.Predictor): | 
				 | 
			
		||||
    def setup(self): | 
				 | 
			
		||||
        self.args = get_args() | 
				 | 
			
		||||
        self.args.reset_context_delta = True | 
				 | 
			
		||||
        self.text_generator = CLIPTextGenerator(**vars(self.args)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "image", | 
				 | 
			
		||||
        type=Path, | 
				 | 
			
		||||
        help="input image" | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "cond_text", | 
				 | 
			
		||||
        type=str, | 
				 | 
			
		||||
        default='Image of a', | 
				 | 
			
		||||
        help="conditional text", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "beam_size", | 
				 | 
			
		||||
        type=int, | 
				 | 
			
		||||
        default=5, min=1, max=10, | 
				 | 
			
		||||
        help="Number of beams to use", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "end_factor", | 
				 | 
			
		||||
        type=float, | 
				 | 
			
		||||
        default=1.01, min=1.0, max=1.10, | 
				 | 
			
		||||
        help="Higher value for shorter captions", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "max_seq_length", | 
				 | 
			
		||||
        type=int, | 
				 | 
			
		||||
        default=15, min=1, max=20, | 
				 | 
			
		||||
        help="Maximum number of tokens to generate", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "ce_loss_scale", | 
				 | 
			
		||||
        type=float, | 
				 | 
			
		||||
        default=0.2, min=0.0, max=0.6, | 
				 | 
			
		||||
        help="Scale of cross-entropy loss with un-shifted language model", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): | 
				 | 
			
		||||
        self.args.cond_text = cond_text | 
				 | 
			
		||||
        self.text_generator.end_factor = end_factor | 
				 | 
			
		||||
        self.text_generator.target_seq_length = max_seq_length | 
				 | 
			
		||||
        self.text_generator.ce_scale = ce_loss_scale | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        image_features = self.text_generator.get_img_feature([str(image)], None) | 
				 | 
			
		||||
        captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # CLIP SCORE | 
				 | 
			
		||||
        encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) | 
				 | 
			
		||||
                            for c in captions] | 
				 | 
			
		||||
        encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] | 
				 | 
			
		||||
        best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Perplexity SCORE | 
				 | 
			
		||||
        ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] | 
				 | 
			
		||||
        best_ppl_index = torch.tensor(ppl_scores).argmin().item() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        best_clip_caption = self.args.cond_text + captions[best_clip_idx] | 
				 | 
			
		||||
        best_mixed = self.args.cond_text + captions[0] | 
				 | 
			
		||||
        best_PPL = self.args.cond_text + captions[best_ppl_index] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return final | 
				 | 
			
		||||
        # return self.args.cond_text + captions[best_clip_idx] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def get_args(): | 
				 | 
			
		||||
    parser = argparse.ArgumentParser() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0) | 
				 | 
			
		||||
    parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") | 
				 | 
			
		||||
    parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") | 
				 | 
			
		||||
    parser.add_argument("--target_seq_length", type=int, default=15) | 
				 | 
			
		||||
    parser.add_argument("--cond_text", type=str, default="Image of a") | 
				 | 
			
		||||
    parser.add_argument("--reset_context_delta", action="store_true", | 
				 | 
			
		||||
                        help="Should we reset the context at each token gen") | 
				 | 
			
		||||
    parser.add_argument("--num_iterations", type=int, default=5) | 
				 | 
			
		||||
    parser.add_argument("--clip_loss_temperature", type=float, default=0.01) | 
				 | 
			
		||||
    parser.add_argument("--clip_scale", type=float, default=1) | 
				 | 
			
		||||
    parser.add_argument("--ce_scale", type=float, default=0.2) | 
				 | 
			
		||||
    parser.add_argument("--stepsize", type=float, default=0.3) | 
				 | 
			
		||||
    parser.add_argument("--grad_norm_factor", type=float, default=0.9) | 
				 | 
			
		||||
    parser.add_argument("--fusion_factor", type=float, default=0.99) | 
				 | 
			
		||||
    parser.add_argument("--repetition_penalty", type=float, default=1) | 
				 | 
			
		||||
    parser.add_argument("--end_token", type=str, default=".", help="Token to end text") | 
				 | 
			
		||||
    parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") | 
				 | 
			
		||||
    parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") | 
				 | 
			
		||||
    parser.add_argument("--beam_size", type=int, default=5) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    args = parser.parse_args('') | 
				 | 
			
		||||
    return args | 
				 | 
			
		||||
@ -1,129 +0,0 @@ | 
			
		|||||
import os | 
				 | 
			
		||||
import tempfile | 
				 | 
			
		||||
import sys | 
				 | 
			
		||||
sys.path.append('CLIP') | 
				 | 
			
		||||
from pathlib import Path | 
				 | 
			
		||||
import cog | 
				 | 
			
		||||
import argparse | 
				 | 
			
		||||
import torch | 
				 | 
			
		||||
import clip | 
				 | 
			
		||||
from model.ZeroCLIP import CLIPTextGenerator | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def perplexity_score(text, lm_model, lm_tokenizer, device): | 
				 | 
			
		||||
    encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') | 
				 | 
			
		||||
    input_ids = encodings.input_ids.to(device) | 
				 | 
			
		||||
    target_ids = input_ids.clone() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    outputs = lm_model(input_ids, labels=target_ids) | 
				 | 
			
		||||
    log_likelihood = outputs[0] | 
				 | 
			
		||||
    ll = log_likelihood.item() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    return ll | 
				 | 
			
		||||
 | 
				 | 
			
		||||
class Predictor(cog.Predictor): | 
				 | 
			
		||||
    def setup(self): | 
				 | 
			
		||||
        self.args = get_args() | 
				 | 
			
		||||
        self.args.reset_context_delta = True | 
				 | 
			
		||||
        self.text_generator = CLIPTextGenerator(**vars(self.args)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "image1", | 
				 | 
			
		||||
        type=Path, | 
				 | 
			
		||||
        help="Final result will be: image1 + (image2 - image3)" | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "image2", | 
				 | 
			
		||||
        type=Path, | 
				 | 
			
		||||
        help="Final result will be: image1 + (image2 - image3)" | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "image3", | 
				 | 
			
		||||
        type=Path, | 
				 | 
			
		||||
        help="Final result will be: image1 + (image2 - image3)" | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "cond_text", | 
				 | 
			
		||||
        type=str, | 
				 | 
			
		||||
        default='Image of a', | 
				 | 
			
		||||
        help="conditional text", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "beam_size", | 
				 | 
			
		||||
        type=int, | 
				 | 
			
		||||
        default=3, min=1, max=10, | 
				 | 
			
		||||
        help="Number of beams to use", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "end_factors", | 
				 | 
			
		||||
        type=float, | 
				 | 
			
		||||
        default=1.06, min=1.0, max=1.10, | 
				 | 
			
		||||
        help="Higher value for shorter captions", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "max_seq_lengths", | 
				 | 
			
		||||
        type=int, | 
				 | 
			
		||||
        default=3, min=1, max=20, | 
				 | 
			
		||||
        help="Maximum number of tokens to generate", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    @cog.input( | 
				 | 
			
		||||
        "ce_loss_scale", | 
				 | 
			
		||||
        type=float, | 
				 | 
			
		||||
        default=0.2, min=0.0, max=0.6, | 
				 | 
			
		||||
        help="Scale of cross-entropy loss with un-shifted language model", | 
				 | 
			
		||||
    ) | 
				 | 
			
		||||
    def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): | 
				 | 
			
		||||
        self.args.cond_text = cond_text | 
				 | 
			
		||||
        self.text_generator.end_factor = end_factors | 
				 | 
			
		||||
        self.text_generator.target_seq_length = max_seq_lengths | 
				 | 
			
		||||
        self.text_generator.ce_scale = ce_loss_scale | 
				 | 
			
		||||
        self.text_generator.fusion_factor = 0.95 | 
				 | 
			
		||||
        self.text_generator.grad_norm_factor = 0.95 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) | 
				 | 
			
		||||
        captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # CLIP SCORE | 
				 | 
			
		||||
        encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) | 
				 | 
			
		||||
                            for c in captions] | 
				 | 
			
		||||
        encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] | 
				 | 
			
		||||
        best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        # Perplexity SCORE | 
				 | 
			
		||||
        ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] | 
				 | 
			
		||||
        best_ppl_index = torch.tensor(ppl_scores).argmin().item() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        best_clip_caption = self.args.cond_text + captions[best_clip_idx] | 
				 | 
			
		||||
        best_mixed = self.args.cond_text + captions[0] | 
				 | 
			
		||||
        best_PPL = self.args.cond_text + captions[best_ppl_index] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        return final | 
				 | 
			
		||||
        # return self.args.cond_text + captions[best_clip_idx] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def get_args(): | 
				 | 
			
		||||
    parser = argparse.ArgumentParser() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0) | 
				 | 
			
		||||
    parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") | 
				 | 
			
		||||
    parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") | 
				 | 
			
		||||
    parser.add_argument("--target_seq_length", type=int, default=15) | 
				 | 
			
		||||
    parser.add_argument("--cond_text", type=str, default="Image of a") | 
				 | 
			
		||||
    parser.add_argument("--reset_context_delta", action="store_true", | 
				 | 
			
		||||
                        help="Should we reset the context at each token gen") | 
				 | 
			
		||||
    parser.add_argument("--num_iterations", type=int, default=5) | 
				 | 
			
		||||
    parser.add_argument("--clip_loss_temperature", type=float, default=0.01) | 
				 | 
			
		||||
    parser.add_argument("--clip_scale", type=float, default=1) | 
				 | 
			
		||||
    parser.add_argument("--ce_scale", type=float, default=0.2) | 
				 | 
			
		||||
    parser.add_argument("--stepsize", type=float, default=0.3) | 
				 | 
			
		||||
    parser.add_argument("--grad_norm_factor", type=float, default=0.95) | 
				 | 
			
		||||
    parser.add_argument("--fusion_factor", type=float, default=0.95) | 
				 | 
			
		||||
    parser.add_argument("--repetition_penalty", type=float, default=1) | 
				 | 
			
		||||
    parser.add_argument("--end_token", type=str, default=".", help="Token to end text") | 
				 | 
			
		||||
    parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") | 
				 | 
			
		||||
    parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") | 
				 | 
			
		||||
    parser.add_argument("--beam_size", type=int, default=5) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    args = parser.parse_args('') | 
				 | 
			
		||||
    return args | 
				 | 
			
		||||
@ -1,3 +0,0 @@ | 
			
		|||||
ftfy  | 
				 | 
			
		||||
regex  | 
				 | 
			
		||||
tqdm | 
				 | 
			
		||||
@ -1,131 +0,0 @@ | 
			
		|||||
import argparse | 
				 | 
			
		||||
import ipdb | 
				 | 
			
		||||
from tqdm import tqdm | 
				 | 
			
		||||
import progressbar | 
				 | 
			
		||||
import torch | 
				 | 
			
		||||
import ipdb | 
				 | 
			
		||||
import clip | 
				 | 
			
		||||
from model.ZeroCLIP import CLIPTextGenerator | 
				 | 
			
		||||
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def get_args(): | 
				 | 
			
		||||
    parser = argparse.ArgumentParser() | 
				 | 
			
		||||
     | 
				 | 
			
		||||
    parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") | 
				 | 
			
		||||
    parser.add_argument("--test_path", type=str) | 
				 | 
			
		||||
    parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") | 
				 | 
			
		||||
    parser.add_argument("--save_name", type=str, help="the name of the saved file") | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0) | 
				 | 
			
		||||
    parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") | 
				 | 
			
		||||
    parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") | 
				 | 
			
		||||
    parser.add_argument("--target_seq_length", type=int, default=15) | 
				 | 
			
		||||
    parser.add_argument("--cond_text", type=str, default="Image of a") | 
				 | 
			
		||||
    parser.add_argument("--reset_context_delta", action="store_true", | 
				 | 
			
		||||
                        help="Should we reset the context at each token gen") | 
				 | 
			
		||||
    parser.add_argument("--num_iterations", type=int, default=5) | 
				 | 
			
		||||
    parser.add_argument("--clip_loss_temperature", type=float, default=0.01) | 
				 | 
			
		||||
    parser.add_argument("--clip_scale", type=float, default=1) | 
				 | 
			
		||||
    parser.add_argument("--ce_scale", type=float, default=0.2) | 
				 | 
			
		||||
    parser.add_argument("--stepsize", type=float, default=0.3) | 
				 | 
			
		||||
    parser.add_argument("--grad_norm_factor", type=float, default=0.9) | 
				 | 
			
		||||
    parser.add_argument("--fusion_factor", type=float, default=0.99) | 
				 | 
			
		||||
    parser.add_argument("--repetition_penalty", type=float, default=1) | 
				 | 
			
		||||
    parser.add_argument("--end_token", type=str, default=".", help="Token to end text") | 
				 | 
			
		||||
    parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") | 
				 | 
			
		||||
    parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") | 
				 | 
			
		||||
    parser.add_argument("--beam_size", type=int, default=1) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument("--multi_gpu", action="store_true") | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument('--run_type', | 
				 | 
			
		||||
                        default='caption', | 
				 | 
			
		||||
                        nargs='?', | 
				 | 
			
		||||
                        choices=['caption', 'arithmetics']) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', | 
				 | 
			
		||||
                        help="Path to image for captioning") | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    parser.add_argument("--arithmetics_imgs", nargs="+", | 
				 | 
			
		||||
                        default=['example_images/arithmetics/woman2.jpg', | 
				 | 
			
		||||
                                 'example_images/arithmetics/king2.jpg', | 
				 | 
			
		||||
                                 'example_images/arithmetics/man2.jpg']) | 
				 | 
			
		||||
    parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    args = parser.parse_args() | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    return args | 
				 | 
			
		||||
 | 
				 | 
			
		||||
def run(args, text_generator, img_path): | 
				 | 
			
		||||
    image_features = text_generator.get_img_feature([img_path], None) | 
				 | 
			
		||||
    captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] | 
				 | 
			
		||||
    encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] | 
				 | 
			
		||||
    best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() | 
				 | 
			
		||||
    return captions | 
				 | 
			
		||||
 | 
				 | 
			
		||||
 | 
				 | 
			
		||||
if __name__ == '__main__': | 
				 | 
			
		||||
    if torch.cuda.is_available(): | 
				 | 
			
		||||
        print ('Cuda is available.') | 
				 | 
			
		||||
    cuda_available = torch.cuda.is_available() | 
				 | 
			
		||||
    args = get_args() | 
				 | 
			
		||||
    device = torch.device('cuda') | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    save_path_prefix = args.save_path_prefix | 
				 | 
			
		||||
    import os | 
				 | 
			
		||||
    if os.path.exists(save_path_prefix): | 
				 | 
			
		||||
        pass | 
				 | 
			
		||||
    else: # recursively construct directory | 
				 | 
			
		||||
        os.makedirs(save_path_prefix, exist_ok=True) | 
				 | 
			
		||||
    # parse save name | 
				 | 
			
		||||
    save_name = args.save_name | 
				 | 
			
		||||
    full_save_path = save_path_prefix + '/' + save_name | 
				 | 
			
		||||
    print ('full save path is {}'.format(full_save_path)) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    print ('Loading data...') | 
				 | 
			
		||||
    import json | 
				 | 
			
		||||
    with open(args.test_path) as f: | 
				 | 
			
		||||
        item_list = json.load(f) | 
				 | 
			
		||||
    print ('Data loaded.') | 
				 | 
			
		||||
    print ('Number of test instances is {}'.format(len(item_list))) | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    # ZeroCap generator | 
				 | 
			
		||||
    text_generator = CLIPTextGenerator(**vars(args)) | 
				 | 
			
		||||
     | 
				 | 
			
		||||
    result_list = [] | 
				 | 
			
		||||
    invalid_num = 0 | 
				 | 
			
		||||
    print ('----------------------------------------------------------------') | 
				 | 
			
		||||
    test_num = len(item_list) | 
				 | 
			
		||||
    #test_num = 10 | 
				 | 
			
		||||
    print ('Number of inference instances is {}'.format(test_num)) | 
				 | 
			
		||||
    p = progressbar.ProgressBar(test_num) | 
				 | 
			
		||||
    p.start() | 
				 | 
			
		||||
    for p_idx in tqdm(range(test_num)): | 
				 | 
			
		||||
        p.update(p_idx) | 
				 | 
			
		||||
        one_test_dict = item_list[p_idx] | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        one_res_dict = { | 
				 | 
			
		||||
            'split':one_test_dict['split'], | 
				 | 
			
		||||
            'image_name':one_test_dict['image_name'], | 
				 | 
			
		||||
            #'file_path':one_test_dict['file_path'], | 
				 | 
			
		||||
            'captions':one_test_dict['captions'] | 
				 | 
			
		||||
        } | 
				 | 
			
		||||
 | 
				 | 
			
		||||
        image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] | 
				 | 
			
		||||
        try: | 
				 | 
			
		||||
            output_text = run(args, text_generator, img_path=image_full_path) | 
				 | 
			
		||||
            one_res_dict['prediction'] = output_text[0] | 
				 | 
			
		||||
            result_list.append(one_res_dict) | 
				 | 
			
		||||
        except Exception as error: | 
				 | 
			
		||||
            print(f'[!] ERROR:', error) | 
				 | 
			
		||||
            invalid_num += 1 | 
				 | 
			
		||||
            print ('invalid number is {}'.format(invalid_num)) | 
				 | 
			
		||||
            continue | 
				 | 
			
		||||
    p.finish() | 
				 | 
			
		||||
    print ('Inference completed!') | 
				 | 
			
		||||
 | 
				 | 
			
		||||
    import json | 
				 | 
			
		||||
    with open(full_save_path, 'w') as outfile: | 
				 | 
			
		||||
        json.dump(result_list, outfile, indent=4) | 
				 | 
			
		||||
@ -1,19 +0,0 @@ | 
			
		|||||
import os | 
				 | 
			
		||||
 | 
				 | 
			
		||||
import pkg_resources | 
				 | 
			
		||||
from setuptools import setup, find_packages | 
				 | 
			
		||||
 | 
				 | 
			
		||||
setup( | 
				 | 
			
		||||
    name="zero-shot-image-to-text", | 
				 | 
			
		||||
    py_modules=["zero-shot-image-to-text"], | 
				 | 
			
		||||
    version="1.0", | 
				 | 
			
		||||
    description="", | 
				 | 
			
		||||
    packages=find_packages(), | 
				 | 
			
		||||
    install_requires=[ | 
				 | 
			
		||||
        str(r) | 
				 | 
			
		||||
        for r in pkg_resources.parse_requirements( | 
				 | 
			
		||||
            open(os.path.join(os.path.dirname(__file__), "requirements.txt")) | 
				 | 
			
		||||
        ) | 
				 | 
			
		||||
    ], | 
				 | 
			
		||||
    include_package_data=True | 
				 | 
			
		||||
) | 
				 | 
			
		||||
					Loading…
					
					
				
		Reference in new issue