diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..627dc5d
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .magic import Magic
+
+def magic(model_name: str):
+ return Magic(model_name)
diff --git a/clip/README.md b/clip/README.md
new file mode 100644
index 0000000..dd7a695
--- /dev/null
+++ b/clip/README.md
@@ -0,0 +1,141 @@
+## CLIP
+This folder illustrates how to use CLIP to build text index and to conduct cross-modal retrieval baseline.
+
+****
+## Catalogue:
+* 1. Build Text Index
+ * 1.1. Build Text Index for MSCOCO
+ * 1.1.1. Download Our Built Index
+ * 1.1.2. Construct the Index by Yourself
+ * 1.2. Build Text Index for Flickr30k
+ * 1.2.1. Download Our Built Index
+ * 1.2.2. Construct the Index by Yourself
+* 2. CLIP Retrieval Baseline
+ * 2.1. In Domain CLIP Retrieval
+ * 2.2. Cross Domain CLIP Retrieval
+
+****
+
+
+
+### 1. Build Text Index:
+We show how to build the text index, from which the caption is retrieved from, using CLIP.
+
+
+
+#### 1.1. Build Text Index for MSCOCO:
+First, we demonstrate how to build text index for MSCOCO.
+
+
+
+#### 1.1.1. Download Our Post-processed Index:
+We share our built index for MSCOCO via this [[link]](https://drive.google.com/file/d/1Dx_RPeAmydS6ZYuiJ-dLlK9-DjDZkxAh/view?usp=sharing). After downloading, unzip the downloaded file **mscoco_index.zip** under the current directory.
+
+> **** The resulting directory looks like:
+
+ .
+ ├── ./mscoco_index/
+ ├── index_matrix.txt # The file that stores the representations of captions from the training set of MSCOCO. Each row is a vector that corresponds to a specific caption from the training set.
+ └── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption.
+
+
+
+#### 1.1.2. Construct the Index by Yourself:
+
+You can also rebuild the index by yourself. First, you should make sure you have downloaded the MSCOCO data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). Then, you can run the following command to build the index.
+```yaml
+chmod +x ./build_mscoco_index.sh
+./build_mscoco_index.sh
+```
+The arguments are as follows:
+* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
+* `--text_file_path`: Where the training text corpus stores.
+* `--save_index_prefix`: In which directory you would like to store your index files.
+* `--save_index_name`: The saved name of the caption representations.
+* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions.
+* `--batch_size`: The inference batch size.
+
+
+
+
+#### 1.2. Build Text Index for Flickr30k:
+Next, we demonstrate how to build text index for Flickr30k.
+
+
+
+#### 1.2.1. Download Our Post-processed Index:
+We share our built index for Flickr30k via this [[link]](https://drive.google.com/file/d/1hS58_ir5pdZZPckApCtlz2RyasCQbrPf/view?usp=sharing). After downloading, unzip the downloaded file **flickr30k_index.zip** under the current directory.
+
+> **** The resulting directory looks like:
+
+ .
+ ├── ./flickr30k_index/
+ ├── index_matrix.txt # The file that stores the representations of captions from the training set of Flickr30k. Each row is a vector that corresponds to a specific caption from the training set.
+ └── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption.
+
+
+
+#### 1.2.2. Construct the Index by Yourself:
+
+You can also rebuild the index by yourself. First, you should make sure you have downloaded the Flickr30k data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). Then, you can run the following command to build the index.
+```yaml
+chmod +x ./build_flickr30k_index.sh
+./build_flickr30k_index.sh
+```
+The arguments are as follows:
+* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
+* `--text_file_path`: Where the training text corpus stores.
+* `--save_index_prefix`: In which directory you would like to store your index files.
+* `--save_index_name`: The saved name of the caption representations.
+* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions.
+* `--batch_size`: The inference batch size.
+
+****
+
+
+
+### 2. CLIP Retrieval Baseline:
+Here, we show how to conduct the CLIP retrieval baseline.
+
+
+
+#### 2.1. In Domain CLIP Retrieval:
+To retrieve the captions from the in domain training set, you should run the following command:
+```yaml
+chmod +x ./X_clip_retrieval.sh
+./X_clip_retrieval.sh
+```
+Here, X is in ['mscoco', 'flickr30k'] which corresponds for the MSCOCO and Flickr30k benchmarks.
+
+The arguments are as follows:
+* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
+* `--test_image_prefix_path`: Where the test set images stores.
+* `--test_path`: Where the reference test captions file stores.
+* `--index_matrix_path`: The path of the representation index file.
+* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions.
+* `--save_path_prefix`: Where to save the inferenced result.
+* `--save_name`: The saved name of the inferenced result.
+
+**[Note]** As we are conducting in domain CLIP retrieval, the test images and the caption index should come from the same benchmark.
+
+
+
+
+#### 2.2. Cross Domain CLIP Retrieval:
+To retrieve the captions from the cross domain training set, you should run the following command:
+```yaml
+chmod +x ./source_X_target_Y_clip_retrieval.sh
+./source_X_target_Y_clip_retrieval.sh
+```
+Here, X is the source domain from ['mscoco', 'flickr30k'] and Y is the target domain from ['flickr30k', 'mscoco'].
+
+The arguments are as follows:
+* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
+* `--test_image_prefix_path`: Where the test set images stores.
+* `--test_path`: Where the reference test captions file stores.
+* `--index_matrix_path`: The path of the representation index file.
+* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions.
+* `--save_path_prefix`: Where to save the inferenced result.
+* `--save_name`: The saved name of the inferenced result.
+
+**[Note]** As we are conducting cross domain CLIP retrieval, the test images and the caption index should come from **different** benchmarks.
diff --git a/clip/build_flickr30k_index.sh b/clip/build_flickr30k_index.sh
new file mode 100644
index 0000000..5579beb
--- /dev/null
+++ b/clip/build_flickr30k_index.sh
@@ -0,0 +1,7 @@
+CUDA_VISIBLE_DEVICES=1 python build_text_index.py\
+ --clip_name openai/clip-vit-base-patch32\
+ --text_file_path ../data/flickr30k/flickr30k_train.json\
+ --save_index_prefix ./flickr30k_index/\
+ --save_index_name index_matrix.txt\
+ --save_mapping_dict_name text_mapping.json\
+ --batch_size 128
\ No newline at end of file
diff --git a/clip/build_mscoco_index.sh b/clip/build_mscoco_index.sh
new file mode 100644
index 0000000..c053f75
--- /dev/null
+++ b/clip/build_mscoco_index.sh
@@ -0,0 +1,7 @@
+CUDA_VISIBLE_DEVICES=0 python build_text_index.py\
+ --clip_name openai/clip-vit-base-patch32\
+ --text_file_path ../data/mscoco/mscoco_train.json\
+ --save_index_prefix ./mscoco_index/\
+ --save_index_name index_matrix.txt\
+ --save_mapping_dict_name text_mapping.json\
+ --batch_size 128
\ No newline at end of file
diff --git a/clip/build_text_index.py b/clip/build_text_index.py
new file mode 100644
index 0000000..98461a5
--- /dev/null
+++ b/clip/build_text_index.py
@@ -0,0 +1,105 @@
+import sys
+import torch
+import numpy as np
+import progressbar
+import os
+
+def parse_config():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32")
+ parser.add_argument("--text_file_path", type=str)
+ # save configuration
+ parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index')
+ parser.add_argument("--save_index_name", type=str)
+ parser.add_argument("--save_mapping_dict_name", type=str,
+ help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text")
+ # inference configuration
+ parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP")
+ return parser.parse_args()
+
+def load_batch_text(text_file_path, batch_size):
+ import json
+ with open(text_file_path) as f:
+ item_list = json.load(f)
+
+ text_list = []
+ for item in item_list:
+ captions = item["captions"]
+ for cap in captions:
+ text_list.append(cap)
+ print ('Number of text instances is {}'.format(len(text_list)))
+
+ data_num = len(text_list)
+ batch_num = data_num // batch_size
+ batch_text_list = []
+ s_idx, e_idx = 0, batch_size
+ for p_idx in range(batch_num):
+ one_batch_text_list = []
+ for idx in range(s_idx, e_idx):
+ one_batch_text_list.append(text_list[idx])
+ batch_text_list.append(one_batch_text_list)
+ return batch_text_list
+
+
+import argparse
+if __name__ == '__main__':
+ if torch.cuda.is_available():
+ print ('Cuda is available.')
+ cuda_available = torch.cuda.is_available()
+ args = parse_config()
+ device = torch.device('cuda')
+
+ import os
+ if os.path.exists(args.save_index_prefix):
+ pass
+ else: # recursively construct directory
+ os.makedirs(args.save_index_prefix, exist_ok=True)
+
+ print ('Loading CLIP...')
+ from clip import CLIP
+ model = CLIP(args.clip_name)
+ if cuda_available:
+ model = model.cuda(device)
+ model.eval()
+ print ('CLIP loaded!')
+
+ print ('Loading text data...')
+ batch_text_list = load_batch_text(args.text_file_path, args.batch_size)
+ print ('Text data loaded.')
+
+ res_text_vec_list, res_text_list = [], []
+ batch_num = len(batch_text_list)
+ print ('Number of batches is {}'.format(batch_num))
+ print ('Start inference...')
+ p = progressbar.ProgressBar(batch_num)
+ p.start()
+ with torch.no_grad():
+ for p_idx in range(batch_num):
+ p.update(p_idx)
+ one_text_batch = batch_text_list[p_idx]
+ one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu()
+ one_batch_vec_list = one_batch_vec.unbind(dim=0)
+ bsz = len(one_batch_vec_list)
+ for k in range(bsz):
+ res_text_vec_list.append(one_batch_vec_list[k].numpy())
+ res_text_list.append(one_text_batch[k])
+ p.finish()
+ assert len(res_text_vec_list) == len(res_text_list)
+ print ('Inference completed!')
+
+ index_text_mapping_dict = {}
+ for k in range(len(res_text_list)):
+ index_text_mapping_dict[k] = res_text_list[k]
+ mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name
+ import json
+ with open(mapping_list_save_path, 'w') as outfile:
+ json.dump(index_text_mapping_dict, outfile, indent=4)
+ print ('Mapping dictionary saved!')
+
+ print ('Start buiding index...')
+ index_save_path = args.save_index_prefix + '/' + args.save_index_name
+ with open(index_save_path, 'w', encoding = 'utf8') as o:
+ for vec in res_text_vec_list:
+ one_text = ' '.join([str(num) for num in vec]).strip()
+ o.writelines(one_text + '\n')
+ print ('Index completed!')
diff --git a/clip/clip.py b/clip/clip.py
new file mode 100644
index 0000000..05b1c1c
--- /dev/null
+++ b/clip/clip.py
@@ -0,0 +1,146 @@
+import torch
+import requests
+from torch import nn
+from PIL import Image
+
+class CLIP(nn.Module):
+ def __init__(self, model_name):
+ super(CLIP, self).__init__()
+ # model name: e.g. openai/clip-vit-base-patch32
+ print ('Initializing CLIP model...')
+ from transformers import CLIPProcessor, CLIPModel
+ self.model = CLIPModel.from_pretrained(model_name)
+ self.model.eval()
+ self.processor = CLIPProcessor.from_pretrained(model_name)
+ from transformers import CLIPTokenizer
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
+ self.cuda_has_been_checked = False
+ print ('CLIP model initialized.')
+
+ def check_cuda(self):
+ self.cuda_available = next(self.model.parameters()).is_cuda
+ self.device = next(self.model.parameters()).get_device()
+ if self.cuda_available:
+ print ('Cuda is available.')
+ print ('Device is {}'.format(self.device))
+ else:
+ print ('Cuda is not available.')
+ print ('Device is {}'.format(self.device))
+
+ @torch.no_grad()
+ def compute_image_representation_from_image_path(self, image_path):
+ if not self.cuda_has_been_checked:
+ self.check_cuda()
+ self.cuda_has_been_checked = True
+ else:
+ pass
+ # image_path: the path of the image
+ image = Image.open(image_path)
+ inputs = self.processor(images=image, return_tensors="pt")
+ pixel_values = inputs['pixel_values']
+ if self.cuda_available:
+ pixel_values = pixel_values.cuda(self.device)
+ visual_outputs = self.model.vision_model(pixel_values=pixel_values)
+ image_embeds = visual_outputs[1]
+ image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
+ return image_embeds
+
+ def compute_image_representation_from_image_instance(self, image):
+ if not self.cuda_has_been_checked:
+ self.check_cuda()
+ self.cuda_has_been_checked = True
+ else:
+ pass
+ # image_path: the path of the image
+ inputs = self.processor(images=image, return_tensors="pt")
+ pixel_values = inputs['pixel_values']
+ if self.cuda_available:
+ pixel_values = pixel_values.cuda(self.device)
+ visual_outputs = self.model.vision_model(pixel_values=pixel_values)
+ image_embeds = visual_outputs[1]
+ image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
+ return image_embeds
+
+ def compute_text_representation(self, text_list):
+ if not self.cuda_has_been_checked:
+ self.check_cuda()
+ self.cuda_has_been_checked = True
+ else:
+ pass
+ # text_list: a list of text
+ text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt",
+ max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True)
+ # self.tokenizer.max_len_single_sentence + 2 = 77
+ input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']
+ if self.cuda_available:
+ input_ids = input_ids.cuda(self.device)
+ attention_mask = attention_mask.cuda(self.device)
+ text_outputs = self.model.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask
+ )
+ text_embeds = text_outputs[1]
+ text_embeds = self.model.text_projection(text_embeds)
+ return text_embeds
+
+ def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds):
+ '''
+ image_embeds: 1 x embed_dim
+ text_embeds: len(text_list) x embed_dim
+ '''
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
+ logit_scale = self.model.logit_scale.exp()
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
+ logits_per_image = logits_per_text.T
+ return logits_per_image.softmax(dim=1) # 1 x len(text_list)
+
+ def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list):
+ text_embeds = self.compute_text_representation(text_list)
+ return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds)
+
+ ### -------------------- functions for building index ---------------------- ###
+ def compute_batch_index_image_features(self, image_list):
+ '''
+ # list of image instances
+ '''
+ if not self.cuda_has_been_checked:
+ self.check_cuda()
+ self.cuda_has_been_checked = True
+ else:
+ pass
+ # image_path: the path of the image
+ inputs = self.processor(images=image_list, return_tensors="pt")
+ pixel_values = inputs['pixel_values']
+ if self.cuda_available:
+ pixel_values = pixel_values.cuda(self.device)
+ visual_outputs = self.model.vision_model(pixel_values=pixel_values)
+ image_embeds = visual_outputs[1]
+ image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
+ return image_embeds # len(image_list) x embed_dim
+
+ def compute_batch_index_text_representation(self, text_list):
+ if not self.cuda_has_been_checked:
+ self.check_cuda()
+ self.cuda_has_been_checked = True
+ else:
+ pass
+ # text_list: a list of text
+ #text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt")
+ text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt",
+ max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True)
+ input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']
+ if self.cuda_available:
+ input_ids = input_ids.cuda(self.device)
+ attention_mask = attention_mask.cuda(self.device)
+ text_outputs = self.model.text_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask
+ )
+ text_embeds = text_outputs[1]
+ text_embeds = self.model.text_projection(text_embeds)
+ return text_embeds
+ #logit_scale = self.model.logit_scale.exp()
+ #text_embeds = text_embeds * logit_scale
+ #return text_embeds
+
diff --git a/clip/clipretrieval.py b/clip/clipretrieval.py
new file mode 100644
index 0000000..bd77cbe
--- /dev/null
+++ b/clip/clipretrieval.py
@@ -0,0 +1,135 @@
+import json
+import copy
+import torch
+import progressbar
+import numpy as np
+from PIL import Image
+
+class CLIPIndex:
+ def __init__(self, index_matrix_path, mapping_dict_path, clip):
+ '''
+ index_path: the pre-trained index
+ mapping_dict_path: the pre-indexed mapping dictionary
+ clip: the pre-trained clip model
+ '''
+ print ('Loading index...')
+ self.index_matrix = self.normalization(self.load_matrix(index_matrix_path))
+ print ('Index loaded.')
+ print (self.index_matrix.shape)
+ with open(mapping_dict_path) as f:
+ self.mapping_dict = json.load(f)
+ self.clip = clip
+
+ def load_matrix(self, in_f):
+ matrix_list = []
+ with open(in_f, 'r', encoding = 'utf8') as i:
+ lines = i.readlines()
+ for l in lines:
+ one_vec = [float(num) for num in l.strip('\n').split()]
+ matrix_list.append(one_vec)
+ return np.array(matrix_list)
+
+ def normalization(self, matrix):
+ '''
+ matrix: num_instance x num_feature
+ '''
+ return matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
+
+ def get_image_representation(self, image_path):
+ image_instance = Image.open(image_path)
+ image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy()
+ image_vec = self.normalization(image_vec)
+ return image_vec
+
+ def search_text(self, image_path):
+ image_vec = self.get_image_representation(image_path)
+ sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1]
+ top_idx = sort_idx_list[0]
+ return self.mapping_dict[str(top_idx)]
+
+
+def parse_config():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--clip_name", type=str)
+ parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
+ parser.add_argument("--test_path", type=str)
+ # index configuration
+ parser.add_argument("--index_matrix_path", type=str)
+ parser.add_argument("--mapping_dict_path", type=str)
+ # save configuration
+ parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
+ parser.add_argument("--save_name", type=str, help="the name of the saved file")
+ return parser.parse_args()
+
+import argparse
+if __name__ == '__main__':
+ if torch.cuda.is_available():
+ print ('Cuda is available.')
+ cuda_available = torch.cuda.is_available()
+ args = parse_config()
+ device = torch.device('cuda')
+
+ save_path_prefix = args.save_path_prefix
+ import os
+ if os.path.exists(save_path_prefix):
+ pass
+ else: # recursively construct directory
+ os.makedirs(save_path_prefix, exist_ok=True)
+ # parse save name
+ save_name = args.save_name
+ full_save_path = save_path_prefix + '/' + save_name
+ print ('full save path is {}'.format(full_save_path))
+
+ print ('Loading CLIP...')
+ from clip import CLIP
+ clip = CLIP(args.clip_name)
+ if cuda_available:
+ clip = clip.cuda(device)
+ clip.eval()
+ print ('CLIP loaded!')
+
+ clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip)
+
+ print ('Loading data...')
+ import json
+ with open(args.test_path) as f:
+ item_list = json.load(f)
+ print ('Data loaded.')
+ print ('Number of test instances is {}'.format(len(item_list)))
+
+ result_list = []
+ invalid_num = 0
+ print ('----------------------------------------------------------------')
+ with torch.no_grad():
+ test_num = len(item_list)
+ #test_num = 10
+ print ('Number of inference instances is {}'.format(test_num))
+ p = progressbar.ProgressBar(test_num)
+ p.start()
+ for p_idx in range(test_num):
+ p.update(p_idx)
+ one_test_dict = item_list[p_idx]
+
+ one_res_dict = {
+ 'split':one_test_dict['split'],
+ 'image_name':one_test_dict['image_name'],
+ #'file_path':one_test_dict['file_path'],
+ 'captions':one_test_dict['captions']
+ }
+
+ image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
+ try:
+ output_text = clipindex.search_text(image_full_path)
+ one_res_dict['prediction'] = output_text
+ result_list.append(one_res_dict)
+ except:
+ invalid_num += 1
+ print ('invalid number is {}'.format(invalid_num))
+ continue
+ p.finish()
+ print ('Inference completed!')
+
+ import json
+ with open(full_save_path, 'w') as outfile:
+ json.dump(result_list, outfile, indent=4)
+
diff --git a/clip/flickr30k_clip_retrieval.sh b/clip/flickr30k_clip_retrieval.sh
new file mode 100644
index 0000000..0dba975
--- /dev/null
+++ b/clip/flickr30k_clip_retrieval.sh
@@ -0,0 +1,8 @@
+CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\
+ --clip_name openai/clip-vit-base-patch32\
+ --test_image_prefix_path ../data/flickr30k/test_images/\
+ --test_path ../data/flickr30k/flickr30k_test.json\
+ --index_matrix_path ./flickr30k_index/index_matrix.txt\
+ --mapping_dict_path ./flickr30k_index/text_mapping.json\
+ --save_path_prefix ../inference_result/flickr30k/baselines/\
+ --save_name flickr30k_in_domain_clipretrieval.json
\ No newline at end of file
diff --git a/clip/mscoco_clip_retrieval.sh b/clip/mscoco_clip_retrieval.sh
new file mode 100644
index 0000000..cf4d893
--- /dev/null
+++ b/clip/mscoco_clip_retrieval.sh
@@ -0,0 +1,8 @@
+CUDA_VISIBLE_DEVICES=0 python clipretrieval.py\
+ --clip_name openai/clip-vit-base-patch32\
+ --test_image_prefix_path ../data/mscoco/test_images/\
+ --test_path ../data/mscoco/mscoco_test.json\
+ --index_matrix_path ./mscoco_index/index_matrix.txt\
+ --mapping_dict_path ./mscoco_index/text_mapping.json\
+ --save_path_prefix ../inference_result/mscoco/baselines/\
+ --save_name mscoco_in_domain_clipretrieval.json
\ No newline at end of file
diff --git a/clip/source_flickr30k_target_mscoco_clip_retrieval.sh b/clip/source_flickr30k_target_mscoco_clip_retrieval.sh
new file mode 100644
index 0000000..105f1c2
--- /dev/null
+++ b/clip/source_flickr30k_target_mscoco_clip_retrieval.sh
@@ -0,0 +1,8 @@
+CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\
+ --clip_name openai/clip-vit-base-patch32\
+ --test_image_prefix_path ../data/mscoco/test_images/\
+ --test_path ../data/mscoco/mscoco_test.json\
+ --index_matrix_path ./flickr30k_index/index_matrix.txt\
+ --mapping_dict_path ./flickr30k_index/text_mapping.json\
+ --save_path_prefix ../inference_result/flickr30k_model_to_mscoco/\
+ --save_name source_flickr30k_target_mscoco_clip_retrieval.json
\ No newline at end of file
diff --git a/clip/source_mscoco_target_flickr30k_clip_retrieval.sh b/clip/source_mscoco_target_flickr30k_clip_retrieval.sh
new file mode 100644
index 0000000..9902cda
--- /dev/null
+++ b/clip/source_mscoco_target_flickr30k_clip_retrieval.sh
@@ -0,0 +1,8 @@
+CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\
+ --clip_name openai/clip-vit-base-patch32\
+ --test_image_prefix_path ../data/flickr30k/test_images/\
+ --test_path ../data/flickr30k/flickr30k_test.json\
+ --index_matrix_path ./mscoco_index/index_matrix.txt\
+ --mapping_dict_path ./mscoco_index/text_mapping.json\
+ --save_path_prefix ../inference_result/mscoco_model_to_flickr30k/\
+ --save_name source_mscoco_target_flickr30k_clip_retrieval.json
\ No newline at end of file
diff --git a/magic.py b/magic.py
new file mode 100644
index 0000000..85f46a7
--- /dev/null
+++ b/magic.py
@@ -0,0 +1,99 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from re import I
+import sys
+import os
+import pathlib
+import pickle
+from argparse import Namespace
+
+import torch
+import torchvision
+from torchvision import transforms
+from transformers import GPT2Tokenizer
+
+from towhee.types.arg import arg, to_image_color
+from towhee.types.image_utils import to_pil
+from towhee.operator.base import NNOperator, OperatorFlag
+from towhee import register
+from towhee.models import clip
+
+class Magic(NNOperator):
+ """
+ Magic image captioning operator
+ """
+ def __init__(self, model_name: str):
+ super().__init__()
+ path = str(pathlib.Path(__file__).parent)
+ sys.path.append(path)
+ from clip import CLIP
+ from simctg import SimCTG
+ sys.path.pop()
+
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ # Load Language Model
+ language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco'
+ sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
+ self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device)
+ self.generation_model.eval()
+
+ model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32"
+ self.clip = CLIP(model_name).to(self.device)
+ self.clip.eval()
+
+
+ def _preprocess(self, img):
+ img = to_pil(img)
+ processed_img = self.transf_1(img)
+ processed_img = self.transf_2(processed_img)
+ processed_img = processed_img.to(self.device)
+ return processed_img
+
+ @arg(1, to_image_color('RGB'))
+ def inference_single_data(self, data):
+ text = self._inference_from_image(data)
+ return text
+
+ def __call__(self, data):
+ if not isinstance(data, list):
+ data = [data]
+ else:
+ data = data
+ results = []
+ for single_data in data:
+ result = self.inference_single_data(single_data)
+ results.append(result)
+ if len(data) == 1:
+ return results[0]
+ else:
+ return results
+
+ @arg(1, to_image_color('RGB'))
+ def _inference_from_image(self, img):
+ #img = self._preprocess(img).unsqueeze(0)
+ k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16
+ eos_token = '<|endoftext|>'
+ with torch.no_grad():
+ output = generation_model.magic_search(input_ids, k,
+ alpha, decoding_len, beta, image_instance, clip, 60)
+
+ return out
+
+ def _configs(self):
+ config = {}
+ config['expansionnet_rf'] = {}
+ config['expansionnet_rf']['weights'] = 'rf_model.pth'
+ return config
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..e69de29
diff --git a/zerocap/README.md b/zerocap/README.md
new file mode 100644
index 0000000..1839550
--- /dev/null
+++ b/zerocap/README.md
@@ -0,0 +1,89 @@
+### Our Implementation of the ZeroCap Baseline Model
+
+****
+### Catalogue:
+* 1. Environment Preparation
+* 2. Image Captioning on MSCOCO
+* 3. Image Captioning on Flickr30k
+* 4. Cross Domain Image Captioning on MSCOCO
+* 5. Cross Domain Image Captioning on Flickr30k
+* 6. Citation
+* 7. Acknowledgements
+
+****
+
+
+
+#### 1. Environment Preparation:
+To install the correct environment, please run the following command:
+```yaml
+pip install -r requirements.txt
+```
+
+****
+
+
+
+#### 2. Image Captioning on MSCOCO:
+To perform image captioning on MSCOCO, please run the following command:
+```yaml
+chmod +x ./mscoco_zerocap.sh
+./mscoco_zerocap.sh
+```
+
+****
+
+
+
+#### 3. Image Captioning on Flickr30k:
+To perform image captioning on Flickr30k, please run the following command:
+```yaml
+chmod +x ./flickr30k_zerocap.sh
+./flickr30k_zerocap.sh
+```
+
+****
+
+
+
+#### 4. Cross Domain Image Captioning on MSCOCO:
+To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command:
+```yaml
+chmod +x ./flickr30k_to_mscoco_zerocap.sh
+./flickr30k_to_mscoco_zerocap.sh
+```
+
+****
+
+
+
+#### 5. Cross Domain Image Captioning on Flickr30k:
+To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command:
+```yaml
+chmod +x ./mscoco_to_flickr30k_zerocap.sh
+./mscoco_to_flickr30k_zerocap.sh
+```
+
+****
+
+
+
+#### 6. Citation:
+If you find our code helpful, please cite the original paper as
+
+```bibtex
+@article{tewel2021zero,
+ title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic},
+ author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior},
+ journal={arXiv preprint arXiv:2111.14447},
+ year={2021}
+}
+```
+
+****
+
+
+
+#### 7. Acknowledgements:
+We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text).
+
diff --git a/zerocap/cog.yaml b/zerocap/cog.yaml
new file mode 100644
index 0000000..92f13da
--- /dev/null
+++ b/zerocap/cog.yaml
@@ -0,0 +1,12 @@
+build:
+ gpu: true
+ python_version: "3.8"
+ system_packages:
+ - "libgl1-mesa-glx"
+ - "libglib2.0-0"
+ python_packages:
+ - "git+https://github.com/openai/CLIP.git"
+ - "git+https://github.com/YoadTew/zero-shot-image-to-text.git"
+
+predict: "predict.py:Predictor"
+#predict: "predict_arithmetic.py:Predictor"
\ No newline at end of file
diff --git a/zerocap/flickr30k_zerocap.sh b/zerocap/flickr30k_zerocap.sh
new file mode 100755
index 0000000..b727b12
--- /dev/null
+++ b/zerocap/flickr30k_zerocap.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+# lm_model:
+# 1. cambridgeltl/magic_mscoco
+# 2. cambridgeltl/magic_flickr30k
+CUDA_VISIBLE_DEVICES=1 python run.py \
+ --beam_size 1 \
+ --target_seq_length 16 \
+ --reset_context_delta \
+ --lm_model cambridgeltl/magic_flickr30k \
+ --test_image_prefix_path ../data/flickr30k/test_images \
+ --test_path ../data/flickr30k/flickr30k_test.json \
+ --save_path_prefix ../inference_result/flickr30k/baselines/ \
+ --save_name zerocap_result.json
diff --git a/zerocap/forbidden_tokens.npy b/zerocap/forbidden_tokens.npy
new file mode 100644
index 0000000..aeed51b
Binary files /dev/null and b/zerocap/forbidden_tokens.npy differ
diff --git a/zerocap/model/ZeroCLIP.py b/zerocap/model/ZeroCLIP.py
new file mode 100644
index 0000000..2c36fd2
--- /dev/null
+++ b/zerocap/model/ZeroCLIP.py
@@ -0,0 +1,389 @@
+import numpy as np
+from torch import nn
+from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
+from transformers.models.gpt_neo import GPTNeoForCausalLM
+import torch
+import clip
+from PIL import Image
+from datetime import datetime
+import sys
+
+
+def log_info(text, verbose=True):
+ if verbose:
+ dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
+ print(f'{dt_string} | {text}')
+ sys.stdout.flush()
+
+
+def add_context(x, y):
+ return (x[0] + y[0], x[1] + y[1])
+
+
+def convert_models_to_fp32(model):
+ for p in model.parameters():
+ p.data = p.data.float()
+
+
+class CLIPTextGenerator:
+ def __init__(self,
+ seed=0,
+ lm_model='gpt-2',
+ forbidden_tokens_file_path='./forbidden_tokens.npy',
+ clip_checkpoints='./clip_checkpoints',
+ target_seq_length=15,
+ reset_context_delta=True,
+ num_iterations=5,
+ clip_loss_temperature=0.01,
+ clip_scale=1.,
+ ce_scale=0.2,
+ stepsize=0.3,
+ grad_norm_factor=0.9,
+ fusion_factor=0.99,
+ repetition_penalty=1.,
+ end_token='.',
+ end_factor=1.01,
+ forbidden_factor=20,
+ **kwargs):
+
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # set Random seed
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ # Initialize Language model
+ self.context_prefix = ''
+
+ self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model)
+ self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True)
+ self.context_prefix = self.lm_tokenizer.bos_token
+
+ self.lm_model.to(self.device)
+ self.lm_model.eval()
+
+ self.forbidden_tokens = np.load(forbidden_tokens_file_path)
+ self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
+ (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())]
+
+ # Freeze LM weights
+ for param in self.lm_model.parameters():
+ param.requires_grad = False
+
+ # Initialize CLIP
+ self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
+ download_root=clip_checkpoints, jit=False)
+ # convert_models_to_fp32(self.clip)
+
+ # Init arguments
+ self.target_seq_length = target_seq_length
+ self.reset_context_delta = reset_context_delta
+ self.num_iterations = num_iterations
+ self.clip_loss_temperature = clip_loss_temperature
+ self.clip_scale = clip_scale
+ self.ce_scale = ce_scale
+ self.stepsize = stepsize
+ self.grad_norm_factor = grad_norm_factor
+ self.fusion_factor = fusion_factor
+ self.repetition_penalty = repetition_penalty
+ self.end_token = self.lm_tokenizer.encode(end_token)[0]
+ self.end_factor = end_factor
+ self.ef_idx = 1
+ self.forbidden_factor = forbidden_factor
+
+ def get_img_feature(self, img_path, weights):
+ imgs = [Image.open(x) for x in img_path]
+ clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
+
+ with torch.no_grad():
+ image_fts = [self.clip.encode_image(x) for x in clip_imgs]
+
+ if weights is not None:
+ image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
+ else:
+ image_features = sum(image_fts)
+
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ return image_features.detach()
+
+ def get_txt_features(self, text):
+ clip_texts = clip.tokenize(text).to(self.device)
+
+ with torch.no_grad():
+ text_features = self.clip.encode_text(clip_texts)
+
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+ return text_features.detach()
+
+ def get_combined_feature(self, img_path, texts, weights_i, weights_t):
+ imgs = [Image.open(x) for x in img_path]
+ clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
+ clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
+
+ with torch.no_grad():
+ image_fts = [self.clip.encode_image(x) for x in clip_imgs]
+ text_fts = [self.clip.encode_text(x) for x in clip_texts]
+
+ features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
+ if weights_t is not None:
+ features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
+
+ features = features / features.norm(dim=-1, keepdim=True)
+ return features.detach()
+
+ def run(self, image_features, cond_text, beam_size):
+ self.image_features = image_features
+
+ context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
+
+ output_tokens, output_text = self.generate_text(context_tokens, beam_size)
+
+ return output_text
+
+ def generate_text(self, context_tokens, beam_size):
+ context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
+
+ gen_tokens = None
+ scores = None
+ seq_lengths = torch.ones(beam_size, device=self.device)
+ is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
+
+ for i in range(self.target_seq_length):
+ probs = self.get_next_probs(i, context_tokens)
+ logits = probs.log()
+
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+
+ if gen_tokens is None:
+ gen_tokens = next_tokens
+ else:
+ gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
+ gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
+ beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ gen_tokens = gen_tokens[next_tokens_source]
+ gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
+ context_tokens = context_tokens[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+
+ context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
+
+ ####
+ tmp_scores = scores / seq_lengths
+ tmp_output_list = gen_tokens.cpu().numpy()
+ tmp_output_texts = [
+ self.lm_tokenizer.decode(tmp_output)
+ for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
+ ]
+ tmp_order = tmp_scores.argsort(descending=True)
+ tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
+ log_info(tmp_output_texts, verbose=True)
+ ####
+
+ if is_stopped.all():
+ break
+
+ scores = scores / seq_lengths
+ output_list = gen_tokens.cpu().numpy()
+ output_texts = [
+ self.lm_tokenizer.decode(output[: int(length)])
+ for output, length in zip(output_list, seq_lengths)
+ ]
+ order = scores.argsort(descending=True)
+ output_texts = [output_texts[i] for i in order]
+
+ return context_tokens, output_texts
+
+ def get_next_probs(self, i, context_tokens):
+ last_token = context_tokens[:, -1:]
+
+ if self.reset_context_delta and context_tokens.size(1) > 1:
+ context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
+
+ # Logits of LM with unshifted context
+ logits_before_shift = self.lm_model(context_tokens)["logits"]
+ logits_before_shift = logits_before_shift[:, -1, :]
+ probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
+
+ if context:
+ context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
+
+ lm_output = self.lm_model(last_token, past_key_values=context)
+ logits, past = (
+ lm_output["logits"],
+ lm_output["past_key_values"],
+ )
+ logits = logits[:, -1, :]
+
+ logits = self.update_special_tokens_logits(context_tokens, i, logits)
+
+ probs = nn.functional.softmax(logits, dim=-1)
+ probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
+ probs = probs / probs.sum()
+
+ return probs
+
+ def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
+ context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
+
+ window_mask = torch.ones_like(context[0][0]).to(self.device)
+
+ for i in range(self.num_iterations):
+ curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
+ context_delta]
+
+ for p0, p1 in curr_shift:
+ p0.retain_grad()
+ p1.retain_grad()
+
+ shifted_context = list(map(add_context, context, curr_shift))
+
+ shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
+ logits = shifted_outputs["logits"][:, -1, :]
+ probs = nn.functional.softmax(logits, dim=-1)
+
+ loss = 0.0
+
+ # CLIP LOSS
+ clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
+ loss += self.clip_scale * clip_loss
+
+ # CE/Fluency loss
+ ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
+ loss += ce_loss.sum()
+
+ loss.backward()
+
+ # ---------- Weights ----------
+ combined_scores_k = -(ce_loss)
+ combined_scores_c = -(self.clip_scale * torch.stack(clip_losses))
+
+ # minmax
+ if combined_scores_k.shape[0] == 1:
+ tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device)
+ else:
+ tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / (
+ combined_scores_k.max() - combined_scores_k.min())
+ tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / (
+ combined_scores_c.max() - combined_scores_c.min())
+
+ tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c
+ tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1)
+
+ factor = 1
+
+ # --------- Specific Gen ---------
+ sep_grads = None
+
+ for b in range(context_tokens.shape[0]):
+ tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
+ for p_ in curr_shift]
+
+ # normalize gradients
+ tmp_grad = [tuple([-self.stepsize * factor * (
+ x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
+ j] ** self.grad_norm_factor).data.cpu().numpy()
+ for j, x in enumerate(p_)])
+ for i, p_ in enumerate(curr_shift)]
+ if sep_grads is None:
+ sep_grads = tmp_grad
+ else:
+ for l_index in range(len(sep_grads)):
+ sep_grads[l_index] = list(sep_grads[l_index])
+ for k_index in range(len(sep_grads[0])):
+ sep_grads[l_index][k_index] = np.concatenate(
+ (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
+ sep_grads[l_index] = tuple(sep_grads[l_index])
+ final_grads = sep_grads
+
+ # --------- update context ---------
+ context_delta = list(map(add_context, final_grads, context_delta))
+
+ for p0, p1 in curr_shift:
+ p0.grad.data.zero_()
+ p1.grad.data.zero_()
+
+ new_context = []
+ for p0, p1 in context:
+ new_context.append((p0.detach(), p1.detach()))
+ context = new_context
+
+ context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
+ for p_ in context_delta]
+ context = list(map(add_context, context, context_delta))
+
+ new_context = []
+ for p0, p1 in context:
+ new_context.append((p0.detach(), p1.detach()))
+ context = new_context
+
+ return context
+
+ def update_special_tokens_logits(self, context_tokens, i, logits):
+ for beam_id in range(context_tokens.shape[0]):
+ for token_idx in set(context_tokens[beam_id][-4:].tolist()):
+ factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
+ logits[beam_id, token_idx] /= factor
+
+ if i >= self.ef_idx:
+ factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
+ logits[beam_id, self.end_token] *= factor
+ if i == 0:
+ start_factor = 1.6
+ factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
+ logits[beam_id, self.end_token] /= factor
+
+ for token_idx in list(self.forbidden_tokens):
+ factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
+ logits[beam_id, token_idx] /= factor
+
+ return logits
+
+ def clip_loss(self, probs, context_tokens):
+ for p_ in self.clip.transformer.parameters():
+ if p_.grad is not None:
+ p_.grad.data.zero_()
+
+ top_size = 512
+ _, top_indices = probs.topk(top_size, -1)
+
+ prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
+
+ clip_loss = 0
+ losses = []
+ for idx_p in range(probs.shape[0]):
+ top_texts = []
+ prefix_text = prefix_texts[idx_p]
+ for x in top_indices[idx_p]:
+ top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
+ text_features = self.get_txt_features(top_texts)
+
+ with torch.no_grad():
+ similiraties = (self.image_features @ text_features.T)
+ target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
+ target_probs = target_probs.type(torch.float32)
+
+ target = torch.zeros_like(probs[idx_p])
+ target[top_indices[idx_p]] = target_probs[0]
+ target = target.unsqueeze(0)
+ cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
+
+ clip_loss += cur_clip_loss
+ losses.append(cur_clip_loss)
+
+ return clip_loss, losses
diff --git a/zerocap/model/ZeroCLIP_batched.py b/zerocap/model/ZeroCLIP_batched.py
new file mode 100644
index 0000000..2c0209f
--- /dev/null
+++ b/zerocap/model/ZeroCLIP_batched.py
@@ -0,0 +1,449 @@
+import numpy as np
+from torch import nn
+from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
+from transformers.models.gpt_neo import GPTNeoForCausalLM
+import torch
+import clip
+from PIL import Image
+from datetime import datetime
+import sys
+
+class TextCLIP(nn.Module):
+ def __init__(self, model):
+ super(TextCLIP, self).__init__()
+ self.model = model
+
+ def forward(self, text):
+ return self.model.encode_text(text)
+
+
+class ImageCLIP(nn.Module):
+ def __init__(self, model):
+ super(ImageCLIP, self).__init__()
+ self.model = model
+
+ def forward(self, image):
+ return self.model.encode_image(image)
+
+def log_info(text, verbose=True):
+ if verbose:
+ dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
+ print(f'{dt_string} | {text}')
+ sys.stdout.flush()
+
+
+def add_context(x, y):
+ return (x[0] + y[0], x[1] + y[1])
+
+
+def convert_models_to_fp32(model):
+ for p in model.parameters():
+ p.data = p.data.float()
+
+
+class CLIPTextGenerator:
+ def __init__(self,
+ seed=0,
+ lm_model='gpt-2',
+ forbidden_tokens_file_path='./forbidden_tokens.npy',
+ clip_checkpoints='./clip_checkpoints',
+ target_seq_length=15,
+ reset_context_delta=True,
+ num_iterations=5,
+ clip_loss_temperature=0.01,
+ clip_scale=1.,
+ ce_scale=0.2,
+ stepsize=0.3,
+ grad_norm_factor=0.9,
+ fusion_factor=0.99,
+ repetition_penalty=1.,
+ end_token='.',
+ end_factor=1.01,
+ forbidden_factor=20,
+ **kwargs):
+
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # set Random seed
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ # Initialize Language model
+ self.context_prefix = ''
+
+ if lm_model == 'gpt-neo':
+ self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
+ self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True)
+ elif lm_model == 'gpt-2':
+ self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
+ self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True)
+ self.context_prefix = self.lm_tokenizer.bos_token
+
+ self.lm_model.to(self.device)
+ self.lm_model.eval()
+
+ self.forbidden_tokens = np.load(forbidden_tokens_file_path)
+ self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
+ (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())]
+
+ # Freeze LM weights
+ for param in self.lm_model.parameters():
+ param.requires_grad = False
+
+ # Initialize CLIP
+ self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
+ download_root=clip_checkpoints, jit=False)
+ self.clip_image = ImageCLIP(self.clip)
+ self.clip_image = torch.nn.DataParallel(self.clip_image)
+ self.clip_text = TextCLIP(self.clip)
+ self.clip_text = torch.nn.DataParallel(self.clip_text)
+
+ # Init arguments
+ self.target_seq_length = target_seq_length
+ self.reset_context_delta = reset_context_delta
+ self.num_iterations = num_iterations
+ self.clip_loss_temperature = clip_loss_temperature
+ self.clip_scale = clip_scale
+ self.ce_scale = ce_scale
+ self.stepsize = stepsize
+ self.grad_norm_factor = grad_norm_factor
+ self.fusion_factor = fusion_factor
+ self.repetition_penalty = repetition_penalty
+ self.end_token = self.lm_tokenizer.encode(end_token)[0]
+ self.end_factor = end_factor
+ self.ef_idx = 1
+ self.forbidden_factor = forbidden_factor
+
+ def get_img_feature(self, img_path, weights):
+ imgs = [Image.open(x) for x in img_path]
+ clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
+
+ with torch.no_grad():
+ image_fts = [self.clip_image(x) for x in clip_imgs]
+
+ if weights is not None:
+ image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
+ else:
+ image_features = sum(image_fts)
+
+ image_features = torch.nn.functional.normalize(image_features, dim=-1)
+ return image_features.detach()
+
+ def get_txt_features(self, text):
+ clip_texts = clip.tokenize(text).to(self.device)
+
+ with torch.no_grad():
+ text_features = self.clip_text(clip_texts)
+
+ text_features = torch.nn.functional.normalize(text_features, dim=-1)
+ return text_features.detach()
+
+ def get_combined_feature(self, img_path, texts, weights_i, weights_t):
+ imgs = [Image.open(x) for x in img_path]
+ clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
+ clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
+
+ with torch.no_grad():
+ image_fts = [self.clip.encode_image(x) for x in clip_imgs]
+ text_fts = [self.clip.encode_text(x) for x in clip_texts]
+
+ features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
+ if weights_t is not None:
+ features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
+
+ features = features / features.norm(dim=-1, keepdim=True)
+ return features.detach()
+
+ def run(self, image_features, cond_text, beam_size):
+ self.image_features = image_features
+
+ context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
+
+ output_tokens, output_text = self.generate_text(context_tokens, beam_size)
+
+ return output_text
+
+ def generate_text(self, context_tokens, beam_size):
+ context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
+
+ gen_tokens = None
+ scores = None
+ seq_lengths = torch.ones(beam_size, device=self.device)
+ is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
+
+ for i in range(self.target_seq_length):
+ probs = self.get_next_probs(i, context_tokens)
+ logits = probs.log()
+
+ if scores is None:
+ scores, next_tokens = logits.topk(beam_size, -1)
+ context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
+
+ if gen_tokens is None:
+ gen_tokens = next_tokens
+ else:
+ gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
+ gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
+ else:
+ logits[is_stopped] = -float(np.inf)
+ logits[is_stopped, 0] = 0
+ scores_sum = scores[:, None] + logits
+ seq_lengths[~is_stopped] += 1
+ scores_sum_average = scores_sum / seq_lengths[:, None]
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
+ beam_size, -1)
+ next_tokens_source = next_tokens // scores_sum.shape[1]
+ seq_lengths = seq_lengths[next_tokens_source]
+ next_tokens = next_tokens % scores_sum.shape[1]
+ next_tokens = next_tokens.unsqueeze(1)
+ gen_tokens = gen_tokens[next_tokens_source]
+ gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
+ context_tokens = context_tokens[next_tokens_source]
+ scores = scores_sum_average * seq_lengths
+ is_stopped = is_stopped[next_tokens_source]
+
+ context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
+ is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
+
+ ####
+ tmp_scores = scores / seq_lengths
+ tmp_output_list = gen_tokens.cpu().numpy()
+ tmp_output_texts = [
+ self.lm_tokenizer.decode(tmp_output)
+ for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
+ ]
+ tmp_order = tmp_scores.argsort(descending=True)
+ tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
+ log_info(tmp_output_texts, verbose=True)
+ ####
+
+ if is_stopped.all():
+ break
+
+ scores = scores / seq_lengths
+ output_list = gen_tokens.cpu().numpy()
+ output_texts = [
+ self.lm_tokenizer.decode(output[: int(length)])
+ for output, length in zip(output_list, seq_lengths)
+ ]
+ order = scores.argsort(descending=True)
+ output_texts = [output_texts[i] for i in order]
+
+ return context_tokens, output_texts
+
+ def get_next_probs(self, i, context_tokens):
+ last_token = context_tokens[:, -1:]
+
+ if self.reset_context_delta and context_tokens.size(1) > 1:
+ context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
+
+ # Logits of LM with unshifted context
+ logits_before_shift = self.lm_model(context_tokens)["logits"]
+ logits_before_shift = logits_before_shift[:, -1, :]
+ probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
+
+ if context:
+ context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
+
+ lm_output = self.lm_model(last_token, past_key_values=context)
+ logits, past = (
+ lm_output["logits"],
+ lm_output["past_key_values"],
+ )
+ logits = logits[:, -1, :]
+
+ logits = self.update_special_tokens_logits(context_tokens, i, logits)
+
+ probs = nn.functional.softmax(logits, dim=-1)
+ probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
+ probs = probs / probs.sum()
+
+ return probs
+
+ def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
+ context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
+
+ for i in range(self.num_iterations):
+ curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
+ context_delta]
+
+ for p0, p1 in curr_shift:
+ p0.retain_grad()
+ p1.retain_grad()
+
+ shifted_context = list(map(add_context, context, curr_shift))
+
+ shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
+ logits = shifted_outputs["logits"][:, -1, :]
+ probs = nn.functional.softmax(logits, dim=-1)
+
+ loss = 0.0
+
+ # CLIP LOSS
+ clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
+ loss += self.clip_scale * clip_loss
+
+ # CE/Fluency loss
+ ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
+ loss += ce_loss.sum()
+
+ loss.backward()
+
+ # --------- Specific Gen ---------
+ final_grads = self.norm_grad(context, context_tokens, curr_shift)
+
+ # --------- update context ---------
+ context_delta = list(map(add_context, final_grads, context_delta))
+
+ for p0, p1 in curr_shift:
+ p0.grad.data.zero_()
+ p1.grad.data.zero_()
+
+ new_context = []
+ for p0, p1 in context:
+ new_context.append((p0.detach(), p1.detach()))
+ context = new_context
+
+ context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
+ for p_ in context_delta]
+ context = list(map(add_context, context, context_delta))
+
+ new_context = []
+ for p0, p1 in context:
+ new_context.append((p0.detach(), p1.detach()))
+ context = new_context
+
+ return context
+
+ def norm_grad(self, context, context_tokens, curr_shift, ):
+ factor = 1
+ sep_grads = None
+ window_mask = torch.ones_like(context[0][0]).to(self.device)
+
+ for b in range(context_tokens.shape[0]):
+ tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
+ for p_ in curr_shift]
+
+ # normalize gradients
+ tmp_grad = [tuple([-self.stepsize * factor * (
+ x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
+ j] ** self.grad_norm_factor).data.cpu().numpy()
+ for j, x in enumerate(p_)])
+ for i, p_ in enumerate(curr_shift)]
+ if sep_grads is None:
+ sep_grads = tmp_grad
+ else:
+ for l_index in range(len(sep_grads)):
+ sep_grads[l_index] = list(sep_grads[l_index])
+ for k_index in range(len(sep_grads[0])):
+ sep_grads[l_index][k_index] = np.concatenate(
+ (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
+ sep_grads[l_index] = tuple(sep_grads[l_index])
+ final_grads = sep_grads
+
+ return final_grads
+
+ def update_special_tokens_logits(self, context_tokens, i, logits):
+ for beam_id in range(context_tokens.shape[0]):
+ for token_idx in set(context_tokens[beam_id][-4:].tolist()):
+ factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
+ logits[beam_id, token_idx] /= factor
+
+ if i >= self.ef_idx:
+ factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
+ logits[beam_id, self.end_token] *= factor
+ if i == 0:
+ start_factor = 1.6
+ factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
+ logits[beam_id, self.end_token] /= factor
+
+ for token_idx in list(self.forbidden_tokens):
+ factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
+ logits[beam_id, token_idx] /= factor
+
+ return logits
+
+ def clip_loss(self, probs, context_tokens):
+ for p_ in self.clip.transformer.parameters():
+ if p_.grad is not None:
+ p_.grad.data.zero_()
+
+ top_size = 512
+ top_probs, top_indices = probs.topk(top_size, -1)
+
+ prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens]
+
+ clip_loss = 0
+ losses = []
+
+ top_texts = []
+ for idx_p in range(probs.shape[0]):
+ prefix_text = prefix_texts[idx_p]
+ for x in top_indices[idx_p]:
+ top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
+
+ text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1)
+
+ with torch.no_grad():
+ similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1)
+ similiraties = similiraties.reshape(probs.size(0), -1)
+ target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
+ target_probs = target_probs.type(torch.float32)
+
+ clip_loss += torch.sum(-(target_probs * torch.log(top_probs)))
+ # for idx_p in range(probs.shape[0]):
+ # top_texts = []
+ # prefix_text = prefix_texts[idx_p]
+ # for x in top_indices[idx_p]:
+ # top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
+ # text_features = self.get_txt_features(top_texts)
+ #
+ # with torch.no_grad():
+ # similiraties = (self.image_features @ text_features.T)
+ # target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
+ # target_probs = target_probs.type(torch.float32)
+ #
+ # target = torch.zeros_like(probs[idx_p])
+ # target[top_indices[idx_p]] = target_probs[0]
+ # target = target.unsqueeze(0)
+ # cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
+ #
+ # clip_loss += cur_clip_loss
+ # losses.append(cur_clip_loss)
+
+ return clip_loss, losses
+
+ def clip_loss_old(self, probs, context_tokens):
+ for p_ in self.clip.transformer.parameters():
+ if p_.grad is not None:
+ p_.grad.data.zero_()
+
+ top_size = 512
+ _, top_indices = probs.topk(top_size, -1)
+
+ prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
+
+ clip_loss = 0
+ losses = []
+ for idx_p in range(probs.shape[0]):
+ top_texts = []
+ prefix_text = prefix_texts[idx_p]
+ for x in top_indices[idx_p]:
+ top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
+ text_features = self.get_txt_features(top_texts)
+
+ with torch.no_grad():
+ similiraties = (self.image_features @ text_features.T)
+ target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
+ target_probs = target_probs.type(torch.float32)
+
+ target = torch.zeros_like(probs[idx_p])
+ target[top_indices[idx_p]] = target_probs[0]
+ target = target.unsqueeze(0)
+ cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
+
+ clip_loss += cur_clip_loss
+ losses.append(cur_clip_loss)
+
+ return clip_loss, losses
\ No newline at end of file
diff --git a/zerocap/model/__init__.py b/zerocap/model/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc
new file mode 100644
index 0000000..093e083
Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc differ
diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc
new file mode 100644
index 0000000..5f08c9e
Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc differ
diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc
new file mode 100644
index 0000000..aa91bfc
Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc differ
diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc
new file mode 100644
index 0000000..816cae9
Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc differ
diff --git a/zerocap/model/__pycache__/__init__.cpython-36.pyc b/zerocap/model/__pycache__/__init__.cpython-36.pyc
new file mode 100644
index 0000000..5a18c45
Binary files /dev/null and b/zerocap/model/__pycache__/__init__.cpython-36.pyc differ
diff --git a/zerocap/model/__pycache__/__init__.cpython-37.pyc b/zerocap/model/__pycache__/__init__.cpython-37.pyc
new file mode 100644
index 0000000..3108a5a
Binary files /dev/null and b/zerocap/model/__pycache__/__init__.cpython-37.pyc differ
diff --git a/zerocap/mscoco_zerocap.sh b/zerocap/mscoco_zerocap.sh
new file mode 100755
index 0000000..a06ae50
--- /dev/null
+++ b/zerocap/mscoco_zerocap.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+# lm_model:
+# 1. cambridgeltl/magic_mscoco
+# 2. cambridgeltl/magic_flickr30k
+CUDA_VISIBLE_DEVICES=1 python run.py \
+ --beam_size 1 \
+ --target_seq_length 16 \
+ --reset_context_delta \
+ --lm_model cambridgeltl/magic_mscoco \
+ --test_image_prefix_path ../data/mscoco/test_images \
+ --test_path ../data/mscoco/mscoco_test.json \
+ --save_path_prefix ../inference_result/mscoco/baselines/ \
+ --save_name zerocap_result.json
diff --git a/zerocap/predict.py b/zerocap/predict.py
new file mode 100644
index 0000000..46271f4
--- /dev/null
+++ b/zerocap/predict.py
@@ -0,0 +1,117 @@
+import os
+import tempfile
+import sys
+sys.path.append('CLIP')
+from pathlib import Path
+import cog
+import argparse
+import torch
+import clip
+from model.ZeroCLIP import CLIPTextGenerator
+
+def perplexity_score(text, lm_model, lm_tokenizer, device):
+ encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
+ input_ids = encodings.input_ids.to(device)
+ target_ids = input_ids.clone()
+
+ outputs = lm_model(input_ids, labels=target_ids)
+ log_likelihood = outputs[0]
+ ll = log_likelihood.item()
+
+ return ll
+
+class Predictor(cog.Predictor):
+ def setup(self):
+ self.args = get_args()
+ self.args.reset_context_delta = True
+ self.text_generator = CLIPTextGenerator(**vars(self.args))
+
+ @cog.input(
+ "image",
+ type=Path,
+ help="input image"
+ )
+ @cog.input(
+ "cond_text",
+ type=str,
+ default='Image of a',
+ help="conditional text",
+ )
+ @cog.input(
+ "beam_size",
+ type=int,
+ default=5, min=1, max=10,
+ help="Number of beams to use",
+ )
+ @cog.input(
+ "end_factor",
+ type=float,
+ default=1.01, min=1.0, max=1.10,
+ help="Higher value for shorter captions",
+ )
+ @cog.input(
+ "max_seq_length",
+ type=int,
+ default=15, min=1, max=20,
+ help="Maximum number of tokens to generate",
+ )
+ @cog.input(
+ "ce_loss_scale",
+ type=float,
+ default=0.2, min=0.0, max=0.6,
+ help="Scale of cross-entropy loss with un-shifted language model",
+ )
+ def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale):
+ self.args.cond_text = cond_text
+ self.text_generator.end_factor = end_factor
+ self.text_generator.target_seq_length = max_seq_length
+ self.text_generator.ce_scale = ce_loss_scale
+
+ image_features = self.text_generator.get_img_feature([str(image)], None)
+ captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
+
+ # CLIP SCORE
+ encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
+ for c in captions]
+ encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
+ best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
+
+ # Perplexity SCORE
+ ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
+ best_ppl_index = torch.tensor(ppl_scores).argmin().item()
+
+ best_clip_caption = self.args.cond_text + captions[best_clip_idx]
+ best_mixed = self.args.cond_text + captions[0]
+ best_PPL = self.args.cond_text + captions[best_ppl_index]
+
+ final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
+
+ return final
+ # return self.args.cond_text + captions[best_clip_idx]
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
+ parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
+ parser.add_argument("--target_seq_length", type=int, default=15)
+ parser.add_argument("--cond_text", type=str, default="Image of a")
+ parser.add_argument("--reset_context_delta", action="store_true",
+ help="Should we reset the context at each token gen")
+ parser.add_argument("--num_iterations", type=int, default=5)
+ parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
+ parser.add_argument("--clip_scale", type=float, default=1)
+ parser.add_argument("--ce_scale", type=float, default=0.2)
+ parser.add_argument("--stepsize", type=float, default=0.3)
+ parser.add_argument("--grad_norm_factor", type=float, default=0.9)
+ parser.add_argument("--fusion_factor", type=float, default=0.99)
+ parser.add_argument("--repetition_penalty", type=float, default=1)
+ parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
+ parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
+ parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
+ parser.add_argument("--beam_size", type=int, default=5)
+
+ args = parser.parse_args('')
+ return args
diff --git a/zerocap/predict_arithmetic.py b/zerocap/predict_arithmetic.py
new file mode 100644
index 0000000..1e2ade2
--- /dev/null
+++ b/zerocap/predict_arithmetic.py
@@ -0,0 +1,129 @@
+import os
+import tempfile
+import sys
+sys.path.append('CLIP')
+from pathlib import Path
+import cog
+import argparse
+import torch
+import clip
+from model.ZeroCLIP import CLIPTextGenerator
+
+def perplexity_score(text, lm_model, lm_tokenizer, device):
+ encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
+ input_ids = encodings.input_ids.to(device)
+ target_ids = input_ids.clone()
+
+ outputs = lm_model(input_ids, labels=target_ids)
+ log_likelihood = outputs[0]
+ ll = log_likelihood.item()
+
+ return ll
+
+class Predictor(cog.Predictor):
+ def setup(self):
+ self.args = get_args()
+ self.args.reset_context_delta = True
+ self.text_generator = CLIPTextGenerator(**vars(self.args))
+
+ @cog.input(
+ "image1",
+ type=Path,
+ help="Final result will be: image1 + (image2 - image3)"
+ )
+ @cog.input(
+ "image2",
+ type=Path,
+ help="Final result will be: image1 + (image2 - image3)"
+ )
+ @cog.input(
+ "image3",
+ type=Path,
+ help="Final result will be: image1 + (image2 - image3)"
+ )
+ @cog.input(
+ "cond_text",
+ type=str,
+ default='Image of a',
+ help="conditional text",
+ )
+ @cog.input(
+ "beam_size",
+ type=int,
+ default=3, min=1, max=10,
+ help="Number of beams to use",
+ )
+ @cog.input(
+ "end_factors",
+ type=float,
+ default=1.06, min=1.0, max=1.10,
+ help="Higher value for shorter captions",
+ )
+ @cog.input(
+ "max_seq_lengths",
+ type=int,
+ default=3, min=1, max=20,
+ help="Maximum number of tokens to generate",
+ )
+ @cog.input(
+ "ce_loss_scale",
+ type=float,
+ default=0.2, min=0.0, max=0.6,
+ help="Scale of cross-entropy loss with un-shifted language model",
+ )
+ def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale):
+ self.args.cond_text = cond_text
+ self.text_generator.end_factor = end_factors
+ self.text_generator.target_seq_length = max_seq_lengths
+ self.text_generator.ce_scale = ce_loss_scale
+ self.text_generator.fusion_factor = 0.95
+ self.text_generator.grad_norm_factor = 0.95
+
+ image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None)
+ captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
+
+ # CLIP SCORE
+ encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
+ for c in captions]
+ encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
+ best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
+
+ # Perplexity SCORE
+ ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
+ best_ppl_index = torch.tensor(ppl_scores).argmin().item()
+
+ best_clip_caption = self.args.cond_text + captions[best_clip_idx]
+ best_mixed = self.args.cond_text + captions[0]
+ best_PPL = self.args.cond_text + captions[best_ppl_index]
+
+ final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
+
+ return final
+ # return self.args.cond_text + captions[best_clip_idx]
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
+ parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
+ parser.add_argument("--target_seq_length", type=int, default=15)
+ parser.add_argument("--cond_text", type=str, default="Image of a")
+ parser.add_argument("--reset_context_delta", action="store_true",
+ help="Should we reset the context at each token gen")
+ parser.add_argument("--num_iterations", type=int, default=5)
+ parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
+ parser.add_argument("--clip_scale", type=float, default=1)
+ parser.add_argument("--ce_scale", type=float, default=0.2)
+ parser.add_argument("--stepsize", type=float, default=0.3)
+ parser.add_argument("--grad_norm_factor", type=float, default=0.95)
+ parser.add_argument("--fusion_factor", type=float, default=0.95)
+ parser.add_argument("--repetition_penalty", type=float, default=1)
+ parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
+ parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
+ parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
+ parser.add_argument("--beam_size", type=int, default=5)
+
+ args = parser.parse_args('')
+ return args
diff --git a/zerocap/requirements.txt b/zerocap/requirements.txt
new file mode 100644
index 0000000..0eaf0ad
--- /dev/null
+++ b/zerocap/requirements.txt
@@ -0,0 +1,3 @@
+ftfy
+regex
+tqdm
diff --git a/zerocap/run.py b/zerocap/run.py
new file mode 100644
index 0000000..fab33b9
--- /dev/null
+++ b/zerocap/run.py
@@ -0,0 +1,131 @@
+import argparse
+import ipdb
+from tqdm import tqdm
+import progressbar
+import torch
+import ipdb
+import clip
+from model.ZeroCLIP import CLIPTextGenerator
+from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
+
+def get_args():
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
+ parser.add_argument("--test_path", type=str)
+ parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
+ parser.add_argument("--save_name", type=str, help="the name of the saved file")
+
+ parser.add_argument("--seed", type=int, default=0)
+ parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
+ parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
+ parser.add_argument("--target_seq_length", type=int, default=15)
+ parser.add_argument("--cond_text", type=str, default="Image of a")
+ parser.add_argument("--reset_context_delta", action="store_true",
+ help="Should we reset the context at each token gen")
+ parser.add_argument("--num_iterations", type=int, default=5)
+ parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
+ parser.add_argument("--clip_scale", type=float, default=1)
+ parser.add_argument("--ce_scale", type=float, default=0.2)
+ parser.add_argument("--stepsize", type=float, default=0.3)
+ parser.add_argument("--grad_norm_factor", type=float, default=0.9)
+ parser.add_argument("--fusion_factor", type=float, default=0.99)
+ parser.add_argument("--repetition_penalty", type=float, default=1)
+ parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
+ parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
+ parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
+ parser.add_argument("--beam_size", type=int, default=1)
+
+ parser.add_argument("--multi_gpu", action="store_true")
+
+ parser.add_argument('--run_type',
+ default='caption',
+ nargs='?',
+ choices=['caption', 'arithmetics'])
+
+ parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
+ help="Path to image for captioning")
+
+ parser.add_argument("--arithmetics_imgs", nargs="+",
+ default=['example_images/arithmetics/woman2.jpg',
+ 'example_images/arithmetics/king2.jpg',
+ 'example_images/arithmetics/man2.jpg'])
+ parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
+
+ args = parser.parse_args()
+
+ return args
+
+def run(args, text_generator, img_path):
+ image_features = text_generator.get_img_feature([img_path], None)
+ captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
+
+ encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
+ encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
+ best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
+ return captions
+
+
+if __name__ == '__main__':
+ if torch.cuda.is_available():
+ print ('Cuda is available.')
+ cuda_available = torch.cuda.is_available()
+ args = get_args()
+ device = torch.device('cuda')
+
+ save_path_prefix = args.save_path_prefix
+ import os
+ if os.path.exists(save_path_prefix):
+ pass
+ else: # recursively construct directory
+ os.makedirs(save_path_prefix, exist_ok=True)
+ # parse save name
+ save_name = args.save_name
+ full_save_path = save_path_prefix + '/' + save_name
+ print ('full save path is {}'.format(full_save_path))
+
+ print ('Loading data...')
+ import json
+ with open(args.test_path) as f:
+ item_list = json.load(f)
+ print ('Data loaded.')
+ print ('Number of test instances is {}'.format(len(item_list)))
+
+ # ZeroCap generator
+ text_generator = CLIPTextGenerator(**vars(args))
+
+ result_list = []
+ invalid_num = 0
+ print ('----------------------------------------------------------------')
+ test_num = len(item_list)
+ #test_num = 10
+ print ('Number of inference instances is {}'.format(test_num))
+ p = progressbar.ProgressBar(test_num)
+ p.start()
+ for p_idx in tqdm(range(test_num)):
+ p.update(p_idx)
+ one_test_dict = item_list[p_idx]
+
+ one_res_dict = {
+ 'split':one_test_dict['split'],
+ 'image_name':one_test_dict['image_name'],
+ #'file_path':one_test_dict['file_path'],
+ 'captions':one_test_dict['captions']
+ }
+
+ image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
+ try:
+ output_text = run(args, text_generator, img_path=image_full_path)
+ one_res_dict['prediction'] = output_text[0]
+ result_list.append(one_res_dict)
+ except Exception as error:
+ print(f'[!] ERROR:', error)
+ invalid_num += 1
+ print ('invalid number is {}'.format(invalid_num))
+ continue
+ p.finish()
+ print ('Inference completed!')
+
+ import json
+ with open(full_save_path, 'w') as outfile:
+ json.dump(result_list, outfile, indent=4)
diff --git a/zerocap/setup.py b/zerocap/setup.py
new file mode 100644
index 0000000..8ae2efe
--- /dev/null
+++ b/zerocap/setup.py
@@ -0,0 +1,19 @@
+import os
+
+import pkg_resources
+from setuptools import setup, find_packages
+
+setup(
+ name="zero-shot-image-to-text",
+ py_modules=["zero-shot-image-to-text"],
+ version="1.0",
+ description="",
+ packages=find_packages(),
+ install_requires=[
+ str(r)
+ for r in pkg_resources.parse_requirements(
+ open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
+ )
+ ],
+ include_package_data=True
+)
\ No newline at end of file