diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..0f450d5
Binary files /dev/null and b/.DS_Store differ
diff --git a/.README.md.swp b/.README.md.swp
new file mode 100644
index 0000000..b3d1b72
Binary files /dev/null and b/.README.md.swp differ
diff --git a/README.md b/README.md
index 4b22684..6a4a1b3 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,81 @@
-# magic
+# Image Captioning with MAGIC
+
+*author: David Wang*
+
+
+
+
+
+## Description
+
+This operator generates the caption with [MAGIC](https://arxiv.org/abs/2205.02655) which describes the content of the given image. MAGIC is a simple yet efficient plug-and-play framework, which directly combines an off-the-shelf LM (i.e., GPT-2) and an image-text matching model (i.e., CLIP) for image-grounded text generation. During decoding, MAGIC influences the generation of the LM by introducing a CLIP-induced score, called magic score, which regularizes the generated result to be semantically related to a given image while being coherent to the previously generated context. This is an adaptation from [yxuansu / MAGIC](https://github.com/yxuansu/MAGIC).
+
+
+
+
+
+## Code Example
+
+Load an image from path './image.jpg' to generate the caption.
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+towhee.glob('./image.jpg') \
+ .image_decode() \
+ .image_captioning.magic(model_name='expansionnet_rf') \
+ .show()
+```
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.glob['path']('./image.jpg') \
+ .image_decode['path', 'img']() \
+ .image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \
+ .select['img', 'text']() \
+ .show()
+```
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***expansionnet_v2(model_name)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ The model name of MAGIC. Supported model names:
+- magic_mscoco
+
+
+
+## Interface
+
+An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
+
+
+**Parameters:**
+
+ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
+
+ The image to generate embedding.
+
+
+
+**Returns:** *str*
+
+ The caption generated by model.
diff --git a/language_model/README.md b/language_model/README.md
new file mode 100644
index 0000000..da8f45e
--- /dev/null
+++ b/language_model/README.md
@@ -0,0 +1,167 @@
+## Unsupervised Domain Adaptation of Language Model
+****
+### Catalogue:
+* 1. MSCOCO Benchmark
+ * 1.1. MSCOCO Data Preparation
+ * 1.2. Unsupervised Domain Adaptation on MSCOCO
+* 2. Flickr30k Benchmark
+ * 2.1. Flickr30k Data Preparation
+ * 2.2. Unsupervised Domain Adaptation on Flickr30k
+* 3. Unsupervised Baselines
+ * 3.1. Contrastive Search
+ * 3.2. Top-k Sampling
+ * 3.3. Nucleus Sampling
+
+****
+
+
+#### 1. MSCOCO Benchmark:
+
+We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark.
+
+
+
+##### 1.1. MSCOCO Data Preparation:
+
+To prepare the MSCOCO benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark).
+
+
+
+##### 1.2.Unsupervised Domain Adaptation on MSCOCO:
+After preparing the MSCOCO data, run the following command to train the language model.
+```yaml
+chmod +x ./train_mscoco.sh
+./train_mscoco.sh
+```
+The arguments are as follows:
+* `--model_name`: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large).
+* `--train_path`: The file path of training set.
+* `--dev_path`: The file path of validation set.
+* `--test_path`: The file path of test set.
+* `--add_eos_token_to_data`: Whether adding an eos token at the end of text sequence.
+* `--margin`: The contrastive margin $\rho$.
+* `--max_len`: The maximum length of training samples.
+* `--number_of_gpu`: The number of available GPUs.
+* `--batch_size_per_gpu`: The batch size for each GPU.
+* `--gradient_accumulation_steps`: How many forward computations between two gradient updates.
+* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu.
+* `--total_steps`: The number of total gradient update steps.
+* `--print_every`: Have many steps to show the intermediate results.
+* `--save_every`: How many steps to save one checkpoint.
+* `--learning_rate`: The learning rate.
+* `--save_path_prefix`: Where to save the checkpoints.
+
+****
+
+
+#### 2. Flickr30k Benchmark:
+
+We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark.
+
+
+
+##### 2.1. Flickr30k Data Preparation:
+
+To prepare the Flickr30k benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark).
+
+
+
+##### 2.2. Unsupervised Domain Adaptation on Flickr30k:
+After preparing the Flickr30k data, run the following command to train the language model.
+```yaml
+chmod +x ./train_flickr30k.sh
+./train_flickr30k.sh
+```
+
+****
+
+
+#### 3. Unsupervised Baselines:
+
+Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are **unsupervised** as the language model is a text-only model and does not take image as input.
+
+```python
+# first, load the language model
+import torch
+from simctg import SimCTG
+sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
+# we use the language model adapted on MSCOCO as an example.
+language_model_name = r'cambridgeltl/magic_mscoco'
+generation_model = SimCTG(language_model_name, sos_token, pad_token)
+generation_model.eval()
+
+# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token.
+tokens = generation_model.tokenizer.tokenize(sos_token)
+input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens)
+input_ids = torch.LongTensor(input_ids).view(1,-1)
+```
+
+
+
+##### 3.1. Contrastive Search :
+```python
+'''
+ use contrastive search to generate the result.
+ note that, contrastive search is a deterministic decoding method, thus the generated text is always the same.
+'''
+beam_width, alpha, decoding_len = 45, 0.1, 16
+output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len)
+print (output_text)
+'''
+ A man is riding a skateboard down a street.
+'''
+```
+The arguments are as follows:
+* `--input_ids`: The id of the start of sentence token.
+* `--beam_width`: k in the contrastive search.
+* `--alpha`: alpha in the contrastive search.
+* `--decoding_len`: Number of tokens to generate.
+
+
+
+##### 3.2. Top-k Sampling :
+```python
+'''
+ use top-k sampling to generate the result.
+ note that, the this method is a stochastic method, thus the generated text is always different.
+'''
+top_k, decoding_len = 40, 16
+output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len)
+print (output_text)
+'''
+ some very different types of vases with flowers together
+'''
+```
+The arguments are as follows:
+* `--input_ids`: The id of the start of sentence token.
+* `--k`: The k in top-k sampling.
+* `--decoding_len`: Number of tokens to generate.
+
+
+
+##### 3.3. Nucleus Sampling :
+```python
+'''
+ use nucleus sampling to generate the result.
+ note that, the this method is a stochastic method, thus the generated text is always different.
+'''
+nucleus_p, decoding_len = 0.95, 16
+output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len)
+print (output_text)
+'''
+ Two young girls enjoying a hot dog hot dog bun.
+'''
+```
+The arguments are as follows:
+* `--input_ids`: The id of the start of sentence token.
+* `--nucleus_p`: The probability in nucleus sampling.
+* `--decoding_len`: Number of tokens to generate.
+
+
+
+
+
+
+
+
+
diff --git a/language_model/dataclass.py b/language_model/dataclass.py
new file mode 100644
index 0000000..3786dc5
--- /dev/null
+++ b/language_model/dataclass.py
@@ -0,0 +1,157 @@
+import json
+import random
+import torch
+import numpy as np
+import progressbar
+from torch.nn.utils import rnn
+
+class Data:
+ def __init__(self, model_name, train_path, dev_path, test_path, max_len,
+ sos_token, pad_token, add_eos_token_to_data):
+ '''
+ model_name: gpt2
+ train_path: training data path
+ dev_path: validation data path
+ test_path: test data path
+ max_len: maximum length for training sequences
+ sos_token: initialized sos token <-start_of_text->
+ pad_token: used to pad the sequences <-pad->
+ add_eos_token_to_data: whether we want to the model learn to generate eos token;
+ if so, the model could automatically stop generation by generating eos token
+ '''
+ from transformers import GPT2TokenizerFast
+ self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
+ self.sos_token, self.sos_token_id = self.add_special_token(sos_token)
+ print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id))
+ self.pad_token, self.pad_token_id = self.add_special_token(pad_token)
+ print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id))
+ self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id
+ print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id))
+ self.add_eos_token_to_data = add_eos_token_to_data
+
+ self.max_len = max_len
+ self.train_token_list, self.train_token_id_list = self.process_one_file(train_path)
+ self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path)
+ self.test_token_list, self.test_token_id_list = self.process_one_file(test_path)
+ self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \
+ len(self.test_token_list)
+ print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num))
+
+ self.train_idx_list = [i for i in range(self.train_num)]
+ random.shuffle(self.train_idx_list)
+ self.dev_idx_list = [j for j in range(self.dev_num)]
+ self.test_idx_list = [j for j in range(self.test_num)]
+ self.dev_current_idx, self.test_current_idx = 0, 0
+
+ def add_special_token(self, special_token):
+ if special_token in self.tokenizer.vocab:
+ print (special_token + ' token exists.')
+ else:
+ print ('Add token to the tokenizer.')
+ print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
+ self.tokenizer.add_tokens([special_token])
+ print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
+ assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1
+ special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0]
+ return special_token, special_token_id
+
+ def process_one_file(self, path):
+ print ('Processing {}'.format(path))
+ with open(path) as f:
+ item_list = json.load(f)
+ lines = []
+ for item in item_list:
+ captions_list = item['captions']
+ for one_caption in captions_list:
+ lines.append(one_caption.strip())
+
+ res_token_list, res_token_id_list = [], []
+ n = len(lines)
+ p = progressbar.ProgressBar(n)
+ p.start()
+ for i in range(n):
+ p.update(i)
+ text = lines[i].strip('\n')
+ self.process_one_text(text, res_token_list, res_token_id_list)
+ p.finish()
+ print ('{} processed!'.format(path))
+ return res_token_list, res_token_id_list
+
+ def process_one_text(self, text, res_token_list, res_token_id_list):
+ tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True)
+ if len(tokens) <= 1: # filter out too short sequence
+ return
+ tokens = [self.sos_token] + tokens[:self.max_len]
+ if self.add_eos_token_to_data:
+ tokens = tokens + [self.eos_token]
+ token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
+ res_token_list.append(tokens)
+ res_token_id_list.append(token_ids)
+ return
+
+ def pad_batch(self, batch_id_list):
+ batch_id_list = [torch.LongTensor(item) for item in batch_id_list]
+ batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id)
+ batch_mask = torch.ones_like(batch_tensor)
+ batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor)
+ return batch_tensor, batch_mask
+
+ def process_output(self, batch_tgt_id_list):
+ batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list]
+ batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence
+ batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone()
+ batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone()
+ return batch_tgt_input_tensor, batch_tgt_output_tensor
+
+ def parse_batch(self, batch_id_list):
+ batch_input, batch_labels = self.process_output(batch_id_list)
+ batch_labels[batch_labels[:, :] == self.pad_token_id] = -100
+ return batch_input, batch_labels
+
+ def get_next_train_batch(self, batch_size):
+ batch_idx_list = random.sample(self.train_idx_list, batch_size)
+ batch_id_list, batch_token_list = [], []
+
+ for idx in batch_idx_list:
+ batch_id_list.append(self.train_token_id_list[idx])
+ batch_token_list.append(self.train_token_list[idx])
+ batch_input_tensor, batch_labels = self.parse_batch(batch_id_list)
+ return batch_input_tensor, batch_labels, batch_token_list
+
+ def get_next_validation_batch(self, batch_size, mode):
+ batch_id_list, batch_token_list = [], []
+ if mode == 'dev':
+ curr_select_idx, instance_num = self.dev_current_idx, self.dev_num
+ tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list
+ elif mode == 'test':
+ curr_select_idx, instance_num = self.test_current_idx, self.test_num
+ tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list
+ else:
+ raise Exception('Wrong Validation Mode!!!')
+
+ if curr_select_idx + batch_size < instance_num:
+ for i in range(batch_size):
+ curr_idx = curr_select_idx + i
+ batch_id_list.append(tgt_token_id_list[curr_idx])
+ batch_token_list.append(tgt_token_list[curr_idx])
+ if mode == 'dev':
+ self.dev_current_idx += batch_size
+ else:
+ self.test_current_idx += batch_size
+ else:
+ for i in range(batch_size):
+ curr_idx = curr_select_idx + i
+ if curr_idx > instance_num - 1:
+ curr_idx = 0
+ if mode == 'dev':
+ self.dev_current_idx = 0
+ else:
+ self.test_current_idx = 0
+ batch_id_list.append(tgt_token_id_list[curr_idx])
+ batch_token_list.append(tgt_token_list[curr_idx])
+ if mode == 'dev':
+ self.dev_current_idx = 0
+ else:
+ self.test_current_idx = 0
+ batch_input_tensor, batch_labels = self.parse_batch(batch_id_list)
+ return batch_input_tensor, batch_labels, batch_token_list
diff --git a/language_model/loss_func.py b/language_model/loss_func.py
new file mode 100644
index 0000000..96a4243
--- /dev/null
+++ b/language_model/loss_func.py
@@ -0,0 +1,80 @@
+import torch
+
+def compute_valid_token_num(valid_len_list):
+ res = 0
+ for one_len in valid_len_list:
+ res += one_len * (one_len - 1)
+ return res
+
+def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0):
+ '''
+ prefix_len: the length of prefix that we do not want to compute CL loss for.
+
+ (1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4),
+ then the loss padding matrix looks like
+ [0., 1., 1., 1.],
+ [1., 0., 1., 1.],
+ [1., 1., 0., 1.],
+ [1., 1., 1., 0.]
+
+ (2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3),
+ then the loss padding matrix looks like
+ [0., 1., 1., 0.],
+ [1., 0., 1., 0.],
+ [1., 1., 0., 0.],
+ [0., 0., 0., 0.]
+ '''
+ res_list = []
+ base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen)
+ base_mask = base_mask.type(torch.FloatTensor)
+ bsz = len(valid_len_list)
+ for i in range(bsz):
+ one_base_mask = base_mask.clone()
+ one_valid_len = valid_len_list[i]
+ one_base_mask[:,one_valid_len:] = 0.
+ one_base_mask[one_valid_len:, :] = 0.
+ if prefix_len > 0:
+ one_base_mask[:prefix_len, :prefix_len] = 0.
+ res_list.append(one_base_mask)
+ res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list)
+ #print (res_mask)
+ assert res_mask.size() == torch.Size([bsz, seqlen, seqlen])
+ return res_mask
+
+def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0):
+ '''
+ margin: predefined margin to push similarity score away
+ score_matrix: bsz x seqlen x seqlen
+ input_ids: bsz x seqlen
+ pad_token_id: indicating which tokens are padding token
+ '''
+ bsz, seqlen, _ = score_matrix.size()
+ gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen
+ gold_score = torch.unsqueeze(gold_score, -1)
+ assert gold_score.size() == torch.Size([bsz, seqlen, 1])
+ difference_matrix = gold_score - score_matrix
+ assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen])
+ loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen
+ loss_matrix = torch.nn.functional.relu(loss_matrix)
+
+ ### input mask
+ input_mask = torch.ones_like(input_ids).type(torch.FloatTensor)
+ if loss_matrix.is_cuda:
+ input_mask = input_mask.cuda(loss_matrix.get_device())
+ input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0)
+
+ if loss_matrix.is_cuda:
+ input_mask = input_mask.cuda(loss_matrix.get_device())
+
+ valid_len_list = torch.sum(input_mask, dim = -1).tolist()
+ loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len)
+ if score_matrix.is_cuda:
+ loss_mask = loss_mask.cuda(score_matrix.get_device())
+ masked_loss_matrix = loss_matrix * loss_mask
+
+ loss_matrix = torch.sum(masked_loss_matrix, dim = -1)
+ assert loss_matrix.size() == input_ids.size()
+ loss_matrix = loss_matrix * input_mask
+ cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask)
+ return cl_loss
+
\ No newline at end of file
diff --git a/language_model/simctg.py b/language_model/simctg.py
new file mode 100644
index 0000000..25599e9
--- /dev/null
+++ b/language_model/simctg.py
@@ -0,0 +1,233 @@
+import os
+import sys
+import operator
+from tqdm import tqdm
+from operator import itemgetter
+import torch
+from torch import nn
+import random
+import argparse
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import CrossEntropyLoss
+from loss_func import contrastive_loss
+from utlis import PlugAndPlayContrastiveDecodingOneStepFast
+
+import seaborn as sns
+import matplotlib.pyplot as plt
+import pandas as pd
+import datetime
+
+train_fct = CrossEntropyLoss()
+val_fct = CrossEntropyLoss(reduction='none')
+class SimCTG(nn.Module):
+ def __init__(self, model_name, sos_token, pad_token):
+ super(SimCTG, self).__init__()
+ from transformers import AutoTokenizer, GPT2LMHeadModel
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.sos_token, self.sos_token_id = self.add_special_token(sos_token)
+ print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id))
+ self.pad_token, self.pad_token_id = self.add_special_token(pad_token)
+ print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id))
+ self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id
+ print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id))
+ self.model = GPT2LMHeadModel.from_pretrained(model_name)
+ self.vocab_size = len(self.tokenizer)
+ print ('Resizing model embedding...')
+ self.model.resize_token_embeddings(len(self.tokenizer))
+ print ('Model embedding resized!')
+ self.embed_dim = self.model.config.hidden_size
+
+ def add_special_token(self, special_token):
+ if special_token in self.tokenizer.vocab:
+ print (special_token + ' token exists.')
+ else:
+ print ('Add token to the tokenizer.')
+ print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
+ self.tokenizer.add_tokens([special_token])
+ print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
+ assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1
+ special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0]
+ return special_token, special_token_id
+
+ def compute_logits_and_hidden_states(self, input_ids):
+ # used for advanced decoding
+ # input_ids: 1 x seqlen
+ outputs = self.model(input_ids=input_ids, output_hidden_states=True)
+ last_hidden_states = outputs.hidden_states[-1]
+ logits = outputs.logits
+ return last_hidden_states, logits
+
+ def forward(self, input_ids, labels, margin):
+ bsz, seqlen = input_ids.size()
+ outputs = self.model(input_ids=input_ids, output_hidden_states=True)
+ logits = outputs.logits
+ assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size])
+ last_hidden_states = outputs.hidden_states[-1]
+ assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim])
+ mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1))
+
+ norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True)
+ cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2))
+ assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen])
+ cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0)
+ return mle_loss, cl_loss
+
+ def eval_loss(self, input_ids, labels):
+ bsz, seqlen = input_ids.size()
+ outputs = self.model(input_ids=input_ids, output_hidden_states=True)
+ logits = outputs.logits
+ assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size])
+ last_hidden_states = outputs.hidden_states[-1]
+ assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim])
+ mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1))
+ assert mle_loss.size() == torch.Size([bsz * seqlen])
+ mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0)
+ mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0)
+ # sum
+ mle_loss_sum = torch.sum(mle_loss)
+ token_num_sum = torch.sum(mask)
+ return mle_loss_sum, token_num_sum
+
+ def save_model(self, ckpt_save_path):
+ import os
+ if os.path.exists(ckpt_save_path):
+ pass
+ else: # recursively construct directory
+ os.makedirs(ckpt_save_path, exist_ok=True)
+ # save model
+ self.model.save_pretrained(ckpt_save_path)
+ # save tokenizer
+ self.tokenizer.save_pretrained(ckpt_save_path)
+
+ def parse_sentences(self, text, num_of_sentences_to_keep):
+ item_list = text.split('.')
+ res_list = item_list[:num_of_sentences_to_keep]
+ if len(item_list) > num_of_sentences_to_keep:
+ res_text = '.'.join(res_list).strip('.') + '.'
+ else:
+ res_text = '.'.join(res_list).strip('.').strip()
+ return res_text
+
+ def parse_generated_result(self, output, num_of_sentences_to_keep):
+ output_text = self.tokenizer.decode(output)
+ item_list = output_text.split(self.eos_token)
+ full_text = self.eos_token.join(item_list[:2]).strip()
+ full_text = self.parse_sentences(full_text, num_of_sentences_to_keep)
+ generated_text = item_list[1].strip()
+ generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep)
+ return full_text, generated_text
+
+ # decoding functions
+ # ------------------------------------------------------- #
+
+ def parse_output_token_list(self, output):
+ output = output.tolist()
+ res_list = []
+ for token_id in output:
+ if token_id == self.sos_token_id:
+ continue
+ elif token_id == self.eos_token_id:
+ break
+ else:
+ res_list.append(token_id)
+ text = self.tokenizer.decode(res_list).strip()
+ return ' '.join(text.split()).strip()
+
+ @torch.no_grad()
+ def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip,
+ clip_text_max_len):#, add_token_level_score=False):
+ prefix_len = input_ids.size()[1]
+ #from utlis import PlugAndPlayContrastiveDecodingOneStepFast
+ past_key_values, last_hidden_states, logits = None, None, None
+ generated = [item for item in input_ids.tolist()]
+ input_ids_for_class = input_ids.clone()
+
+ image_embeds = clip.compute_image_representation_from_image_instance(image_instance)
+
+ start_time = datetime.datetime.now()
+
+ # the maximum supported length of generation for SimCTG is 256
+ # to support longer generated length, you can re-train the SimCTG model with longer sequences
+ decoding_len = decoding_len - prefix_len
+ for step in range(decoding_len):
+ input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \
+ PlugAndPlayContrastiveDecodingOneStepFast(
+ self.model,
+ input_ids,
+ prefix_len,
+ beam_width,
+ alpha,
+ beta,
+ self.tokenizer,
+ image_embeds,
+ clip,
+ clip_text_max_len,
+ past_key_values,
+ last_hidden_states,
+ logits,
+ first_step=step==0,
+ input_ids_for_class=input_ids_for_class,
+ )
+ end_time = datetime.datetime.now()
+ time_diff = (end_time - start_time)
+ execution_time = time_diff.total_seconds() * 1000
+ return self.parse_output_token_list(input_ids_for_class[0])
+
+ def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len):
+ '''
+ input_ids: prefix input; 1 x prefix_len
+ decoding_len: how many tokens to generate
+ beam_width: size of candidate pool during decoding
+ alpha: regulates importance of model confidence and degeneration penalty
+ '''
+ self.model.eval()
+ #from utlis import ContrastiveDecodingOneStepFast
+ # sanity check
+ assert alpha >= 0. and alpha <= 1.0
+
+ # fast mode
+ prefix_len = input_ids.size()[1]
+ batch_size, seqlen = input_ids.size()
+ #generated = [[] for _ in range(batch_size)]
+ generated = [item for item in input_ids.tolist()]
+ past_key_values = None
+ last_hidden_states = None
+ logits = None
+ decoding_len = decoding_len - prefix_len
+ for step in range(decoding_len):
+ input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
+ self.model,
+ input_ids,
+ beam_width,
+ alpha,
+ past_key_values,
+ last_hidden_states,
+ self.tokenizer,
+ logits,
+ first_step=step == 0,
+ )
+ tokens = input_ids.squeeze(dim=-1).tolist()
+ for idx, t in enumerate(tokens):
+ generated[idx].append(t)
+ return self.parse_output_token_list(torch.LongTensor(generated[0]))
+
+ def top_k_sampling(self, input_ids, k, decoding_len):
+ _, prefix_len = input_ids.size()
+ output = self.model.generate(
+ input_ids,
+ do_sample=True,
+ max_length=decoding_len,
+ top_p=1.0,
+ top_k=k)
+ return self.parse_output_token_list(output[0])
+
+ def nucleus_sampling(self, input_ids, nucleus_p, decoding_len):
+ _, prefix_len = input_ids.size()
+ output = self.model.generate(
+ input_ids,
+ do_sample=True,
+ max_length=decoding_len,
+ top_p=nucleus_p,
+ top_k=0)
+ return self.parse_output_token_list(output[0])
diff --git a/language_model/train.py b/language_model/train.py
new file mode 100644
index 0000000..f401e44
--- /dev/null
+++ b/language_model/train.py
@@ -0,0 +1,107 @@
+# coding=utf-8
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.multiprocessing as mp
+import argparse, os
+import random
+import numpy as np
+import time
+import logging
+import progressbar
+
+import logging
+logging.getLogger('transformers.generation_utils').disabled = True
+
+def parse_config():
+ parser = argparse.ArgumentParser()
+ # data configuration
+ parser.add_argument("--model_name", type=str, default='gpt2')
+ parser.add_argument("--train_path", type=str)
+ parser.add_argument("--dev_path", type=str)
+ parser.add_argument("--test_path", type=str)
+ parser.add_argument("--max_len", type=int)
+ parser.add_argument("--add_eos_token_to_data", type=str)
+ # mini-batch training configuration
+ parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.")
+ parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.')
+ parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.")
+ parser.add_argument("--effective_batch_size", type=int,
+ help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps")
+ # pre-training configuration
+ parser.add_argument("--total_steps", type=int,
+ help="total effective training steps")
+ parser.add_argument("--print_every", type=int,
+ help="how many update steps to print one intermediate result")
+ parser.add_argument("--save_every", type=int,
+ help="how many update steps to save one model")
+ # learning configuration
+ parser.add_argument("--learning_rate", type=float, default=2e-5)
+ parser.add_argument("--margin", type=float)
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
+ parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.")
+ return parser.parse_args()
+
+def load_previous_best_model(path):
+ import os
+ filenames = os.listdir(path)
+ for file in filenames:
+ if file.startswith('training_step'):
+ return path + '/' + file
+ raise Exception('No best model found!')
+
+import argparse
+if __name__ == '__main__':
+ if torch.cuda.is_available():
+ print ('Cuda is available.')
+ cuda_available = torch.cuda.is_available()
+ multi_gpu_training = False
+ if cuda_available:
+ if torch.cuda.device_count() > 1:
+ multi_gpu_training = True
+ print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count()))
+ else:
+ print ('Using single GPU training.')
+ else:
+ pass
+ args = parse_config()
+ device = torch.device('cuda')
+ model_name = args.model_name
+
+ sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
+ add_eos_token_to_data = args.add_eos_token_to_data
+ if add_eos_token_to_data == 'True':
+ add_eos_token_to_data = True
+ print ('Add eos token to data!')
+ elif add_eos_token_to_data == 'False':
+ add_eos_token_to_data = False
+ print ('Do not add eos token to data!')
+ else:
+ raise Exception('Wrong eos configuration for data!!!')
+ print ('Loading data...')
+ from dataclass import Data
+ data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len,
+ sos_token, pad_token, add_eos_token_to_data)
+ print ('Data loaded.')
+
+ from trainer import model_training
+ print ('############################################################')
+ print ('Start Training...')
+ from simctg import SimCTG
+ print ('Initializaing SimCTG model...')
+ model = SimCTG(model_name, sos_token, pad_token)
+ if cuda_available:
+ if multi_gpu_training:
+ model = nn.DataParallel(model) # multi-gpu training
+ else:
+ pass
+ model = model.to(device)
+ else:
+ pass
+ print ('Model loaded')
+ total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every
+ ckpt_save_path = args.save_path_prefix
+ model = model_training(args, data, model, total_steps, print_every, save_every,
+ ckpt_save_path, cuda_available, device)
+ print ('Training stage completed!')
+ print ('############################################################')
diff --git a/language_model/train_flickr30k.sh b/language_model/train_flickr30k.sh
new file mode 100644
index 0000000..f76d082
--- /dev/null
+++ b/language_model/train_flickr30k.sh
@@ -0,0 +1,17 @@
+CUDA_VISIBLE_DEVICES=0 python train.py\
+ --model_name gpt2\
+ --train_path ../data/flickr30k/flickr30k_train.json\
+ --dev_path ../data/flickr30k/flickr30k_val.json\
+ --test_path ../data/flickr30k/flickr30k_test.json\
+ --add_eos_token_to_data True\
+ --margin 0.5\
+ --max_len 64\
+ --number_of_gpu 1\
+ --batch_size_per_gpu 32\
+ --gradient_accumulation_steps 4\
+ --effective_batch_size 128\
+ --total_steps 10000\
+ --print_every 50\
+ --save_every 250\
+ --learning_rate 2e-5\
+ --save_path_prefix ./magic_flickr30k/
\ No newline at end of file
diff --git a/language_model/train_mscoco.sh b/language_model/train_mscoco.sh
new file mode 100644
index 0000000..7ab7582
--- /dev/null
+++ b/language_model/train_mscoco.sh
@@ -0,0 +1,17 @@
+CUDA_VISIBLE_DEVICES=0 python train.py\
+ --model_name gpt2\
+ --train_path ../data/mscoco/mscoco_train.json\
+ --dev_path ../data/mscoco/mscoco_val.json\
+ --test_path ../data/mscoco/mscoco_test.json\
+ --add_eos_token_to_data True\
+ --margin 0.5\
+ --max_len 64\
+ --number_of_gpu 1\
+ --batch_size_per_gpu 32\
+ --gradient_accumulation_steps 4\
+ --effective_batch_size 128\
+ --total_steps 20000\
+ --print_every 100\
+ --save_every 500\
+ --learning_rate 2e-5\
+ --save_path_prefix ./magic_mscoco/
\ No newline at end of file
diff --git a/language_model/trainer.py b/language_model/trainer.py
new file mode 100644
index 0000000..e51a850
--- /dev/null
+++ b/language_model/trainer.py
@@ -0,0 +1,165 @@
+# coding=utf-8
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.multiprocessing as mp
+import argparse, os
+import random
+import numpy as np
+import time
+import logging
+import progressbar
+
+import logging
+logging.getLogger('transformers.generation_utils').disabled = True
+
+def eval_model(args, model, data, cuda_available, device):
+ dataset_batch_size = args.batch_size_per_gpu * args.number_of_gpu
+ eval_step = int(data.test_num / dataset_batch_size) + 1
+ val_loss, token_sum = 0., 0.
+ model.eval()
+ with torch.no_grad():
+ p = progressbar.ProgressBar(eval_step)
+ p.start()
+ for idx in range(eval_step):
+ p.update(idx)
+ batch_input_tensor, batch_labels, _ = \
+ data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test')
+ if cuda_available:
+ batch_input_tensor = batch_input_tensor.cuda(device)
+ batch_labels = batch_labels.cuda(device)
+ one_val_loss, one_val_token_sum = model.eval_loss(batch_input_tensor, batch_labels)
+ one_val_loss = torch.sum(one_val_loss)
+ one_val_token_sum = torch.sum(one_val_token_sum)
+ val_loss += one_val_loss.item()
+ token_sum += one_val_token_sum.item()
+ p.finish()
+ model.train()
+ val_loss = val_loss / token_sum
+ return val_loss
+
+def model_training(args, data, model, total_steps, print_every, save_every, ckpt_save_path, cuda_available, device):
+ import os
+ if os.path.exists(ckpt_save_path):
+ pass
+ else: # recursively construct directory
+ os.makedirs(ckpt_save_path, exist_ok=True)
+
+ max_save_num = 1
+
+ batch_size_per_gpu, gradient_accumulation_steps, number_of_gpu, effective_batch_size = \
+ args.batch_size_per_gpu, args.gradient_accumulation_steps, args.number_of_gpu, args.effective_batch_size
+ assert effective_batch_size == batch_size_per_gpu * gradient_accumulation_steps * number_of_gpu
+
+ warmup_steps = int(0.1 * total_steps) # 10% of training steps are used for warmup
+ print ('total training steps is {}, warmup steps is {}'.format(total_steps, warmup_steps))
+ from transformers.optimization import AdamW, get_linear_schedule_with_warmup
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate)
+ scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
+ optimizer.zero_grad()
+
+ effective_batch_acm = 0
+ all_batch_step = 1
+ print_valid, save_valid = False, False
+ train_loss, train_cl_loss, min_val_loss = 0., 0., 1e10
+ train_ave_bleu = 0.
+
+ print ('--------------------------------------------------------------------------')
+ print ('Start Training:')
+ model.train()
+ number_of_saves = 0
+
+ while effective_batch_acm < total_steps:
+ all_batch_step += 1
+ train_batch_input_tensor, train_batch_labels, _ = data.get_next_train_batch(batch_size_per_gpu * number_of_gpu)
+ if cuda_available:
+ train_batch_input_tensor = train_batch_input_tensor.cuda(device)
+ train_batch_labels = train_batch_labels.cuda(device)
+ mle_loss, cl_loss = model(train_batch_input_tensor, train_batch_labels, args.margin)
+
+ loss = mle_loss + cl_loss
+ loss = loss.mean()
+ loss.backward()
+ train_loss += mle_loss.item()
+ train_cl_loss += cl_loss.item()
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
+
+ # parameter update
+ if all_batch_step % gradient_accumulation_steps == 0:
+ optimizer.step()
+ scheduler.step()
+ optimizer.zero_grad()
+ effective_batch_acm += 1
+ print_valid, save_valid = True, True
+
+ # print intermediate result
+ if effective_batch_acm % print_every == 0 and print_valid:
+ denominator = (effective_batch_acm - (number_of_saves * save_every)) * gradient_accumulation_steps
+ one_train_loss = train_loss / denominator
+ one_train_cl_loss = train_cl_loss / denominator
+ print ('At training steps {}, training MLE loss is {}, train CL loss is {}'.format(effective_batch_acm,
+ one_train_loss, one_train_cl_loss))
+ print_valid = False
+
+ # saving result
+ if effective_batch_acm % save_every == 0 and save_valid:
+ number_of_saves += 1
+
+ save_valid = False
+ one_train_loss = train_loss / (save_every * gradient_accumulation_steps)
+ one_train_cl_loss = train_cl_loss / (save_every * gradient_accumulation_steps)
+
+ model.eval()
+ one_val_loss = eval_model(args, model, data, cuda_available, device)
+ model.train()
+
+ print ('At training steps {}, training MLE loss is {}, train CL loss is {}, validation loss is {}'.format(effective_batch_acm,
+ one_train_loss, one_train_cl_loss, one_val_loss))
+
+ train_loss, train_cl_loss = 0., 0.
+
+ if one_val_loss < min_val_loss:
+ # in finetuning stage, we always save the model
+ min_val_loss = min(one_val_loss, min_val_loss)
+ print ('Saving model...')
+ one_val_ppl = np.exp(one_val_loss)
+ one_val_ppl = round(one_val_ppl, 3)
+ save_name = 'training_step_{}_train_mle_loss_{}_train_cl_loss_{}_dev_loss_{}_dev_ppl_{}'.format(effective_batch_acm,
+ round(one_train_loss,5), round(one_train_cl_loss,5), round(one_val_loss,5), one_val_ppl)
+
+ model_save_path = ckpt_save_path + '/' + save_name
+ import os
+ if os.path.exists(model_save_path):
+ pass
+ else: # recursively construct directory
+ os.makedirs(model_save_path, exist_ok=True)
+ if cuda_available and torch.cuda.device_count() > 1:
+ model.module.save_model(model_save_path)
+ else:
+ model.save_model(model_save_path)
+ print ('Model Saved!')
+
+ # --------------------------------------------------------------------------------------------- #
+ # removing extra checkpoints...
+ import os
+ from operator import itemgetter
+ fileData = {}
+ test_output_dir = ckpt_save_path
+ for fname in os.listdir(test_output_dir):
+ if fname.startswith('training_step'):
+ fileData[fname] = os.stat(test_output_dir + '/' + fname).st_mtime
+ else:
+ pass
+ sortedFiles = sorted(fileData.items(), key=itemgetter(1))
+
+ if len(sortedFiles) < max_save_num:
+ pass
+ else:
+ delete = len(sortedFiles) - max_save_num
+ for x in range(0, delete):
+ one_folder_name = test_output_dir + '/' + sortedFiles[x][0]
+ os.system('rm -r ' + one_folder_name)
+ print ('-----------------------------------')
+ # --------------------------------------------------------------------------------------------- #
+ return model
+
diff --git a/language_model/utlis.py b/language_model/utlis.py
new file mode 100644
index 0000000..a739cd2
--- /dev/null
+++ b/language_model/utlis.py
@@ -0,0 +1,291 @@
+import sys
+import os
+import operator
+from operator import itemgetter
+import torch
+from torch import nn
+import torch.nn.functional as F
+import random
+import numpy as np
+import argparse
+import random
+
+def parse_prompt(text):
+ '''
+ process the prompt text;
+ '''
+ eos_token = '<|endoftext|>'
+ text = text.strip(eos_token).strip()
+ left_bracket_idx, right_bracket_idx = -1, -1
+ for idx in range(len(text)):
+ char = text[idx]
+ if char == '[' and left_bracket_idx == -1: # first [ is met
+ left_bracket_idx = idx
+ elif char == ']' and right_bracket_idx == -1: # first ] is met
+ right_bracket_idx = idx
+ else:
+ pass
+ res_text = ''
+ remove = False
+ if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx:
+ if right_bracket_idx - left_bracket_idx <= 6:
+ remove = True
+ else:
+ pass
+
+ for idx in range(len(text)):
+ if remove:
+ if idx >= left_bracket_idx and idx <= right_bracket_idx:
+ continue
+ else:
+ res_text += text[idx]
+ else:
+ res_text += text[idx]
+ res_text = res_text.strip()
+ res_text = ' '.join(res_text.split()).strip()
+ return res_text
+
+def typical_filtering(scores, mass, min_tokens_to_keep, filter_value):
+ # calculate entropy
+ normalized = torch.nn.functional.log_softmax(scores, dim=-1)
+ p = torch.exp(normalized)
+ ent = -(normalized * p).nansum(-1, keepdim=True)
+
+ # shift and sort
+ shifted_scores = torch.abs((-normalized) - ent)
+ sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
+ sorted_logits = scores.gather(-1, sorted_indices)
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
+
+ # Remove tokens with cumulative mass above the threshold
+ last_ind = (cumulative_probs < mass).sum(dim=1)
+ last_ind[last_ind < 0] = 0
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
+ if min_tokens_to_keep > 1:
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
+ sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+
+ scores = scores.masked_fill(indices_to_remove, filter_value)
+ return scores
+
+def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf):
+ assert logits.dim() == 1
+ top_k = min(top_k, logits.size(-1))
+ if top_k > 0:
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = filter_value
+ if top_p > 0.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
+ sorted_indices_to_remove[..., 0] = 0
+
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
+ logits[indices_to_remove] = filter_value
+
+ indices_to_remove = logits < threshold
+ logits[indices_to_remove] = filter_value
+ return logits
+
+# ========== batch version ========= #
+def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
+ '''
+ context_hidden: bsz*beam x seqlen x embed_dim
+ next_hidden: bsz*beam x 1 x embed_dim
+ next_top_k_probs: bsz x beam
+ '''
+ _, context_len, embed_dim = context_hidden.size()
+ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
+ norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
+ cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S]
+ scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
+ next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
+ scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
+ scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
+ selected_idx = scores.max(dim=-1)[1] # [B]
+ return selected_idx
+
+def ContrastiveDecodingOneStepFast(
+ model,
+ ids,
+ beam_width,
+ alpha,
+ past_key_values,
+ last_hidden_states,
+ vocab,
+ logit_for_next_step,
+ first_step=False,
+ ):
+ # input_ids: [B, S]
+ if first_step:
+ output = model(
+ input_ids=ids,
+ past_key_values=past_key_values,
+ use_cache=True,
+ output_hidden_states=True
+ )
+ past_key_values = output.past_key_values
+ last_hidden_states = output.hidden_states[-1] # [B, S, E]
+ logit_for_next_step = output.logits[:, -1, :] # [B, V]
+ bsz, seqlen, embed_dim = last_hidden_states.size()
+ p = random.uniform(0, 1)
+
+ next_probs = F.softmax(logit_for_next_step, dim=-1)
+ _, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
+ top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K]
+ # compute new hidden
+ past_key_values = enlarge_past_key_values(past_key_values, beam_width)
+ output = model(
+ input_ids=top_k_ids.view(-1, 1),
+ attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ )
+ past_key_values = output.past_key_values
+ logits = output.logits[:, -1, :] # [B*K, V]
+ next_hidden = output.hidden_states[-1] # [B*K, 1, E]
+ context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E]
+
+ selected_idx = ranking_fast(
+ context_hidden,
+ next_hidden,
+ top_k_probs, # [B, K]
+ alpha,
+ beam_width,
+ ) # [B]
+ # prepare for the next step
+ next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1]
+ next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E]
+ next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E]
+ last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E]
+ past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
+ logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V]
+ # next_id: [B, 1]
+ return next_id, past_key_values, last_hidden_states, logits
+
+def enlarge_past_key_values(past_key_values, beam_width):
+ # from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
+ new_key_values = []
+ for layer in past_key_values:
+ items = []
+ for item in layer:
+ # item is the key and value matrix
+ bsz, num_head, seq_len, esz = item.size()
+ item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz]
+ items.append(item)
+ new_key_values.append(items)
+ return new_key_values
+
+def select_past_key_values(past_key_values, beam_width, selected_idx):
+ '''select_idx: [B]'''
+ new_key_values = []
+ for layer in past_key_values:
+ items = []
+ for item in layer:
+ bsz_and_beam, num_head, seq_len, esz = item.size()
+ bsz = int(bsz_and_beam//beam_width)
+ item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
+ item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
+ items.append(item)
+ new_key_values.append(items)
+ return new_key_values
+
+# ========== fast plug and play version ========= #
+def plug_and_play_fast_ranking(
+ context_hidden,
+ next_hidden,
+ next_top_k_ids,
+ next_top_k_probs,
+ alpha,
+ beta,
+ batch_class_score,
+ beam_width,
+):
+ '''
+ context_hidden: beam_width x context_len x embed_dim
+ next_hidden: beam_width x 1 x embed_dim
+ next_top_k_ids: beam_width x 1
+ batch_class_score: beam_width x 1
+ '''
+ _, context_len, embed_dim = context_hidden.size()
+ norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
+ norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
+ cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
+ scores, _ = torch.max(cosine_matrix, dim = -1)
+ next_top_k_probs = next_top_k_probs.view(-1)
+ scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width])
+ scores = torch.stack(torch.split(scores, beam_width))
+ selected_idx = scores.max(dim=-1)[1]
+ return selected_idx
+
+def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta,
+ simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states,
+ logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False):
+ '''
+ model: the generation model, e.g., gpt2
+ input_ids: 1 x seqlen
+ '''
+
+ if first_step:
+ output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True)
+ past_key_values = output.past_key_values
+ last_hidden_states = output.hidden_states[-1] # [B, S, E]
+ logit_for_next_step = output.logits[:, -1, :] # [B, V]
+ bsz, seqlen, embed_dim = last_hidden_states.size()
+ next_probs = F.softmax(logit_for_next_step, dim = -1)
+ _, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
+ top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
+
+ # compute the new hidden
+ past_key_values = enlarge_past_key_values(past_key_values, beam_width)
+ output = model(
+ input_ids=top_k_ids.view(-1, 1) ,
+ attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
+ past_key_values=past_key_values,
+ output_hidden_states=True,
+ use_cache=True,
+ )
+ past_key_values = output.past_key_values
+ logits = output.logits[:, -1, :]
+ next_hidden = output.hidden_states[-1]
+ context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim)
+
+ # prepare for the classification model
+ input_ids_for_class_ = torch.cat([
+ input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen),
+ top_k_ids.view(-1, 1)
+ ], dim=-1
+ )
+
+ batch_text_list = []
+ for one_input_id in input_ids_for_class_:
+ one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:])
+ # we only consider the class score of the generated text continuation
+ batch_text_list.append(one_text)
+ batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list)
+
+ selected_idx = plug_and_play_fast_ranking(
+ context_hidden,
+ next_hidden,
+ top_k_ids,
+ top_k_probs,
+ alpha,
+ beta,
+ batch_score,
+ beam_width,
+ )
+
+ # prepare for the next step
+ next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)
+ next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))
+ next_hidden = next_hidden[range(bsz), selected_idx, :]
+ last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
+ past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
+ logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]
+ input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1)
+ return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class
+
+
diff --git a/magic.py b/magic.py
index 85f46a7..d291142 100644
--- a/magic.py
+++ b/magic.py
@@ -29,7 +29,6 @@ from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
-from towhee.models import clip
class Magic(NNOperator):
"""
@@ -38,22 +37,32 @@ class Magic(NNOperator):
def __init__(self, model_name: str):
super().__init__()
path = str(pathlib.Path(__file__).parent)
- sys.path.append(path)
+ sys.path.append(path + '/clip')
+ sys.path.append(path + '/language_model')
+ print(sys.path)
from clip import CLIP
from simctg import SimCTG
sys.path.pop()
+ sys.path.pop()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Language Model
- language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco'
+ cfg = self._configs()[model_name]
+ language_model_name = cfg['language_model'] # or r'/path/to/downloaded/cambridgeltl/magic_mscoco'
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device)
self.generation_model.eval()
- model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32"
+ model_name = cfg['clip_model'] # or r"/path/to/downloaded/openai/clip-vit-base-patch32"
self.clip = CLIP(model_name).to(self.device)
+ self.clip.to(self.device)
self.clip.eval()
+ sos_token = r'<-start_of_text->'
+ start_token = self.generation_model.tokenizer.tokenize(sos_token)
+ start_token_id = self.generation_model.tokenizer.convert_tokens_to_ids(start_token)
+ self.input_ids = torch.LongTensor(start_token_id).view(1,-1).to(self.device)
+
def _preprocess(self, img):
img = to_pil(img)
@@ -87,13 +96,15 @@ class Magic(NNOperator):
k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16
eos_token = '<|endoftext|>'
with torch.no_grad():
- output = generation_model.magic_search(input_ids, k,
- alpha, decoding_len, beta, image_instance, clip, 60)
+ print(type(img))
+ output = self.generation_model.magic_search(self.input_ids, k,
+ alpha, decoding_len, beta, img, self.clip, 60)
- return out
+ return output
def _configs(self):
config = {}
- config['expansionnet_rf'] = {}
- config['expansionnet_rf']['weights'] = 'rf_model.pth'
+ config['magic_mscoco'] = {}
+ config['magic_mscoco']['language_model'] = 'cambridgeltl/magic_mscoco'
+ config['magic_mscoco']['clip_model'] = 'openai/clip-vit-base-patch32'
return config
diff --git a/zerocap/README.md b/zerocap/README.md
deleted file mode 100644
index 1839550..0000000
--- a/zerocap/README.md
+++ /dev/null
@@ -1,89 +0,0 @@
-### Our Implementation of the ZeroCap Baseline Model
-
-****
-### Catalogue:
-* 1. Environment Preparation
-* 2. Image Captioning on MSCOCO
-* 3. Image Captioning on Flickr30k
-* 4. Cross Domain Image Captioning on MSCOCO
-* 5. Cross Domain Image Captioning on Flickr30k
-* 6. Citation
-* 7. Acknowledgements
-
-****
-
-
-
-#### 1. Environment Preparation:
-To install the correct environment, please run the following command:
-```yaml
-pip install -r requirements.txt
-```
-
-****
-
-
-
-#### 2. Image Captioning on MSCOCO:
-To perform image captioning on MSCOCO, please run the following command:
-```yaml
-chmod +x ./mscoco_zerocap.sh
-./mscoco_zerocap.sh
-```
-
-****
-
-
-
-#### 3. Image Captioning on Flickr30k:
-To perform image captioning on Flickr30k, please run the following command:
-```yaml
-chmod +x ./flickr30k_zerocap.sh
-./flickr30k_zerocap.sh
-```
-
-****
-
-
-
-#### 4. Cross Domain Image Captioning on MSCOCO:
-To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command:
-```yaml
-chmod +x ./flickr30k_to_mscoco_zerocap.sh
-./flickr30k_to_mscoco_zerocap.sh
-```
-
-****
-
-
-
-#### 5. Cross Domain Image Captioning on Flickr30k:
-To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command:
-```yaml
-chmod +x ./mscoco_to_flickr30k_zerocap.sh
-./mscoco_to_flickr30k_zerocap.sh
-```
-
-****
-
-
-
-#### 6. Citation:
-If you find our code helpful, please cite the original paper as
-
-```bibtex
-@article{tewel2021zero,
- title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic},
- author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior},
- journal={arXiv preprint arXiv:2111.14447},
- year={2021}
-}
-```
-
-****
-
-
-
-#### 7. Acknowledgements:
-We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text).
-
diff --git a/zerocap/cog.yaml b/zerocap/cog.yaml
deleted file mode 100644
index 92f13da..0000000
--- a/zerocap/cog.yaml
+++ /dev/null
@@ -1,12 +0,0 @@
-build:
- gpu: true
- python_version: "3.8"
- system_packages:
- - "libgl1-mesa-glx"
- - "libglib2.0-0"
- python_packages:
- - "git+https://github.com/openai/CLIP.git"
- - "git+https://github.com/YoadTew/zero-shot-image-to-text.git"
-
-predict: "predict.py:Predictor"
-#predict: "predict_arithmetic.py:Predictor"
\ No newline at end of file
diff --git a/zerocap/flickr30k_zerocap.sh b/zerocap/flickr30k_zerocap.sh
deleted file mode 100755
index b727b12..0000000
--- a/zerocap/flickr30k_zerocap.sh
+++ /dev/null
@@ -1,14 +0,0 @@
-#!/bin/bash
-
-# lm_model:
-# 1. cambridgeltl/magic_mscoco
-# 2. cambridgeltl/magic_flickr30k
-CUDA_VISIBLE_DEVICES=1 python run.py \
- --beam_size 1 \
- --target_seq_length 16 \
- --reset_context_delta \
- --lm_model cambridgeltl/magic_flickr30k \
- --test_image_prefix_path ../data/flickr30k/test_images \
- --test_path ../data/flickr30k/flickr30k_test.json \
- --save_path_prefix ../inference_result/flickr30k/baselines/ \
- --save_name zerocap_result.json
diff --git a/zerocap/forbidden_tokens.npy b/zerocap/forbidden_tokens.npy
deleted file mode 100644
index aeed51b..0000000
Binary files a/zerocap/forbidden_tokens.npy and /dev/null differ
diff --git a/zerocap/model/ZeroCLIP.py b/zerocap/model/ZeroCLIP.py
deleted file mode 100644
index 2c36fd2..0000000
--- a/zerocap/model/ZeroCLIP.py
+++ /dev/null
@@ -1,389 +0,0 @@
-import numpy as np
-from torch import nn
-from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
-from transformers.models.gpt_neo import GPTNeoForCausalLM
-import torch
-import clip
-from PIL import Image
-from datetime import datetime
-import sys
-
-
-def log_info(text, verbose=True):
- if verbose:
- dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
- print(f'{dt_string} | {text}')
- sys.stdout.flush()
-
-
-def add_context(x, y):
- return (x[0] + y[0], x[1] + y[1])
-
-
-def convert_models_to_fp32(model):
- for p in model.parameters():
- p.data = p.data.float()
-
-
-class CLIPTextGenerator:
- def __init__(self,
- seed=0,
- lm_model='gpt-2',
- forbidden_tokens_file_path='./forbidden_tokens.npy',
- clip_checkpoints='./clip_checkpoints',
- target_seq_length=15,
- reset_context_delta=True,
- num_iterations=5,
- clip_loss_temperature=0.01,
- clip_scale=1.,
- ce_scale=0.2,
- stepsize=0.3,
- grad_norm_factor=0.9,
- fusion_factor=0.99,
- repetition_penalty=1.,
- end_token='.',
- end_factor=1.01,
- forbidden_factor=20,
- **kwargs):
-
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
-
- # set Random seed
- torch.manual_seed(seed)
- np.random.seed(seed)
-
- # Initialize Language model
- self.context_prefix = ''
-
- self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model)
- self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True)
- self.context_prefix = self.lm_tokenizer.bos_token
-
- self.lm_model.to(self.device)
- self.lm_model.eval()
-
- self.forbidden_tokens = np.load(forbidden_tokens_file_path)
- self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
- (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())]
-
- # Freeze LM weights
- for param in self.lm_model.parameters():
- param.requires_grad = False
-
- # Initialize CLIP
- self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
- download_root=clip_checkpoints, jit=False)
- # convert_models_to_fp32(self.clip)
-
- # Init arguments
- self.target_seq_length = target_seq_length
- self.reset_context_delta = reset_context_delta
- self.num_iterations = num_iterations
- self.clip_loss_temperature = clip_loss_temperature
- self.clip_scale = clip_scale
- self.ce_scale = ce_scale
- self.stepsize = stepsize
- self.grad_norm_factor = grad_norm_factor
- self.fusion_factor = fusion_factor
- self.repetition_penalty = repetition_penalty
- self.end_token = self.lm_tokenizer.encode(end_token)[0]
- self.end_factor = end_factor
- self.ef_idx = 1
- self.forbidden_factor = forbidden_factor
-
- def get_img_feature(self, img_path, weights):
- imgs = [Image.open(x) for x in img_path]
- clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
-
- with torch.no_grad():
- image_fts = [self.clip.encode_image(x) for x in clip_imgs]
-
- if weights is not None:
- image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
- else:
- image_features = sum(image_fts)
-
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
- return image_features.detach()
-
- def get_txt_features(self, text):
- clip_texts = clip.tokenize(text).to(self.device)
-
- with torch.no_grad():
- text_features = self.clip.encode_text(clip_texts)
-
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
- return text_features.detach()
-
- def get_combined_feature(self, img_path, texts, weights_i, weights_t):
- imgs = [Image.open(x) for x in img_path]
- clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
- clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
-
- with torch.no_grad():
- image_fts = [self.clip.encode_image(x) for x in clip_imgs]
- text_fts = [self.clip.encode_text(x) for x in clip_texts]
-
- features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
- if weights_t is not None:
- features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
-
- features = features / features.norm(dim=-1, keepdim=True)
- return features.detach()
-
- def run(self, image_features, cond_text, beam_size):
- self.image_features = image_features
-
- context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
-
- output_tokens, output_text = self.generate_text(context_tokens, beam_size)
-
- return output_text
-
- def generate_text(self, context_tokens, beam_size):
- context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
-
- gen_tokens = None
- scores = None
- seq_lengths = torch.ones(beam_size, device=self.device)
- is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
-
- for i in range(self.target_seq_length):
- probs = self.get_next_probs(i, context_tokens)
- logits = probs.log()
-
- if scores is None:
- scores, next_tokens = logits.topk(beam_size, -1)
- context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
- next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
-
- if gen_tokens is None:
- gen_tokens = next_tokens
- else:
- gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
- gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
- else:
- logits[is_stopped] = -float(np.inf)
- logits[is_stopped, 0] = 0
- scores_sum = scores[:, None] + logits
- seq_lengths[~is_stopped] += 1
- scores_sum_average = scores_sum / seq_lengths[:, None]
- scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
- beam_size, -1)
- next_tokens_source = next_tokens // scores_sum.shape[1]
- seq_lengths = seq_lengths[next_tokens_source]
- next_tokens = next_tokens % scores_sum.shape[1]
- next_tokens = next_tokens.unsqueeze(1)
- gen_tokens = gen_tokens[next_tokens_source]
- gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
- context_tokens = context_tokens[next_tokens_source]
- scores = scores_sum_average * seq_lengths
- is_stopped = is_stopped[next_tokens_source]
-
- context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
- is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
-
- ####
- tmp_scores = scores / seq_lengths
- tmp_output_list = gen_tokens.cpu().numpy()
- tmp_output_texts = [
- self.lm_tokenizer.decode(tmp_output)
- for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
- ]
- tmp_order = tmp_scores.argsort(descending=True)
- tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
- log_info(tmp_output_texts, verbose=True)
- ####
-
- if is_stopped.all():
- break
-
- scores = scores / seq_lengths
- output_list = gen_tokens.cpu().numpy()
- output_texts = [
- self.lm_tokenizer.decode(output[: int(length)])
- for output, length in zip(output_list, seq_lengths)
- ]
- order = scores.argsort(descending=True)
- output_texts = [output_texts[i] for i in order]
-
- return context_tokens, output_texts
-
- def get_next_probs(self, i, context_tokens):
- last_token = context_tokens[:, -1:]
-
- if self.reset_context_delta and context_tokens.size(1) > 1:
- context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
-
- # Logits of LM with unshifted context
- logits_before_shift = self.lm_model(context_tokens)["logits"]
- logits_before_shift = logits_before_shift[:, -1, :]
- probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
-
- if context:
- context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
-
- lm_output = self.lm_model(last_token, past_key_values=context)
- logits, past = (
- lm_output["logits"],
- lm_output["past_key_values"],
- )
- logits = logits[:, -1, :]
-
- logits = self.update_special_tokens_logits(context_tokens, i, logits)
-
- probs = nn.functional.softmax(logits, dim=-1)
- probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
- probs = probs / probs.sum()
-
- return probs
-
- def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
- context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
-
- window_mask = torch.ones_like(context[0][0]).to(self.device)
-
- for i in range(self.num_iterations):
- curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
- context_delta]
-
- for p0, p1 in curr_shift:
- p0.retain_grad()
- p1.retain_grad()
-
- shifted_context = list(map(add_context, context, curr_shift))
-
- shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
- logits = shifted_outputs["logits"][:, -1, :]
- probs = nn.functional.softmax(logits, dim=-1)
-
- loss = 0.0
-
- # CLIP LOSS
- clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
- loss += self.clip_scale * clip_loss
-
- # CE/Fluency loss
- ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
- loss += ce_loss.sum()
-
- loss.backward()
-
- # ---------- Weights ----------
- combined_scores_k = -(ce_loss)
- combined_scores_c = -(self.clip_scale * torch.stack(clip_losses))
-
- # minmax
- if combined_scores_k.shape[0] == 1:
- tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device)
- else:
- tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / (
- combined_scores_k.max() - combined_scores_k.min())
- tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / (
- combined_scores_c.max() - combined_scores_c.min())
-
- tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c
- tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1)
-
- factor = 1
-
- # --------- Specific Gen ---------
- sep_grads = None
-
- for b in range(context_tokens.shape[0]):
- tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
- for p_ in curr_shift]
-
- # normalize gradients
- tmp_grad = [tuple([-self.stepsize * factor * (
- x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
- j] ** self.grad_norm_factor).data.cpu().numpy()
- for j, x in enumerate(p_)])
- for i, p_ in enumerate(curr_shift)]
- if sep_grads is None:
- sep_grads = tmp_grad
- else:
- for l_index in range(len(sep_grads)):
- sep_grads[l_index] = list(sep_grads[l_index])
- for k_index in range(len(sep_grads[0])):
- sep_grads[l_index][k_index] = np.concatenate(
- (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
- sep_grads[l_index] = tuple(sep_grads[l_index])
- final_grads = sep_grads
-
- # --------- update context ---------
- context_delta = list(map(add_context, final_grads, context_delta))
-
- for p0, p1 in curr_shift:
- p0.grad.data.zero_()
- p1.grad.data.zero_()
-
- new_context = []
- for p0, p1 in context:
- new_context.append((p0.detach(), p1.detach()))
- context = new_context
-
- context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
- for p_ in context_delta]
- context = list(map(add_context, context, context_delta))
-
- new_context = []
- for p0, p1 in context:
- new_context.append((p0.detach(), p1.detach()))
- context = new_context
-
- return context
-
- def update_special_tokens_logits(self, context_tokens, i, logits):
- for beam_id in range(context_tokens.shape[0]):
- for token_idx in set(context_tokens[beam_id][-4:].tolist()):
- factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
- logits[beam_id, token_idx] /= factor
-
- if i >= self.ef_idx:
- factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
- logits[beam_id, self.end_token] *= factor
- if i == 0:
- start_factor = 1.6
- factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
- logits[beam_id, self.end_token] /= factor
-
- for token_idx in list(self.forbidden_tokens):
- factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
- logits[beam_id, token_idx] /= factor
-
- return logits
-
- def clip_loss(self, probs, context_tokens):
- for p_ in self.clip.transformer.parameters():
- if p_.grad is not None:
- p_.grad.data.zero_()
-
- top_size = 512
- _, top_indices = probs.topk(top_size, -1)
-
- prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
-
- clip_loss = 0
- losses = []
- for idx_p in range(probs.shape[0]):
- top_texts = []
- prefix_text = prefix_texts[idx_p]
- for x in top_indices[idx_p]:
- top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
- text_features = self.get_txt_features(top_texts)
-
- with torch.no_grad():
- similiraties = (self.image_features @ text_features.T)
- target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
- target_probs = target_probs.type(torch.float32)
-
- target = torch.zeros_like(probs[idx_p])
- target[top_indices[idx_p]] = target_probs[0]
- target = target.unsqueeze(0)
- cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
-
- clip_loss += cur_clip_loss
- losses.append(cur_clip_loss)
-
- return clip_loss, losses
diff --git a/zerocap/model/ZeroCLIP_batched.py b/zerocap/model/ZeroCLIP_batched.py
deleted file mode 100644
index 2c0209f..0000000
--- a/zerocap/model/ZeroCLIP_batched.py
+++ /dev/null
@@ -1,449 +0,0 @@
-import numpy as np
-from torch import nn
-from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
-from transformers.models.gpt_neo import GPTNeoForCausalLM
-import torch
-import clip
-from PIL import Image
-from datetime import datetime
-import sys
-
-class TextCLIP(nn.Module):
- def __init__(self, model):
- super(TextCLIP, self).__init__()
- self.model = model
-
- def forward(self, text):
- return self.model.encode_text(text)
-
-
-class ImageCLIP(nn.Module):
- def __init__(self, model):
- super(ImageCLIP, self).__init__()
- self.model = model
-
- def forward(self, image):
- return self.model.encode_image(image)
-
-def log_info(text, verbose=True):
- if verbose:
- dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
- print(f'{dt_string} | {text}')
- sys.stdout.flush()
-
-
-def add_context(x, y):
- return (x[0] + y[0], x[1] + y[1])
-
-
-def convert_models_to_fp32(model):
- for p in model.parameters():
- p.data = p.data.float()
-
-
-class CLIPTextGenerator:
- def __init__(self,
- seed=0,
- lm_model='gpt-2',
- forbidden_tokens_file_path='./forbidden_tokens.npy',
- clip_checkpoints='./clip_checkpoints',
- target_seq_length=15,
- reset_context_delta=True,
- num_iterations=5,
- clip_loss_temperature=0.01,
- clip_scale=1.,
- ce_scale=0.2,
- stepsize=0.3,
- grad_norm_factor=0.9,
- fusion_factor=0.99,
- repetition_penalty=1.,
- end_token='.',
- end_factor=1.01,
- forbidden_factor=20,
- **kwargs):
-
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
-
- # set Random seed
- torch.manual_seed(seed)
- np.random.seed(seed)
-
- # Initialize Language model
- self.context_prefix = ''
-
- if lm_model == 'gpt-neo':
- self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
- self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True)
- elif lm_model == 'gpt-2':
- self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
- self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True)
- self.context_prefix = self.lm_tokenizer.bos_token
-
- self.lm_model.to(self.device)
- self.lm_model.eval()
-
- self.forbidden_tokens = np.load(forbidden_tokens_file_path)
- self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
- (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())]
-
- # Freeze LM weights
- for param in self.lm_model.parameters():
- param.requires_grad = False
-
- # Initialize CLIP
- self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
- download_root=clip_checkpoints, jit=False)
- self.clip_image = ImageCLIP(self.clip)
- self.clip_image = torch.nn.DataParallel(self.clip_image)
- self.clip_text = TextCLIP(self.clip)
- self.clip_text = torch.nn.DataParallel(self.clip_text)
-
- # Init arguments
- self.target_seq_length = target_seq_length
- self.reset_context_delta = reset_context_delta
- self.num_iterations = num_iterations
- self.clip_loss_temperature = clip_loss_temperature
- self.clip_scale = clip_scale
- self.ce_scale = ce_scale
- self.stepsize = stepsize
- self.grad_norm_factor = grad_norm_factor
- self.fusion_factor = fusion_factor
- self.repetition_penalty = repetition_penalty
- self.end_token = self.lm_tokenizer.encode(end_token)[0]
- self.end_factor = end_factor
- self.ef_idx = 1
- self.forbidden_factor = forbidden_factor
-
- def get_img_feature(self, img_path, weights):
- imgs = [Image.open(x) for x in img_path]
- clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
-
- with torch.no_grad():
- image_fts = [self.clip_image(x) for x in clip_imgs]
-
- if weights is not None:
- image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
- else:
- image_features = sum(image_fts)
-
- image_features = torch.nn.functional.normalize(image_features, dim=-1)
- return image_features.detach()
-
- def get_txt_features(self, text):
- clip_texts = clip.tokenize(text).to(self.device)
-
- with torch.no_grad():
- text_features = self.clip_text(clip_texts)
-
- text_features = torch.nn.functional.normalize(text_features, dim=-1)
- return text_features.detach()
-
- def get_combined_feature(self, img_path, texts, weights_i, weights_t):
- imgs = [Image.open(x) for x in img_path]
- clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
- clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
-
- with torch.no_grad():
- image_fts = [self.clip.encode_image(x) for x in clip_imgs]
- text_fts = [self.clip.encode_text(x) for x in clip_texts]
-
- features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
- if weights_t is not None:
- features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
-
- features = features / features.norm(dim=-1, keepdim=True)
- return features.detach()
-
- def run(self, image_features, cond_text, beam_size):
- self.image_features = image_features
-
- context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
-
- output_tokens, output_text = self.generate_text(context_tokens, beam_size)
-
- return output_text
-
- def generate_text(self, context_tokens, beam_size):
- context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
-
- gen_tokens = None
- scores = None
- seq_lengths = torch.ones(beam_size, device=self.device)
- is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
-
- for i in range(self.target_seq_length):
- probs = self.get_next_probs(i, context_tokens)
- logits = probs.log()
-
- if scores is None:
- scores, next_tokens = logits.topk(beam_size, -1)
- context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
- next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
-
- if gen_tokens is None:
- gen_tokens = next_tokens
- else:
- gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
- gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
- else:
- logits[is_stopped] = -float(np.inf)
- logits[is_stopped, 0] = 0
- scores_sum = scores[:, None] + logits
- seq_lengths[~is_stopped] += 1
- scores_sum_average = scores_sum / seq_lengths[:, None]
- scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
- beam_size, -1)
- next_tokens_source = next_tokens // scores_sum.shape[1]
- seq_lengths = seq_lengths[next_tokens_source]
- next_tokens = next_tokens % scores_sum.shape[1]
- next_tokens = next_tokens.unsqueeze(1)
- gen_tokens = gen_tokens[next_tokens_source]
- gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
- context_tokens = context_tokens[next_tokens_source]
- scores = scores_sum_average * seq_lengths
- is_stopped = is_stopped[next_tokens_source]
-
- context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
- is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
-
- ####
- tmp_scores = scores / seq_lengths
- tmp_output_list = gen_tokens.cpu().numpy()
- tmp_output_texts = [
- self.lm_tokenizer.decode(tmp_output)
- for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
- ]
- tmp_order = tmp_scores.argsort(descending=True)
- tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
- log_info(tmp_output_texts, verbose=True)
- ####
-
- if is_stopped.all():
- break
-
- scores = scores / seq_lengths
- output_list = gen_tokens.cpu().numpy()
- output_texts = [
- self.lm_tokenizer.decode(output[: int(length)])
- for output, length in zip(output_list, seq_lengths)
- ]
- order = scores.argsort(descending=True)
- output_texts = [output_texts[i] for i in order]
-
- return context_tokens, output_texts
-
- def get_next_probs(self, i, context_tokens):
- last_token = context_tokens[:, -1:]
-
- if self.reset_context_delta and context_tokens.size(1) > 1:
- context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
-
- # Logits of LM with unshifted context
- logits_before_shift = self.lm_model(context_tokens)["logits"]
- logits_before_shift = logits_before_shift[:, -1, :]
- probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
-
- if context:
- context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
-
- lm_output = self.lm_model(last_token, past_key_values=context)
- logits, past = (
- lm_output["logits"],
- lm_output["past_key_values"],
- )
- logits = logits[:, -1, :]
-
- logits = self.update_special_tokens_logits(context_tokens, i, logits)
-
- probs = nn.functional.softmax(logits, dim=-1)
- probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
- probs = probs / probs.sum()
-
- return probs
-
- def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
- context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
-
- for i in range(self.num_iterations):
- curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
- context_delta]
-
- for p0, p1 in curr_shift:
- p0.retain_grad()
- p1.retain_grad()
-
- shifted_context = list(map(add_context, context, curr_shift))
-
- shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
- logits = shifted_outputs["logits"][:, -1, :]
- probs = nn.functional.softmax(logits, dim=-1)
-
- loss = 0.0
-
- # CLIP LOSS
- clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
- loss += self.clip_scale * clip_loss
-
- # CE/Fluency loss
- ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
- loss += ce_loss.sum()
-
- loss.backward()
-
- # --------- Specific Gen ---------
- final_grads = self.norm_grad(context, context_tokens, curr_shift)
-
- # --------- update context ---------
- context_delta = list(map(add_context, final_grads, context_delta))
-
- for p0, p1 in curr_shift:
- p0.grad.data.zero_()
- p1.grad.data.zero_()
-
- new_context = []
- for p0, p1 in context:
- new_context.append((p0.detach(), p1.detach()))
- context = new_context
-
- context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
- for p_ in context_delta]
- context = list(map(add_context, context, context_delta))
-
- new_context = []
- for p0, p1 in context:
- new_context.append((p0.detach(), p1.detach()))
- context = new_context
-
- return context
-
- def norm_grad(self, context, context_tokens, curr_shift, ):
- factor = 1
- sep_grads = None
- window_mask = torch.ones_like(context[0][0]).to(self.device)
-
- for b in range(context_tokens.shape[0]):
- tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
- for p_ in curr_shift]
-
- # normalize gradients
- tmp_grad = [tuple([-self.stepsize * factor * (
- x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
- j] ** self.grad_norm_factor).data.cpu().numpy()
- for j, x in enumerate(p_)])
- for i, p_ in enumerate(curr_shift)]
- if sep_grads is None:
- sep_grads = tmp_grad
- else:
- for l_index in range(len(sep_grads)):
- sep_grads[l_index] = list(sep_grads[l_index])
- for k_index in range(len(sep_grads[0])):
- sep_grads[l_index][k_index] = np.concatenate(
- (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
- sep_grads[l_index] = tuple(sep_grads[l_index])
- final_grads = sep_grads
-
- return final_grads
-
- def update_special_tokens_logits(self, context_tokens, i, logits):
- for beam_id in range(context_tokens.shape[0]):
- for token_idx in set(context_tokens[beam_id][-4:].tolist()):
- factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
- logits[beam_id, token_idx] /= factor
-
- if i >= self.ef_idx:
- factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
- logits[beam_id, self.end_token] *= factor
- if i == 0:
- start_factor = 1.6
- factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
- logits[beam_id, self.end_token] /= factor
-
- for token_idx in list(self.forbidden_tokens):
- factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
- logits[beam_id, token_idx] /= factor
-
- return logits
-
- def clip_loss(self, probs, context_tokens):
- for p_ in self.clip.transformer.parameters():
- if p_.grad is not None:
- p_.grad.data.zero_()
-
- top_size = 512
- top_probs, top_indices = probs.topk(top_size, -1)
-
- prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens]
-
- clip_loss = 0
- losses = []
-
- top_texts = []
- for idx_p in range(probs.shape[0]):
- prefix_text = prefix_texts[idx_p]
- for x in top_indices[idx_p]:
- top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
-
- text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1)
-
- with torch.no_grad():
- similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1)
- similiraties = similiraties.reshape(probs.size(0), -1)
- target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
- target_probs = target_probs.type(torch.float32)
-
- clip_loss += torch.sum(-(target_probs * torch.log(top_probs)))
- # for idx_p in range(probs.shape[0]):
- # top_texts = []
- # prefix_text = prefix_texts[idx_p]
- # for x in top_indices[idx_p]:
- # top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
- # text_features = self.get_txt_features(top_texts)
- #
- # with torch.no_grad():
- # similiraties = (self.image_features @ text_features.T)
- # target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
- # target_probs = target_probs.type(torch.float32)
- #
- # target = torch.zeros_like(probs[idx_p])
- # target[top_indices[idx_p]] = target_probs[0]
- # target = target.unsqueeze(0)
- # cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
- #
- # clip_loss += cur_clip_loss
- # losses.append(cur_clip_loss)
-
- return clip_loss, losses
-
- def clip_loss_old(self, probs, context_tokens):
- for p_ in self.clip.transformer.parameters():
- if p_.grad is not None:
- p_.grad.data.zero_()
-
- top_size = 512
- _, top_indices = probs.topk(top_size, -1)
-
- prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
-
- clip_loss = 0
- losses = []
- for idx_p in range(probs.shape[0]):
- top_texts = []
- prefix_text = prefix_texts[idx_p]
- for x in top_indices[idx_p]:
- top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
- text_features = self.get_txt_features(top_texts)
-
- with torch.no_grad():
- similiraties = (self.image_features @ text_features.T)
- target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
- target_probs = target_probs.type(torch.float32)
-
- target = torch.zeros_like(probs[idx_p])
- target[top_indices[idx_p]] = target_probs[0]
- target = target.unsqueeze(0)
- cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
-
- clip_loss += cur_clip_loss
- losses.append(cur_clip_loss)
-
- return clip_loss, losses
\ No newline at end of file
diff --git a/zerocap/model/__init__.py b/zerocap/model/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc
deleted file mode 100644
index 093e083..0000000
Binary files a/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc and /dev/null differ
diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc
deleted file mode 100644
index 5f08c9e..0000000
Binary files a/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc and /dev/null differ
diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc
deleted file mode 100644
index aa91bfc..0000000
Binary files a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc and /dev/null differ
diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc
deleted file mode 100644
index 816cae9..0000000
Binary files a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc and /dev/null differ
diff --git a/zerocap/model/__pycache__/__init__.cpython-36.pyc b/zerocap/model/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 5a18c45..0000000
Binary files a/zerocap/model/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/zerocap/model/__pycache__/__init__.cpython-37.pyc b/zerocap/model/__pycache__/__init__.cpython-37.pyc
deleted file mode 100644
index 3108a5a..0000000
Binary files a/zerocap/model/__pycache__/__init__.cpython-37.pyc and /dev/null differ
diff --git a/zerocap/mscoco_zerocap.sh b/zerocap/mscoco_zerocap.sh
deleted file mode 100755
index a06ae50..0000000
--- a/zerocap/mscoco_zerocap.sh
+++ /dev/null
@@ -1,14 +0,0 @@
-#!/bin/bash
-
-# lm_model:
-# 1. cambridgeltl/magic_mscoco
-# 2. cambridgeltl/magic_flickr30k
-CUDA_VISIBLE_DEVICES=1 python run.py \
- --beam_size 1 \
- --target_seq_length 16 \
- --reset_context_delta \
- --lm_model cambridgeltl/magic_mscoco \
- --test_image_prefix_path ../data/mscoco/test_images \
- --test_path ../data/mscoco/mscoco_test.json \
- --save_path_prefix ../inference_result/mscoco/baselines/ \
- --save_name zerocap_result.json
diff --git a/zerocap/predict.py b/zerocap/predict.py
deleted file mode 100644
index 46271f4..0000000
--- a/zerocap/predict.py
+++ /dev/null
@@ -1,117 +0,0 @@
-import os
-import tempfile
-import sys
-sys.path.append('CLIP')
-from pathlib import Path
-import cog
-import argparse
-import torch
-import clip
-from model.ZeroCLIP import CLIPTextGenerator
-
-def perplexity_score(text, lm_model, lm_tokenizer, device):
- encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
- input_ids = encodings.input_ids.to(device)
- target_ids = input_ids.clone()
-
- outputs = lm_model(input_ids, labels=target_ids)
- log_likelihood = outputs[0]
- ll = log_likelihood.item()
-
- return ll
-
-class Predictor(cog.Predictor):
- def setup(self):
- self.args = get_args()
- self.args.reset_context_delta = True
- self.text_generator = CLIPTextGenerator(**vars(self.args))
-
- @cog.input(
- "image",
- type=Path,
- help="input image"
- )
- @cog.input(
- "cond_text",
- type=str,
- default='Image of a',
- help="conditional text",
- )
- @cog.input(
- "beam_size",
- type=int,
- default=5, min=1, max=10,
- help="Number of beams to use",
- )
- @cog.input(
- "end_factor",
- type=float,
- default=1.01, min=1.0, max=1.10,
- help="Higher value for shorter captions",
- )
- @cog.input(
- "max_seq_length",
- type=int,
- default=15, min=1, max=20,
- help="Maximum number of tokens to generate",
- )
- @cog.input(
- "ce_loss_scale",
- type=float,
- default=0.2, min=0.0, max=0.6,
- help="Scale of cross-entropy loss with un-shifted language model",
- )
- def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale):
- self.args.cond_text = cond_text
- self.text_generator.end_factor = end_factor
- self.text_generator.target_seq_length = max_seq_length
- self.text_generator.ce_scale = ce_loss_scale
-
- image_features = self.text_generator.get_img_feature([str(image)], None)
- captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
-
- # CLIP SCORE
- encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
- for c in captions]
- encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
- best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
-
- # Perplexity SCORE
- ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
- best_ppl_index = torch.tensor(ppl_scores).argmin().item()
-
- best_clip_caption = self.args.cond_text + captions[best_clip_idx]
- best_mixed = self.args.cond_text + captions[0]
- best_PPL = self.args.cond_text + captions[best_ppl_index]
-
- final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
-
- return final
- # return self.args.cond_text + captions[best_clip_idx]
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
- parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
- parser.add_argument("--target_seq_length", type=int, default=15)
- parser.add_argument("--cond_text", type=str, default="Image of a")
- parser.add_argument("--reset_context_delta", action="store_true",
- help="Should we reset the context at each token gen")
- parser.add_argument("--num_iterations", type=int, default=5)
- parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
- parser.add_argument("--clip_scale", type=float, default=1)
- parser.add_argument("--ce_scale", type=float, default=0.2)
- parser.add_argument("--stepsize", type=float, default=0.3)
- parser.add_argument("--grad_norm_factor", type=float, default=0.9)
- parser.add_argument("--fusion_factor", type=float, default=0.99)
- parser.add_argument("--repetition_penalty", type=float, default=1)
- parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
- parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
- parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
- parser.add_argument("--beam_size", type=int, default=5)
-
- args = parser.parse_args('')
- return args
diff --git a/zerocap/predict_arithmetic.py b/zerocap/predict_arithmetic.py
deleted file mode 100644
index 1e2ade2..0000000
--- a/zerocap/predict_arithmetic.py
+++ /dev/null
@@ -1,129 +0,0 @@
-import os
-import tempfile
-import sys
-sys.path.append('CLIP')
-from pathlib import Path
-import cog
-import argparse
-import torch
-import clip
-from model.ZeroCLIP import CLIPTextGenerator
-
-def perplexity_score(text, lm_model, lm_tokenizer, device):
- encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
- input_ids = encodings.input_ids.to(device)
- target_ids = input_ids.clone()
-
- outputs = lm_model(input_ids, labels=target_ids)
- log_likelihood = outputs[0]
- ll = log_likelihood.item()
-
- return ll
-
-class Predictor(cog.Predictor):
- def setup(self):
- self.args = get_args()
- self.args.reset_context_delta = True
- self.text_generator = CLIPTextGenerator(**vars(self.args))
-
- @cog.input(
- "image1",
- type=Path,
- help="Final result will be: image1 + (image2 - image3)"
- )
- @cog.input(
- "image2",
- type=Path,
- help="Final result will be: image1 + (image2 - image3)"
- )
- @cog.input(
- "image3",
- type=Path,
- help="Final result will be: image1 + (image2 - image3)"
- )
- @cog.input(
- "cond_text",
- type=str,
- default='Image of a',
- help="conditional text",
- )
- @cog.input(
- "beam_size",
- type=int,
- default=3, min=1, max=10,
- help="Number of beams to use",
- )
- @cog.input(
- "end_factors",
- type=float,
- default=1.06, min=1.0, max=1.10,
- help="Higher value for shorter captions",
- )
- @cog.input(
- "max_seq_lengths",
- type=int,
- default=3, min=1, max=20,
- help="Maximum number of tokens to generate",
- )
- @cog.input(
- "ce_loss_scale",
- type=float,
- default=0.2, min=0.0, max=0.6,
- help="Scale of cross-entropy loss with un-shifted language model",
- )
- def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale):
- self.args.cond_text = cond_text
- self.text_generator.end_factor = end_factors
- self.text_generator.target_seq_length = max_seq_lengths
- self.text_generator.ce_scale = ce_loss_scale
- self.text_generator.fusion_factor = 0.95
- self.text_generator.grad_norm_factor = 0.95
-
- image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None)
- captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
-
- # CLIP SCORE
- encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
- for c in captions]
- encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
- best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
-
- # Perplexity SCORE
- ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
- best_ppl_index = torch.tensor(ppl_scores).argmin().item()
-
- best_clip_caption = self.args.cond_text + captions[best_clip_idx]
- best_mixed = self.args.cond_text + captions[0]
- best_PPL = self.args.cond_text + captions[best_ppl_index]
-
- final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
-
- return final
- # return self.args.cond_text + captions[best_clip_idx]
-
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
- parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
- parser.add_argument("--target_seq_length", type=int, default=15)
- parser.add_argument("--cond_text", type=str, default="Image of a")
- parser.add_argument("--reset_context_delta", action="store_true",
- help="Should we reset the context at each token gen")
- parser.add_argument("--num_iterations", type=int, default=5)
- parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
- parser.add_argument("--clip_scale", type=float, default=1)
- parser.add_argument("--ce_scale", type=float, default=0.2)
- parser.add_argument("--stepsize", type=float, default=0.3)
- parser.add_argument("--grad_norm_factor", type=float, default=0.95)
- parser.add_argument("--fusion_factor", type=float, default=0.95)
- parser.add_argument("--repetition_penalty", type=float, default=1)
- parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
- parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
- parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
- parser.add_argument("--beam_size", type=int, default=5)
-
- args = parser.parse_args('')
- return args
diff --git a/zerocap/requirements.txt b/zerocap/requirements.txt
deleted file mode 100644
index 0eaf0ad..0000000
--- a/zerocap/requirements.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-ftfy
-regex
-tqdm
diff --git a/zerocap/run.py b/zerocap/run.py
deleted file mode 100644
index fab33b9..0000000
--- a/zerocap/run.py
+++ /dev/null
@@ -1,131 +0,0 @@
-import argparse
-import ipdb
-from tqdm import tqdm
-import progressbar
-import torch
-import ipdb
-import clip
-from model.ZeroCLIP import CLIPTextGenerator
-from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
-
-def get_args():
- parser = argparse.ArgumentParser()
-
- parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
- parser.add_argument("--test_path", type=str)
- parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
- parser.add_argument("--save_name", type=str, help="the name of the saved file")
-
- parser.add_argument("--seed", type=int, default=0)
- parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
- parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
- parser.add_argument("--target_seq_length", type=int, default=15)
- parser.add_argument("--cond_text", type=str, default="Image of a")
- parser.add_argument("--reset_context_delta", action="store_true",
- help="Should we reset the context at each token gen")
- parser.add_argument("--num_iterations", type=int, default=5)
- parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
- parser.add_argument("--clip_scale", type=float, default=1)
- parser.add_argument("--ce_scale", type=float, default=0.2)
- parser.add_argument("--stepsize", type=float, default=0.3)
- parser.add_argument("--grad_norm_factor", type=float, default=0.9)
- parser.add_argument("--fusion_factor", type=float, default=0.99)
- parser.add_argument("--repetition_penalty", type=float, default=1)
- parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
- parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
- parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
- parser.add_argument("--beam_size", type=int, default=1)
-
- parser.add_argument("--multi_gpu", action="store_true")
-
- parser.add_argument('--run_type',
- default='caption',
- nargs='?',
- choices=['caption', 'arithmetics'])
-
- parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
- help="Path to image for captioning")
-
- parser.add_argument("--arithmetics_imgs", nargs="+",
- default=['example_images/arithmetics/woman2.jpg',
- 'example_images/arithmetics/king2.jpg',
- 'example_images/arithmetics/man2.jpg'])
- parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
-
- args = parser.parse_args()
-
- return args
-
-def run(args, text_generator, img_path):
- image_features = text_generator.get_img_feature([img_path], None)
- captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
-
- encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
- encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
- best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
- return captions
-
-
-if __name__ == '__main__':
- if torch.cuda.is_available():
- print ('Cuda is available.')
- cuda_available = torch.cuda.is_available()
- args = get_args()
- device = torch.device('cuda')
-
- save_path_prefix = args.save_path_prefix
- import os
- if os.path.exists(save_path_prefix):
- pass
- else: # recursively construct directory
- os.makedirs(save_path_prefix, exist_ok=True)
- # parse save name
- save_name = args.save_name
- full_save_path = save_path_prefix + '/' + save_name
- print ('full save path is {}'.format(full_save_path))
-
- print ('Loading data...')
- import json
- with open(args.test_path) as f:
- item_list = json.load(f)
- print ('Data loaded.')
- print ('Number of test instances is {}'.format(len(item_list)))
-
- # ZeroCap generator
- text_generator = CLIPTextGenerator(**vars(args))
-
- result_list = []
- invalid_num = 0
- print ('----------------------------------------------------------------')
- test_num = len(item_list)
- #test_num = 10
- print ('Number of inference instances is {}'.format(test_num))
- p = progressbar.ProgressBar(test_num)
- p.start()
- for p_idx in tqdm(range(test_num)):
- p.update(p_idx)
- one_test_dict = item_list[p_idx]
-
- one_res_dict = {
- 'split':one_test_dict['split'],
- 'image_name':one_test_dict['image_name'],
- #'file_path':one_test_dict['file_path'],
- 'captions':one_test_dict['captions']
- }
-
- image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
- try:
- output_text = run(args, text_generator, img_path=image_full_path)
- one_res_dict['prediction'] = output_text[0]
- result_list.append(one_res_dict)
- except Exception as error:
- print(f'[!] ERROR:', error)
- invalid_num += 1
- print ('invalid number is {}'.format(invalid_num))
- continue
- p.finish()
- print ('Inference completed!')
-
- import json
- with open(full_save_path, 'w') as outfile:
- json.dump(result_list, outfile, indent=4)
diff --git a/zerocap/setup.py b/zerocap/setup.py
deleted file mode 100644
index 8ae2efe..0000000
--- a/zerocap/setup.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import os
-
-import pkg_resources
-from setuptools import setup, find_packages
-
-setup(
- name="zero-shot-image-to-text",
- py_modules=["zero-shot-image-to-text"],
- version="1.0",
- description="",
- packages=find_packages(),
- install_requires=[
- str(r)
- for r in pkg_resources.parse_requirements(
- open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
- )
- ],
- include_package_data=True
-)
\ No newline at end of file