diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..627dc5d --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .magic import Magic + +def magic(model_name: str): + return Magic(model_name) diff --git a/clip/README.md b/clip/README.md new file mode 100644 index 0000000..dd7a695 --- /dev/null +++ b/clip/README.md @@ -0,0 +1,141 @@ +## CLIP +This folder illustrates how to use CLIP to build text index and to conduct cross-modal retrieval baseline. + +**** +## Catalogue: +* 1. Build Text Index + * 1.1. Build Text Index for MSCOCO + * 1.1.1. Download Our Built Index + * 1.1.2. Construct the Index by Yourself + * 1.2. Build Text Index for Flickr30k + * 1.2.1. Download Our Built Index + * 1.2.2. Construct the Index by Yourself +* 2. CLIP Retrieval Baseline + * 2.1. In Domain CLIP Retrieval + * 2.2. Cross Domain CLIP Retrieval + +**** + + + +### 1. Build Text Index: +We show how to build the text index, from which the caption is retrieved from, using CLIP. + + + +#### 1.1. Build Text Index for MSCOCO: +First, we demonstrate how to build text index for MSCOCO. + + + +#### 1.1.1. Download Our Post-processed Index: +We share our built index for MSCOCO via this [[link]](https://drive.google.com/file/d/1Dx_RPeAmydS6ZYuiJ-dLlK9-DjDZkxAh/view?usp=sharing). After downloading, unzip the downloaded file **mscoco_index.zip** under the current directory. + +> **** The resulting directory looks like: + + . + ├── ./mscoco_index/ + ├── index_matrix.txt # The file that stores the representations of captions from the training set of MSCOCO. Each row is a vector that corresponds to a specific caption from the training set. + └── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption. + + + +#### 1.1.2. Construct the Index by Yourself: + +You can also rebuild the index by yourself. First, you should make sure you have downloaded the MSCOCO data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). Then, you can run the following command to build the index. +```yaml +chmod +x ./build_mscoco_index.sh +./build_mscoco_index.sh +``` +The arguments are as follows: +* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. +* `--text_file_path`: Where the training text corpus stores. +* `--save_index_prefix`: In which directory you would like to store your index files. +* `--save_index_name`: The saved name of the caption representations. +* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions. +* `--batch_size`: The inference batch size. + + + + +#### 1.2. Build Text Index for Flickr30k: +Next, we demonstrate how to build text index for Flickr30k. + + + +#### 1.2.1. Download Our Post-processed Index: +We share our built index for Flickr30k via this [[link]](https://drive.google.com/file/d/1hS58_ir5pdZZPckApCtlz2RyasCQbrPf/view?usp=sharing). After downloading, unzip the downloaded file **flickr30k_index.zip** under the current directory. + +> **** The resulting directory looks like: + + . + ├── ./flickr30k_index/ + ├── index_matrix.txt # The file that stores the representations of captions from the training set of Flickr30k. Each row is a vector that corresponds to a specific caption from the training set. + └── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption. + + + +#### 1.2.2. Construct the Index by Yourself: + +You can also rebuild the index by yourself. First, you should make sure you have downloaded the Flickr30k data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). Then, you can run the following command to build the index. +```yaml +chmod +x ./build_flickr30k_index.sh +./build_flickr30k_index.sh +``` +The arguments are as follows: +* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. +* `--text_file_path`: Where the training text corpus stores. +* `--save_index_prefix`: In which directory you would like to store your index files. +* `--save_index_name`: The saved name of the caption representations. +* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions. +* `--batch_size`: The inference batch size. + +**** + + + +### 2. CLIP Retrieval Baseline: +Here, we show how to conduct the CLIP retrieval baseline. + + + +#### 2.1. In Domain CLIP Retrieval: +To retrieve the captions from the in domain training set, you should run the following command: +```yaml +chmod +x ./X_clip_retrieval.sh +./X_clip_retrieval.sh +``` +Here, X is in ['mscoco', 'flickr30k'] which corresponds for the MSCOCO and Flickr30k benchmarks. + +The arguments are as follows: +* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. +* `--test_image_prefix_path`: Where the test set images stores. +* `--test_path`: Where the reference test captions file stores. +* `--index_matrix_path`: The path of the representation index file. +* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions. +* `--save_path_prefix`: Where to save the inferenced result. +* `--save_name`: The saved name of the inferenced result. + +**[Note]** As we are conducting in domain CLIP retrieval, the test images and the caption index should come from the same benchmark. + + + + +#### 2.2. Cross Domain CLIP Retrieval: +To retrieve the captions from the cross domain training set, you should run the following command: +```yaml +chmod +x ./source_X_target_Y_clip_retrieval.sh +./source_X_target_Y_clip_retrieval.sh +``` +Here, X is the source domain from ['mscoco', 'flickr30k'] and Y is the target domain from ['flickr30k', 'mscoco']. + +The arguments are as follows: +* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. +* `--test_image_prefix_path`: Where the test set images stores. +* `--test_path`: Where the reference test captions file stores. +* `--index_matrix_path`: The path of the representation index file. +* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions. +* `--save_path_prefix`: Where to save the inferenced result. +* `--save_name`: The saved name of the inferenced result. + +**[Note]** As we are conducting cross domain CLIP retrieval, the test images and the caption index should come from **different** benchmarks. diff --git a/clip/build_flickr30k_index.sh b/clip/build_flickr30k_index.sh new file mode 100644 index 0000000..5579beb --- /dev/null +++ b/clip/build_flickr30k_index.sh @@ -0,0 +1,7 @@ +CUDA_VISIBLE_DEVICES=1 python build_text_index.py\ + --clip_name openai/clip-vit-base-patch32\ + --text_file_path ../data/flickr30k/flickr30k_train.json\ + --save_index_prefix ./flickr30k_index/\ + --save_index_name index_matrix.txt\ + --save_mapping_dict_name text_mapping.json\ + --batch_size 128 \ No newline at end of file diff --git a/clip/build_mscoco_index.sh b/clip/build_mscoco_index.sh new file mode 100644 index 0000000..c053f75 --- /dev/null +++ b/clip/build_mscoco_index.sh @@ -0,0 +1,7 @@ +CUDA_VISIBLE_DEVICES=0 python build_text_index.py\ + --clip_name openai/clip-vit-base-patch32\ + --text_file_path ../data/mscoco/mscoco_train.json\ + --save_index_prefix ./mscoco_index/\ + --save_index_name index_matrix.txt\ + --save_mapping_dict_name text_mapping.json\ + --batch_size 128 \ No newline at end of file diff --git a/clip/build_text_index.py b/clip/build_text_index.py new file mode 100644 index 0000000..98461a5 --- /dev/null +++ b/clip/build_text_index.py @@ -0,0 +1,105 @@ +import sys +import torch +import numpy as np +import progressbar +import os + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32") + parser.add_argument("--text_file_path", type=str) + # save configuration + parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index') + parser.add_argument("--save_index_name", type=str) + parser.add_argument("--save_mapping_dict_name", type=str, + help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text") + # inference configuration + parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP") + return parser.parse_args() + +def load_batch_text(text_file_path, batch_size): + import json + with open(text_file_path) as f: + item_list = json.load(f) + + text_list = [] + for item in item_list: + captions = item["captions"] + for cap in captions: + text_list.append(cap) + print ('Number of text instances is {}'.format(len(text_list))) + + data_num = len(text_list) + batch_num = data_num // batch_size + batch_text_list = [] + s_idx, e_idx = 0, batch_size + for p_idx in range(batch_num): + one_batch_text_list = [] + for idx in range(s_idx, e_idx): + one_batch_text_list.append(text_list[idx]) + batch_text_list.append(one_batch_text_list) + return batch_text_list + + +import argparse +if __name__ == '__main__': + if torch.cuda.is_available(): + print ('Cuda is available.') + cuda_available = torch.cuda.is_available() + args = parse_config() + device = torch.device('cuda') + + import os + if os.path.exists(args.save_index_prefix): + pass + else: # recursively construct directory + os.makedirs(args.save_index_prefix, exist_ok=True) + + print ('Loading CLIP...') + from clip import CLIP + model = CLIP(args.clip_name) + if cuda_available: + model = model.cuda(device) + model.eval() + print ('CLIP loaded!') + + print ('Loading text data...') + batch_text_list = load_batch_text(args.text_file_path, args.batch_size) + print ('Text data loaded.') + + res_text_vec_list, res_text_list = [], [] + batch_num = len(batch_text_list) + print ('Number of batches is {}'.format(batch_num)) + print ('Start inference...') + p = progressbar.ProgressBar(batch_num) + p.start() + with torch.no_grad(): + for p_idx in range(batch_num): + p.update(p_idx) + one_text_batch = batch_text_list[p_idx] + one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu() + one_batch_vec_list = one_batch_vec.unbind(dim=0) + bsz = len(one_batch_vec_list) + for k in range(bsz): + res_text_vec_list.append(one_batch_vec_list[k].numpy()) + res_text_list.append(one_text_batch[k]) + p.finish() + assert len(res_text_vec_list) == len(res_text_list) + print ('Inference completed!') + + index_text_mapping_dict = {} + for k in range(len(res_text_list)): + index_text_mapping_dict[k] = res_text_list[k] + mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name + import json + with open(mapping_list_save_path, 'w') as outfile: + json.dump(index_text_mapping_dict, outfile, indent=4) + print ('Mapping dictionary saved!') + + print ('Start buiding index...') + index_save_path = args.save_index_prefix + '/' + args.save_index_name + with open(index_save_path, 'w', encoding = 'utf8') as o: + for vec in res_text_vec_list: + one_text = ' '.join([str(num) for num in vec]).strip() + o.writelines(one_text + '\n') + print ('Index completed!') diff --git a/clip/clip.py b/clip/clip.py new file mode 100644 index 0000000..05b1c1c --- /dev/null +++ b/clip/clip.py @@ -0,0 +1,146 @@ +import torch +import requests +from torch import nn +from PIL import Image + +class CLIP(nn.Module): + def __init__(self, model_name): + super(CLIP, self).__init__() + # model name: e.g. openai/clip-vit-base-patch32 + print ('Initializing CLIP model...') + from transformers import CLIPProcessor, CLIPModel + self.model = CLIPModel.from_pretrained(model_name) + self.model.eval() + self.processor = CLIPProcessor.from_pretrained(model_name) + from transformers import CLIPTokenizer + self.tokenizer = CLIPTokenizer.from_pretrained(model_name) + self.cuda_has_been_checked = False + print ('CLIP model initialized.') + + def check_cuda(self): + self.cuda_available = next(self.model.parameters()).is_cuda + self.device = next(self.model.parameters()).get_device() + if self.cuda_available: + print ('Cuda is available.') + print ('Device is {}'.format(self.device)) + else: + print ('Cuda is not available.') + print ('Device is {}'.format(self.device)) + + @torch.no_grad() + def compute_image_representation_from_image_path(self, image_path): + if not self.cuda_has_been_checked: + self.check_cuda() + self.cuda_has_been_checked = True + else: + pass + # image_path: the path of the image + image = Image.open(image_path) + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs['pixel_values'] + if self.cuda_available: + pixel_values = pixel_values.cuda(self.device) + visual_outputs = self.model.vision_model(pixel_values=pixel_values) + image_embeds = visual_outputs[1] + image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] + return image_embeds + + def compute_image_representation_from_image_instance(self, image): + if not self.cuda_has_been_checked: + self.check_cuda() + self.cuda_has_been_checked = True + else: + pass + # image_path: the path of the image + inputs = self.processor(images=image, return_tensors="pt") + pixel_values = inputs['pixel_values'] + if self.cuda_available: + pixel_values = pixel_values.cuda(self.device) + visual_outputs = self.model.vision_model(pixel_values=pixel_values) + image_embeds = visual_outputs[1] + image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] + return image_embeds + + def compute_text_representation(self, text_list): + if not self.cuda_has_been_checked: + self.check_cuda() + self.cuda_has_been_checked = True + else: + pass + # text_list: a list of text + text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", + max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) + # self.tokenizer.max_len_single_sentence + 2 = 77 + input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] + if self.cuda_available: + input_ids = input_ids.cuda(self.device) + attention_mask = attention_mask.cuda(self.device) + text_outputs = self.model.text_model( + input_ids=input_ids, + attention_mask=attention_mask + ) + text_embeds = text_outputs[1] + text_embeds = self.model.text_projection(text_embeds) + return text_embeds + + def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds): + ''' + image_embeds: 1 x embed_dim + text_embeds: len(text_list) x embed_dim + ''' + image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) + logit_scale = self.model.logit_scale.exp() + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale + logits_per_image = logits_per_text.T + return logits_per_image.softmax(dim=1) # 1 x len(text_list) + + def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list): + text_embeds = self.compute_text_representation(text_list) + return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds) + + ### -------------------- functions for building index ---------------------- ### + def compute_batch_index_image_features(self, image_list): + ''' + # list of image instances + ''' + if not self.cuda_has_been_checked: + self.check_cuda() + self.cuda_has_been_checked = True + else: + pass + # image_path: the path of the image + inputs = self.processor(images=image_list, return_tensors="pt") + pixel_values = inputs['pixel_values'] + if self.cuda_available: + pixel_values = pixel_values.cuda(self.device) + visual_outputs = self.model.vision_model(pixel_values=pixel_values) + image_embeds = visual_outputs[1] + image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] + return image_embeds # len(image_list) x embed_dim + + def compute_batch_index_text_representation(self, text_list): + if not self.cuda_has_been_checked: + self.check_cuda() + self.cuda_has_been_checked = True + else: + pass + # text_list: a list of text + #text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt") + text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", + max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) + input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] + if self.cuda_available: + input_ids = input_ids.cuda(self.device) + attention_mask = attention_mask.cuda(self.device) + text_outputs = self.model.text_model( + input_ids=input_ids, + attention_mask=attention_mask + ) + text_embeds = text_outputs[1] + text_embeds = self.model.text_projection(text_embeds) + return text_embeds + #logit_scale = self.model.logit_scale.exp() + #text_embeds = text_embeds * logit_scale + #return text_embeds + diff --git a/clip/clipretrieval.py b/clip/clipretrieval.py new file mode 100644 index 0000000..bd77cbe --- /dev/null +++ b/clip/clipretrieval.py @@ -0,0 +1,135 @@ +import json +import copy +import torch +import progressbar +import numpy as np +from PIL import Image + +class CLIPIndex: + def __init__(self, index_matrix_path, mapping_dict_path, clip): + ''' + index_path: the pre-trained index + mapping_dict_path: the pre-indexed mapping dictionary + clip: the pre-trained clip model + ''' + print ('Loading index...') + self.index_matrix = self.normalization(self.load_matrix(index_matrix_path)) + print ('Index loaded.') + print (self.index_matrix.shape) + with open(mapping_dict_path) as f: + self.mapping_dict = json.load(f) + self.clip = clip + + def load_matrix(self, in_f): + matrix_list = [] + with open(in_f, 'r', encoding = 'utf8') as i: + lines = i.readlines() + for l in lines: + one_vec = [float(num) for num in l.strip('\n').split()] + matrix_list.append(one_vec) + return np.array(matrix_list) + + def normalization(self, matrix): + ''' + matrix: num_instance x num_feature + ''' + return matrix / np.linalg.norm(matrix, axis=1, keepdims=True) + + def get_image_representation(self, image_path): + image_instance = Image.open(image_path) + image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy() + image_vec = self.normalization(image_vec) + return image_vec + + def search_text(self, image_path): + image_vec = self.get_image_representation(image_path) + sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1] + top_idx = sort_idx_list[0] + return self.mapping_dict[str(top_idx)] + + +def parse_config(): + parser = argparse.ArgumentParser() + parser.add_argument("--clip_name", type=str) + parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") + parser.add_argument("--test_path", type=str) + # index configuration + parser.add_argument("--index_matrix_path", type=str) + parser.add_argument("--mapping_dict_path", type=str) + # save configuration + parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") + parser.add_argument("--save_name", type=str, help="the name of the saved file") + return parser.parse_args() + +import argparse +if __name__ == '__main__': + if torch.cuda.is_available(): + print ('Cuda is available.') + cuda_available = torch.cuda.is_available() + args = parse_config() + device = torch.device('cuda') + + save_path_prefix = args.save_path_prefix + import os + if os.path.exists(save_path_prefix): + pass + else: # recursively construct directory + os.makedirs(save_path_prefix, exist_ok=True) + # parse save name + save_name = args.save_name + full_save_path = save_path_prefix + '/' + save_name + print ('full save path is {}'.format(full_save_path)) + + print ('Loading CLIP...') + from clip import CLIP + clip = CLIP(args.clip_name) + if cuda_available: + clip = clip.cuda(device) + clip.eval() + print ('CLIP loaded!') + + clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip) + + print ('Loading data...') + import json + with open(args.test_path) as f: + item_list = json.load(f) + print ('Data loaded.') + print ('Number of test instances is {}'.format(len(item_list))) + + result_list = [] + invalid_num = 0 + print ('----------------------------------------------------------------') + with torch.no_grad(): + test_num = len(item_list) + #test_num = 10 + print ('Number of inference instances is {}'.format(test_num)) + p = progressbar.ProgressBar(test_num) + p.start() + for p_idx in range(test_num): + p.update(p_idx) + one_test_dict = item_list[p_idx] + + one_res_dict = { + 'split':one_test_dict['split'], + 'image_name':one_test_dict['image_name'], + #'file_path':one_test_dict['file_path'], + 'captions':one_test_dict['captions'] + } + + image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] + try: + output_text = clipindex.search_text(image_full_path) + one_res_dict['prediction'] = output_text + result_list.append(one_res_dict) + except: + invalid_num += 1 + print ('invalid number is {}'.format(invalid_num)) + continue + p.finish() + print ('Inference completed!') + + import json + with open(full_save_path, 'w') as outfile: + json.dump(result_list, outfile, indent=4) + diff --git a/clip/flickr30k_clip_retrieval.sh b/clip/flickr30k_clip_retrieval.sh new file mode 100644 index 0000000..0dba975 --- /dev/null +++ b/clip/flickr30k_clip_retrieval.sh @@ -0,0 +1,8 @@ +CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\ + --clip_name openai/clip-vit-base-patch32\ + --test_image_prefix_path ../data/flickr30k/test_images/\ + --test_path ../data/flickr30k/flickr30k_test.json\ + --index_matrix_path ./flickr30k_index/index_matrix.txt\ + --mapping_dict_path ./flickr30k_index/text_mapping.json\ + --save_path_prefix ../inference_result/flickr30k/baselines/\ + --save_name flickr30k_in_domain_clipretrieval.json \ No newline at end of file diff --git a/clip/mscoco_clip_retrieval.sh b/clip/mscoco_clip_retrieval.sh new file mode 100644 index 0000000..cf4d893 --- /dev/null +++ b/clip/mscoco_clip_retrieval.sh @@ -0,0 +1,8 @@ +CUDA_VISIBLE_DEVICES=0 python clipretrieval.py\ + --clip_name openai/clip-vit-base-patch32\ + --test_image_prefix_path ../data/mscoco/test_images/\ + --test_path ../data/mscoco/mscoco_test.json\ + --index_matrix_path ./mscoco_index/index_matrix.txt\ + --mapping_dict_path ./mscoco_index/text_mapping.json\ + --save_path_prefix ../inference_result/mscoco/baselines/\ + --save_name mscoco_in_domain_clipretrieval.json \ No newline at end of file diff --git a/clip/source_flickr30k_target_mscoco_clip_retrieval.sh b/clip/source_flickr30k_target_mscoco_clip_retrieval.sh new file mode 100644 index 0000000..105f1c2 --- /dev/null +++ b/clip/source_flickr30k_target_mscoco_clip_retrieval.sh @@ -0,0 +1,8 @@ +CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\ + --clip_name openai/clip-vit-base-patch32\ + --test_image_prefix_path ../data/mscoco/test_images/\ + --test_path ../data/mscoco/mscoco_test.json\ + --index_matrix_path ./flickr30k_index/index_matrix.txt\ + --mapping_dict_path ./flickr30k_index/text_mapping.json\ + --save_path_prefix ../inference_result/flickr30k_model_to_mscoco/\ + --save_name source_flickr30k_target_mscoco_clip_retrieval.json \ No newline at end of file diff --git a/clip/source_mscoco_target_flickr30k_clip_retrieval.sh b/clip/source_mscoco_target_flickr30k_clip_retrieval.sh new file mode 100644 index 0000000..9902cda --- /dev/null +++ b/clip/source_mscoco_target_flickr30k_clip_retrieval.sh @@ -0,0 +1,8 @@ +CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\ + --clip_name openai/clip-vit-base-patch32\ + --test_image_prefix_path ../data/flickr30k/test_images/\ + --test_path ../data/flickr30k/flickr30k_test.json\ + --index_matrix_path ./mscoco_index/index_matrix.txt\ + --mapping_dict_path ./mscoco_index/text_mapping.json\ + --save_path_prefix ../inference_result/mscoco_model_to_flickr30k/\ + --save_name source_mscoco_target_flickr30k_clip_retrieval.json \ No newline at end of file diff --git a/magic.py b/magic.py new file mode 100644 index 0000000..85f46a7 --- /dev/null +++ b/magic.py @@ -0,0 +1,99 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from re import I +import sys +import os +import pathlib +import pickle +from argparse import Namespace + +import torch +import torchvision +from torchvision import transforms +from transformers import GPT2Tokenizer + +from towhee.types.arg import arg, to_image_color +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee import register +from towhee.models import clip + +class Magic(NNOperator): + """ + Magic image captioning operator + """ + def __init__(self, model_name: str): + super().__init__() + path = str(pathlib.Path(__file__).parent) + sys.path.append(path) + from clip import CLIP + from simctg import SimCTG + sys.path.pop() + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + # Load Language Model + language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco' + sos_token, pad_token = r'<-start_of_text->', r'<-pad->' + self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device) + self.generation_model.eval() + + model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32" + self.clip = CLIP(model_name).to(self.device) + self.clip.eval() + + + def _preprocess(self, img): + img = to_pil(img) + processed_img = self.transf_1(img) + processed_img = self.transf_2(processed_img) + processed_img = processed_img.to(self.device) + return processed_img + + @arg(1, to_image_color('RGB')) + def inference_single_data(self, data): + text = self._inference_from_image(data) + return text + + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + #img = self._preprocess(img).unsqueeze(0) + k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16 + eos_token = '<|endoftext|>' + with torch.no_grad(): + output = generation_model.magic_search(input_ids, k, + alpha, decoding_len, beta, image_instance, clip, 60) + + return out + + def _configs(self): + config = {} + config['expansionnet_rf'] = {} + config['expansionnet_rf']['weights'] = 'rf_model.pth' + return config diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/zerocap/README.md b/zerocap/README.md new file mode 100644 index 0000000..1839550 --- /dev/null +++ b/zerocap/README.md @@ -0,0 +1,89 @@ +### Our Implementation of the ZeroCap Baseline Model + +**** +### Catalogue: +* 1. Environment Preparation +* 2. Image Captioning on MSCOCO +* 3. Image Captioning on Flickr30k +* 4. Cross Domain Image Captioning on MSCOCO +* 5. Cross Domain Image Captioning on Flickr30k +* 6. Citation +* 7. Acknowledgements + +**** + + + +#### 1. Environment Preparation: +To install the correct environment, please run the following command: +```yaml +pip install -r requirements.txt +``` + +**** + + + +#### 2. Image Captioning on MSCOCO: +To perform image captioning on MSCOCO, please run the following command: +```yaml +chmod +x ./mscoco_zerocap.sh +./mscoco_zerocap.sh +``` + +**** + + + +#### 3. Image Captioning on Flickr30k: +To perform image captioning on Flickr30k, please run the following command: +```yaml +chmod +x ./flickr30k_zerocap.sh +./flickr30k_zerocap.sh +``` + +**** + + + +#### 4. Cross Domain Image Captioning on MSCOCO: +To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command: +```yaml +chmod +x ./flickr30k_to_mscoco_zerocap.sh +./flickr30k_to_mscoco_zerocap.sh +``` + +**** + + + +#### 5. Cross Domain Image Captioning on Flickr30k: +To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command: +```yaml +chmod +x ./mscoco_to_flickr30k_zerocap.sh +./mscoco_to_flickr30k_zerocap.sh +``` + +**** + + + +#### 6. Citation: +If you find our code helpful, please cite the original paper as + +```bibtex +@article{tewel2021zero, + title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, + author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, + journal={arXiv preprint arXiv:2111.14447}, + year={2021} +} +``` + +**** + + + +#### 7. Acknowledgements: +We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text). + diff --git a/zerocap/cog.yaml b/zerocap/cog.yaml new file mode 100644 index 0000000..92f13da --- /dev/null +++ b/zerocap/cog.yaml @@ -0,0 +1,12 @@ +build: + gpu: true + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "git+https://github.com/openai/CLIP.git" + - "git+https://github.com/YoadTew/zero-shot-image-to-text.git" + +predict: "predict.py:Predictor" +#predict: "predict_arithmetic.py:Predictor" \ No newline at end of file diff --git a/zerocap/flickr30k_zerocap.sh b/zerocap/flickr30k_zerocap.sh new file mode 100755 index 0000000..b727b12 --- /dev/null +++ b/zerocap/flickr30k_zerocap.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# lm_model: +# 1. cambridgeltl/magic_mscoco +# 2. cambridgeltl/magic_flickr30k +CUDA_VISIBLE_DEVICES=1 python run.py \ + --beam_size 1 \ + --target_seq_length 16 \ + --reset_context_delta \ + --lm_model cambridgeltl/magic_flickr30k \ + --test_image_prefix_path ../data/flickr30k/test_images \ + --test_path ../data/flickr30k/flickr30k_test.json \ + --save_path_prefix ../inference_result/flickr30k/baselines/ \ + --save_name zerocap_result.json diff --git a/zerocap/forbidden_tokens.npy b/zerocap/forbidden_tokens.npy new file mode 100644 index 0000000..aeed51b Binary files /dev/null and b/zerocap/forbidden_tokens.npy differ diff --git a/zerocap/model/ZeroCLIP.py b/zerocap/model/ZeroCLIP.py new file mode 100644 index 0000000..2c36fd2 --- /dev/null +++ b/zerocap/model/ZeroCLIP.py @@ -0,0 +1,389 @@ +import numpy as np +from torch import nn +from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer +from transformers.models.gpt_neo import GPTNeoForCausalLM +import torch +import clip +from PIL import Image +from datetime import datetime +import sys + + +def log_info(text, verbose=True): + if verbose: + dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") + print(f'{dt_string} | {text}') + sys.stdout.flush() + + +def add_context(x, y): + return (x[0] + y[0], x[1] + y[1]) + + +def convert_models_to_fp32(model): + for p in model.parameters(): + p.data = p.data.float() + + +class CLIPTextGenerator: + def __init__(self, + seed=0, + lm_model='gpt-2', + forbidden_tokens_file_path='./forbidden_tokens.npy', + clip_checkpoints='./clip_checkpoints', + target_seq_length=15, + reset_context_delta=True, + num_iterations=5, + clip_loss_temperature=0.01, + clip_scale=1., + ce_scale=0.2, + stepsize=0.3, + grad_norm_factor=0.9, + fusion_factor=0.99, + repetition_penalty=1., + end_token='.', + end_factor=1.01, + forbidden_factor=20, + **kwargs): + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + # set Random seed + torch.manual_seed(seed) + np.random.seed(seed) + + # Initialize Language model + self.context_prefix = '' + + self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model) + self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True) + self.context_prefix = self.lm_tokenizer.bos_token + + self.lm_model.to(self.device) + self.lm_model.eval() + + self.forbidden_tokens = np.load(forbidden_tokens_file_path) + self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if + (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] + + # Freeze LM weights + for param in self.lm_model.parameters(): + param.requires_grad = False + + # Initialize CLIP + self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, + download_root=clip_checkpoints, jit=False) + # convert_models_to_fp32(self.clip) + + # Init arguments + self.target_seq_length = target_seq_length + self.reset_context_delta = reset_context_delta + self.num_iterations = num_iterations + self.clip_loss_temperature = clip_loss_temperature + self.clip_scale = clip_scale + self.ce_scale = ce_scale + self.stepsize = stepsize + self.grad_norm_factor = grad_norm_factor + self.fusion_factor = fusion_factor + self.repetition_penalty = repetition_penalty + self.end_token = self.lm_tokenizer.encode(end_token)[0] + self.end_factor = end_factor + self.ef_idx = 1 + self.forbidden_factor = forbidden_factor + + def get_img_feature(self, img_path, weights): + imgs = [Image.open(x) for x in img_path] + clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] + + with torch.no_grad(): + image_fts = [self.clip.encode_image(x) for x in clip_imgs] + + if weights is not None: + image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) + else: + image_features = sum(image_fts) + + image_features = image_features / image_features.norm(dim=-1, keepdim=True) + return image_features.detach() + + def get_txt_features(self, text): + clip_texts = clip.tokenize(text).to(self.device) + + with torch.no_grad(): + text_features = self.clip.encode_text(clip_texts) + + text_features = text_features / text_features.norm(dim=-1, keepdim=True) + return text_features.detach() + + def get_combined_feature(self, img_path, texts, weights_i, weights_t): + imgs = [Image.open(x) for x in img_path] + clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] + clip_texts = [clip.tokenize(x).to(self.device) for x in texts] + + with torch.no_grad(): + image_fts = [self.clip.encode_image(x) for x in clip_imgs] + text_fts = [self.clip.encode_text(x) for x in clip_texts] + + features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) + if weights_t is not None: + features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) + + features = features / features.norm(dim=-1, keepdim=True) + return features.detach() + + def run(self, image_features, cond_text, beam_size): + self.image_features = image_features + + context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) + + output_tokens, output_text = self.generate_text(context_tokens, beam_size) + + return output_text + + def generate_text(self, context_tokens, beam_size): + context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) + + gen_tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=self.device) + is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) + + for i in range(self.target_seq_length): + probs = self.get_next_probs(i, context_tokens) + logits = probs.log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + + if gen_tokens is None: + gen_tokens = next_tokens + else: + gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) + gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( + beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + gen_tokens = gen_tokens[next_tokens_source] + gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) + context_tokens = context_tokens[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + context_tokens = torch.cat((context_tokens, next_tokens), dim=1) + is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() + + #### + tmp_scores = scores / seq_lengths + tmp_output_list = gen_tokens.cpu().numpy() + tmp_output_texts = [ + self.lm_tokenizer.decode(tmp_output) + for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) + ] + tmp_order = tmp_scores.argsort(descending=True) + tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] + log_info(tmp_output_texts, verbose=True) + #### + + if is_stopped.all(): + break + + scores = scores / seq_lengths + output_list = gen_tokens.cpu().numpy() + output_texts = [ + self.lm_tokenizer.decode(output[: int(length)]) + for output, length in zip(output_list, seq_lengths) + ] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + + return context_tokens, output_texts + + def get_next_probs(self, i, context_tokens): + last_token = context_tokens[:, -1:] + + if self.reset_context_delta and context_tokens.size(1) > 1: + context = self.lm_model(context_tokens[:, :-1])["past_key_values"] + + # Logits of LM with unshifted context + logits_before_shift = self.lm_model(context_tokens)["logits"] + logits_before_shift = logits_before_shift[:, -1, :] + probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) + + if context: + context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) + + lm_output = self.lm_model(last_token, past_key_values=context) + logits, past = ( + lm_output["logits"], + lm_output["past_key_values"], + ) + logits = logits[:, -1, :] + + logits = self.update_special_tokens_logits(context_tokens, i, logits) + + probs = nn.functional.softmax(logits, dim=-1) + probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) + probs = probs / probs.sum() + + return probs + + def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): + context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] + + window_mask = torch.ones_like(context[0][0]).to(self.device) + + for i in range(self.num_iterations): + curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in + context_delta] + + for p0, p1 in curr_shift: + p0.retain_grad() + p1.retain_grad() + + shifted_context = list(map(add_context, context, curr_shift)) + + shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) + logits = shifted_outputs["logits"][:, -1, :] + probs = nn.functional.softmax(logits, dim=-1) + + loss = 0.0 + + # CLIP LOSS + clip_loss, clip_losses = self.clip_loss(probs, context_tokens) + loss += self.clip_scale * clip_loss + + # CE/Fluency loss + ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) + loss += ce_loss.sum() + + loss.backward() + + # ---------- Weights ---------- + combined_scores_k = -(ce_loss) + combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) + + # minmax + if combined_scores_k.shape[0] == 1: + tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) + else: + tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( + combined_scores_k.max() - combined_scores_k.min()) + tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( + combined_scores_c.max() - combined_scores_c.min()) + + tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c + tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) + + factor = 1 + + # --------- Specific Gen --------- + sep_grads = None + + for b in range(context_tokens.shape[0]): + tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] + for p_ in curr_shift] + + # normalize gradients + tmp_grad = [tuple([-self.stepsize * factor * ( + x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ + j] ** self.grad_norm_factor).data.cpu().numpy() + for j, x in enumerate(p_)]) + for i, p_ in enumerate(curr_shift)] + if sep_grads is None: + sep_grads = tmp_grad + else: + for l_index in range(len(sep_grads)): + sep_grads[l_index] = list(sep_grads[l_index]) + for k_index in range(len(sep_grads[0])): + sep_grads[l_index][k_index] = np.concatenate( + (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) + sep_grads[l_index] = tuple(sep_grads[l_index]) + final_grads = sep_grads + + # --------- update context --------- + context_delta = list(map(add_context, final_grads, context_delta)) + + for p0, p1 in curr_shift: + p0.grad.data.zero_() + p1.grad.data.zero_() + + new_context = [] + for p0, p1 in context: + new_context.append((p0.detach(), p1.detach())) + context = new_context + + context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) + for p_ in context_delta] + context = list(map(add_context, context, context_delta)) + + new_context = [] + for p0, p1 in context: + new_context.append((p0.detach(), p1.detach())) + context = new_context + + return context + + def update_special_tokens_logits(self, context_tokens, i, logits): + for beam_id in range(context_tokens.shape[0]): + for token_idx in set(context_tokens[beam_id][-4:].tolist()): + factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) + logits[beam_id, token_idx] /= factor + + if i >= self.ef_idx: + factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) + logits[beam_id, self.end_token] *= factor + if i == 0: + start_factor = 1.6 + factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) + logits[beam_id, self.end_token] /= factor + + for token_idx in list(self.forbidden_tokens): + factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) + logits[beam_id, token_idx] /= factor + + return logits + + def clip_loss(self, probs, context_tokens): + for p_ in self.clip.transformer.parameters(): + if p_.grad is not None: + p_.grad.data.zero_() + + top_size = 512 + _, top_indices = probs.topk(top_size, -1) + + prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] + + clip_loss = 0 + losses = [] + for idx_p in range(probs.shape[0]): + top_texts = [] + prefix_text = prefix_texts[idx_p] + for x in top_indices[idx_p]: + top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) + text_features = self.get_txt_features(top_texts) + + with torch.no_grad(): + similiraties = (self.image_features @ text_features.T) + target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() + target_probs = target_probs.type(torch.float32) + + target = torch.zeros_like(probs[idx_p]) + target[top_indices[idx_p]] = target_probs[0] + target = target.unsqueeze(0) + cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) + + clip_loss += cur_clip_loss + losses.append(cur_clip_loss) + + return clip_loss, losses diff --git a/zerocap/model/ZeroCLIP_batched.py b/zerocap/model/ZeroCLIP_batched.py new file mode 100644 index 0000000..2c0209f --- /dev/null +++ b/zerocap/model/ZeroCLIP_batched.py @@ -0,0 +1,449 @@ +import numpy as np +from torch import nn +from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer +from transformers.models.gpt_neo import GPTNeoForCausalLM +import torch +import clip +from PIL import Image +from datetime import datetime +import sys + +class TextCLIP(nn.Module): + def __init__(self, model): + super(TextCLIP, self).__init__() + self.model = model + + def forward(self, text): + return self.model.encode_text(text) + + +class ImageCLIP(nn.Module): + def __init__(self, model): + super(ImageCLIP, self).__init__() + self.model = model + + def forward(self, image): + return self.model.encode_image(image) + +def log_info(text, verbose=True): + if verbose: + dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") + print(f'{dt_string} | {text}') + sys.stdout.flush() + + +def add_context(x, y): + return (x[0] + y[0], x[1] + y[1]) + + +def convert_models_to_fp32(model): + for p in model.parameters(): + p.data = p.data.float() + + +class CLIPTextGenerator: + def __init__(self, + seed=0, + lm_model='gpt-2', + forbidden_tokens_file_path='./forbidden_tokens.npy', + clip_checkpoints='./clip_checkpoints', + target_seq_length=15, + reset_context_delta=True, + num_iterations=5, + clip_loss_temperature=0.01, + clip_scale=1., + ce_scale=0.2, + stepsize=0.3, + grad_norm_factor=0.9, + fusion_factor=0.99, + repetition_penalty=1., + end_token='.', + end_factor=1.01, + forbidden_factor=20, + **kwargs): + + self.device = "cuda" if torch.cuda.is_available() else "cpu" + + # set Random seed + torch.manual_seed(seed) + np.random.seed(seed) + + # Initialize Language model + self.context_prefix = '' + + if lm_model == 'gpt-neo': + self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') + self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) + elif lm_model == 'gpt-2': + self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') + self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) + self.context_prefix = self.lm_tokenizer.bos_token + + self.lm_model.to(self.device) + self.lm_model.eval() + + self.forbidden_tokens = np.load(forbidden_tokens_file_path) + self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if + (x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] + + # Freeze LM weights + for param in self.lm_model.parameters(): + param.requires_grad = False + + # Initialize CLIP + self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, + download_root=clip_checkpoints, jit=False) + self.clip_image = ImageCLIP(self.clip) + self.clip_image = torch.nn.DataParallel(self.clip_image) + self.clip_text = TextCLIP(self.clip) + self.clip_text = torch.nn.DataParallel(self.clip_text) + + # Init arguments + self.target_seq_length = target_seq_length + self.reset_context_delta = reset_context_delta + self.num_iterations = num_iterations + self.clip_loss_temperature = clip_loss_temperature + self.clip_scale = clip_scale + self.ce_scale = ce_scale + self.stepsize = stepsize + self.grad_norm_factor = grad_norm_factor + self.fusion_factor = fusion_factor + self.repetition_penalty = repetition_penalty + self.end_token = self.lm_tokenizer.encode(end_token)[0] + self.end_factor = end_factor + self.ef_idx = 1 + self.forbidden_factor = forbidden_factor + + def get_img_feature(self, img_path, weights): + imgs = [Image.open(x) for x in img_path] + clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] + + with torch.no_grad(): + image_fts = [self.clip_image(x) for x in clip_imgs] + + if weights is not None: + image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) + else: + image_features = sum(image_fts) + + image_features = torch.nn.functional.normalize(image_features, dim=-1) + return image_features.detach() + + def get_txt_features(self, text): + clip_texts = clip.tokenize(text).to(self.device) + + with torch.no_grad(): + text_features = self.clip_text(clip_texts) + + text_features = torch.nn.functional.normalize(text_features, dim=-1) + return text_features.detach() + + def get_combined_feature(self, img_path, texts, weights_i, weights_t): + imgs = [Image.open(x) for x in img_path] + clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] + clip_texts = [clip.tokenize(x).to(self.device) for x in texts] + + with torch.no_grad(): + image_fts = [self.clip.encode_image(x) for x in clip_imgs] + text_fts = [self.clip.encode_text(x) for x in clip_texts] + + features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) + if weights_t is not None: + features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) + + features = features / features.norm(dim=-1, keepdim=True) + return features.detach() + + def run(self, image_features, cond_text, beam_size): + self.image_features = image_features + + context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) + + output_tokens, output_text = self.generate_text(context_tokens, beam_size) + + return output_text + + def generate_text(self, context_tokens, beam_size): + context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) + + gen_tokens = None + scores = None + seq_lengths = torch.ones(beam_size, device=self.device) + is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) + + for i in range(self.target_seq_length): + probs = self.get_next_probs(i, context_tokens) + logits = probs.log() + + if scores is None: + scores, next_tokens = logits.topk(beam_size, -1) + context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) + next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) + + if gen_tokens is None: + gen_tokens = next_tokens + else: + gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) + gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) + else: + logits[is_stopped] = -float(np.inf) + logits[is_stopped, 0] = 0 + scores_sum = scores[:, None] + logits + seq_lengths[~is_stopped] += 1 + scores_sum_average = scores_sum / seq_lengths[:, None] + scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( + beam_size, -1) + next_tokens_source = next_tokens // scores_sum.shape[1] + seq_lengths = seq_lengths[next_tokens_source] + next_tokens = next_tokens % scores_sum.shape[1] + next_tokens = next_tokens.unsqueeze(1) + gen_tokens = gen_tokens[next_tokens_source] + gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) + context_tokens = context_tokens[next_tokens_source] + scores = scores_sum_average * seq_lengths + is_stopped = is_stopped[next_tokens_source] + + context_tokens = torch.cat((context_tokens, next_tokens), dim=1) + is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() + + #### + tmp_scores = scores / seq_lengths + tmp_output_list = gen_tokens.cpu().numpy() + tmp_output_texts = [ + self.lm_tokenizer.decode(tmp_output) + for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) + ] + tmp_order = tmp_scores.argsort(descending=True) + tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] + log_info(tmp_output_texts, verbose=True) + #### + + if is_stopped.all(): + break + + scores = scores / seq_lengths + output_list = gen_tokens.cpu().numpy() + output_texts = [ + self.lm_tokenizer.decode(output[: int(length)]) + for output, length in zip(output_list, seq_lengths) + ] + order = scores.argsort(descending=True) + output_texts = [output_texts[i] for i in order] + + return context_tokens, output_texts + + def get_next_probs(self, i, context_tokens): + last_token = context_tokens[:, -1:] + + if self.reset_context_delta and context_tokens.size(1) > 1: + context = self.lm_model(context_tokens[:, :-1])["past_key_values"] + + # Logits of LM with unshifted context + logits_before_shift = self.lm_model(context_tokens)["logits"] + logits_before_shift = logits_before_shift[:, -1, :] + probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) + + if context: + context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) + + lm_output = self.lm_model(last_token, past_key_values=context) + logits, past = ( + lm_output["logits"], + lm_output["past_key_values"], + ) + logits = logits[:, -1, :] + + logits = self.update_special_tokens_logits(context_tokens, i, logits) + + probs = nn.functional.softmax(logits, dim=-1) + probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) + probs = probs / probs.sum() + + return probs + + def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): + context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] + + for i in range(self.num_iterations): + curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in + context_delta] + + for p0, p1 in curr_shift: + p0.retain_grad() + p1.retain_grad() + + shifted_context = list(map(add_context, context, curr_shift)) + + shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) + logits = shifted_outputs["logits"][:, -1, :] + probs = nn.functional.softmax(logits, dim=-1) + + loss = 0.0 + + # CLIP LOSS + clip_loss, clip_losses = self.clip_loss(probs, context_tokens) + loss += self.clip_scale * clip_loss + + # CE/Fluency loss + ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) + loss += ce_loss.sum() + + loss.backward() + + # --------- Specific Gen --------- + final_grads = self.norm_grad(context, context_tokens, curr_shift) + + # --------- update context --------- + context_delta = list(map(add_context, final_grads, context_delta)) + + for p0, p1 in curr_shift: + p0.grad.data.zero_() + p1.grad.data.zero_() + + new_context = [] + for p0, p1 in context: + new_context.append((p0.detach(), p1.detach())) + context = new_context + + context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) + for p_ in context_delta] + context = list(map(add_context, context, context_delta)) + + new_context = [] + for p0, p1 in context: + new_context.append((p0.detach(), p1.detach())) + context = new_context + + return context + + def norm_grad(self, context, context_tokens, curr_shift, ): + factor = 1 + sep_grads = None + window_mask = torch.ones_like(context[0][0]).to(self.device) + + for b in range(context_tokens.shape[0]): + tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] + for p_ in curr_shift] + + # normalize gradients + tmp_grad = [tuple([-self.stepsize * factor * ( + x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ + j] ** self.grad_norm_factor).data.cpu().numpy() + for j, x in enumerate(p_)]) + for i, p_ in enumerate(curr_shift)] + if sep_grads is None: + sep_grads = tmp_grad + else: + for l_index in range(len(sep_grads)): + sep_grads[l_index] = list(sep_grads[l_index]) + for k_index in range(len(sep_grads[0])): + sep_grads[l_index][k_index] = np.concatenate( + (sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) + sep_grads[l_index] = tuple(sep_grads[l_index]) + final_grads = sep_grads + + return final_grads + + def update_special_tokens_logits(self, context_tokens, i, logits): + for beam_id in range(context_tokens.shape[0]): + for token_idx in set(context_tokens[beam_id][-4:].tolist()): + factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) + logits[beam_id, token_idx] /= factor + + if i >= self.ef_idx: + factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) + logits[beam_id, self.end_token] *= factor + if i == 0: + start_factor = 1.6 + factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) + logits[beam_id, self.end_token] /= factor + + for token_idx in list(self.forbidden_tokens): + factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) + logits[beam_id, token_idx] /= factor + + return logits + + def clip_loss(self, probs, context_tokens): + for p_ in self.clip.transformer.parameters(): + if p_.grad is not None: + p_.grad.data.zero_() + + top_size = 512 + top_probs, top_indices = probs.topk(top_size, -1) + + prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] + + clip_loss = 0 + losses = [] + + top_texts = [] + for idx_p in range(probs.shape[0]): + prefix_text = prefix_texts[idx_p] + for x in top_indices[idx_p]: + top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) + + text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) + + with torch.no_grad(): + similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) + similiraties = similiraties.reshape(probs.size(0), -1) + target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() + target_probs = target_probs.type(torch.float32) + + clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) + # for idx_p in range(probs.shape[0]): + # top_texts = [] + # prefix_text = prefix_texts[idx_p] + # for x in top_indices[idx_p]: + # top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) + # text_features = self.get_txt_features(top_texts) + # + # with torch.no_grad(): + # similiraties = (self.image_features @ text_features.T) + # target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() + # target_probs = target_probs.type(torch.float32) + # + # target = torch.zeros_like(probs[idx_p]) + # target[top_indices[idx_p]] = target_probs[0] + # target = target.unsqueeze(0) + # cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) + # + # clip_loss += cur_clip_loss + # losses.append(cur_clip_loss) + + return clip_loss, losses + + def clip_loss_old(self, probs, context_tokens): + for p_ in self.clip.transformer.parameters(): + if p_.grad is not None: + p_.grad.data.zero_() + + top_size = 512 + _, top_indices = probs.topk(top_size, -1) + + prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] + + clip_loss = 0 + losses = [] + for idx_p in range(probs.shape[0]): + top_texts = [] + prefix_text = prefix_texts[idx_p] + for x in top_indices[idx_p]: + top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) + text_features = self.get_txt_features(top_texts) + + with torch.no_grad(): + similiraties = (self.image_features @ text_features.T) + target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() + target_probs = target_probs.type(torch.float32) + + target = torch.zeros_like(probs[idx_p]) + target[top_indices[idx_p]] = target_probs[0] + target = target.unsqueeze(0) + cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) + + clip_loss += cur_clip_loss + losses.append(cur_clip_loss) + + return clip_loss, losses \ No newline at end of file diff --git a/zerocap/model/__init__.py b/zerocap/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc new file mode 100644 index 0000000..093e083 Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc differ diff --git a/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc new file mode 100644 index 0000000..5f08c9e Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc differ diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc new file mode 100644 index 0000000..aa91bfc Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc differ diff --git a/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc new file mode 100644 index 0000000..816cae9 Binary files /dev/null and b/zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc differ diff --git a/zerocap/model/__pycache__/__init__.cpython-36.pyc b/zerocap/model/__pycache__/__init__.cpython-36.pyc new file mode 100644 index 0000000..5a18c45 Binary files /dev/null and b/zerocap/model/__pycache__/__init__.cpython-36.pyc differ diff --git a/zerocap/model/__pycache__/__init__.cpython-37.pyc b/zerocap/model/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000..3108a5a Binary files /dev/null and b/zerocap/model/__pycache__/__init__.cpython-37.pyc differ diff --git a/zerocap/mscoco_zerocap.sh b/zerocap/mscoco_zerocap.sh new file mode 100755 index 0000000..a06ae50 --- /dev/null +++ b/zerocap/mscoco_zerocap.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +# lm_model: +# 1. cambridgeltl/magic_mscoco +# 2. cambridgeltl/magic_flickr30k +CUDA_VISIBLE_DEVICES=1 python run.py \ + --beam_size 1 \ + --target_seq_length 16 \ + --reset_context_delta \ + --lm_model cambridgeltl/magic_mscoco \ + --test_image_prefix_path ../data/mscoco/test_images \ + --test_path ../data/mscoco/mscoco_test.json \ + --save_path_prefix ../inference_result/mscoco/baselines/ \ + --save_name zerocap_result.json diff --git a/zerocap/predict.py b/zerocap/predict.py new file mode 100644 index 0000000..46271f4 --- /dev/null +++ b/zerocap/predict.py @@ -0,0 +1,117 @@ +import os +import tempfile +import sys +sys.path.append('CLIP') +from pathlib import Path +import cog +import argparse +import torch +import clip +from model.ZeroCLIP import CLIPTextGenerator + +def perplexity_score(text, lm_model, lm_tokenizer, device): + encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') + input_ids = encodings.input_ids.to(device) + target_ids = input_ids.clone() + + outputs = lm_model(input_ids, labels=target_ids) + log_likelihood = outputs[0] + ll = log_likelihood.item() + + return ll + +class Predictor(cog.Predictor): + def setup(self): + self.args = get_args() + self.args.reset_context_delta = True + self.text_generator = CLIPTextGenerator(**vars(self.args)) + + @cog.input( + "image", + type=Path, + help="input image" + ) + @cog.input( + "cond_text", + type=str, + default='Image of a', + help="conditional text", + ) + @cog.input( + "beam_size", + type=int, + default=5, min=1, max=10, + help="Number of beams to use", + ) + @cog.input( + "end_factor", + type=float, + default=1.01, min=1.0, max=1.10, + help="Higher value for shorter captions", + ) + @cog.input( + "max_seq_length", + type=int, + default=15, min=1, max=20, + help="Maximum number of tokens to generate", + ) + @cog.input( + "ce_loss_scale", + type=float, + default=0.2, min=0.0, max=0.6, + help="Scale of cross-entropy loss with un-shifted language model", + ) + def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): + self.args.cond_text = cond_text + self.text_generator.end_factor = end_factor + self.text_generator.target_seq_length = max_seq_length + self.text_generator.ce_scale = ce_loss_scale + + image_features = self.text_generator.get_img_feature([str(image)], None) + captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) + + # CLIP SCORE + encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) + for c in captions] + encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] + best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() + + # Perplexity SCORE + ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] + best_ppl_index = torch.tensor(ppl_scores).argmin().item() + + best_clip_caption = self.args.cond_text + captions[best_clip_idx] + best_mixed = self.args.cond_text + captions[0] + best_PPL = self.args.cond_text + captions[best_ppl_index] + + final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' + + return final + # return self.args.cond_text + captions[best_clip_idx] + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") + parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") + parser.add_argument("--target_seq_length", type=int, default=15) + parser.add_argument("--cond_text", type=str, default="Image of a") + parser.add_argument("--reset_context_delta", action="store_true", + help="Should we reset the context at each token gen") + parser.add_argument("--num_iterations", type=int, default=5) + parser.add_argument("--clip_loss_temperature", type=float, default=0.01) + parser.add_argument("--clip_scale", type=float, default=1) + parser.add_argument("--ce_scale", type=float, default=0.2) + parser.add_argument("--stepsize", type=float, default=0.3) + parser.add_argument("--grad_norm_factor", type=float, default=0.9) + parser.add_argument("--fusion_factor", type=float, default=0.99) + parser.add_argument("--repetition_penalty", type=float, default=1) + parser.add_argument("--end_token", type=str, default=".", help="Token to end text") + parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") + parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") + parser.add_argument("--beam_size", type=int, default=5) + + args = parser.parse_args('') + return args diff --git a/zerocap/predict_arithmetic.py b/zerocap/predict_arithmetic.py new file mode 100644 index 0000000..1e2ade2 --- /dev/null +++ b/zerocap/predict_arithmetic.py @@ -0,0 +1,129 @@ +import os +import tempfile +import sys +sys.path.append('CLIP') +from pathlib import Path +import cog +import argparse +import torch +import clip +from model.ZeroCLIP import CLIPTextGenerator + +def perplexity_score(text, lm_model, lm_tokenizer, device): + encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') + input_ids = encodings.input_ids.to(device) + target_ids = input_ids.clone() + + outputs = lm_model(input_ids, labels=target_ids) + log_likelihood = outputs[0] + ll = log_likelihood.item() + + return ll + +class Predictor(cog.Predictor): + def setup(self): + self.args = get_args() + self.args.reset_context_delta = True + self.text_generator = CLIPTextGenerator(**vars(self.args)) + + @cog.input( + "image1", + type=Path, + help="Final result will be: image1 + (image2 - image3)" + ) + @cog.input( + "image2", + type=Path, + help="Final result will be: image1 + (image2 - image3)" + ) + @cog.input( + "image3", + type=Path, + help="Final result will be: image1 + (image2 - image3)" + ) + @cog.input( + "cond_text", + type=str, + default='Image of a', + help="conditional text", + ) + @cog.input( + "beam_size", + type=int, + default=3, min=1, max=10, + help="Number of beams to use", + ) + @cog.input( + "end_factors", + type=float, + default=1.06, min=1.0, max=1.10, + help="Higher value for shorter captions", + ) + @cog.input( + "max_seq_lengths", + type=int, + default=3, min=1, max=20, + help="Maximum number of tokens to generate", + ) + @cog.input( + "ce_loss_scale", + type=float, + default=0.2, min=0.0, max=0.6, + help="Scale of cross-entropy loss with un-shifted language model", + ) + def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): + self.args.cond_text = cond_text + self.text_generator.end_factor = end_factors + self.text_generator.target_seq_length = max_seq_lengths + self.text_generator.ce_scale = ce_loss_scale + self.text_generator.fusion_factor = 0.95 + self.text_generator.grad_norm_factor = 0.95 + + image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) + captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) + + # CLIP SCORE + encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) + for c in captions] + encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] + best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() + + # Perplexity SCORE + ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] + best_ppl_index = torch.tensor(ppl_scores).argmin().item() + + best_clip_caption = self.args.cond_text + captions[best_clip_idx] + best_mixed = self.args.cond_text + captions[0] + best_PPL = self.args.cond_text + captions[best_ppl_index] + + final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' + + return final + # return self.args.cond_text + captions[best_clip_idx] + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") + parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") + parser.add_argument("--target_seq_length", type=int, default=15) + parser.add_argument("--cond_text", type=str, default="Image of a") + parser.add_argument("--reset_context_delta", action="store_true", + help="Should we reset the context at each token gen") + parser.add_argument("--num_iterations", type=int, default=5) + parser.add_argument("--clip_loss_temperature", type=float, default=0.01) + parser.add_argument("--clip_scale", type=float, default=1) + parser.add_argument("--ce_scale", type=float, default=0.2) + parser.add_argument("--stepsize", type=float, default=0.3) + parser.add_argument("--grad_norm_factor", type=float, default=0.95) + parser.add_argument("--fusion_factor", type=float, default=0.95) + parser.add_argument("--repetition_penalty", type=float, default=1) + parser.add_argument("--end_token", type=str, default=".", help="Token to end text") + parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") + parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") + parser.add_argument("--beam_size", type=int, default=5) + + args = parser.parse_args('') + return args diff --git a/zerocap/requirements.txt b/zerocap/requirements.txt new file mode 100644 index 0000000..0eaf0ad --- /dev/null +++ b/zerocap/requirements.txt @@ -0,0 +1,3 @@ +ftfy +regex +tqdm diff --git a/zerocap/run.py b/zerocap/run.py new file mode 100644 index 0000000..fab33b9 --- /dev/null +++ b/zerocap/run.py @@ -0,0 +1,131 @@ +import argparse +import ipdb +from tqdm import tqdm +import progressbar +import torch +import ipdb +import clip +from model.ZeroCLIP import CLIPTextGenerator +from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") + parser.add_argument("--test_path", type=str) + parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") + parser.add_argument("--save_name", type=str, help="the name of the saved file") + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") + parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") + parser.add_argument("--target_seq_length", type=int, default=15) + parser.add_argument("--cond_text", type=str, default="Image of a") + parser.add_argument("--reset_context_delta", action="store_true", + help="Should we reset the context at each token gen") + parser.add_argument("--num_iterations", type=int, default=5) + parser.add_argument("--clip_loss_temperature", type=float, default=0.01) + parser.add_argument("--clip_scale", type=float, default=1) + parser.add_argument("--ce_scale", type=float, default=0.2) + parser.add_argument("--stepsize", type=float, default=0.3) + parser.add_argument("--grad_norm_factor", type=float, default=0.9) + parser.add_argument("--fusion_factor", type=float, default=0.99) + parser.add_argument("--repetition_penalty", type=float, default=1) + parser.add_argument("--end_token", type=str, default=".", help="Token to end text") + parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") + parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") + parser.add_argument("--beam_size", type=int, default=1) + + parser.add_argument("--multi_gpu", action="store_true") + + parser.add_argument('--run_type', + default='caption', + nargs='?', + choices=['caption', 'arithmetics']) + + parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', + help="Path to image for captioning") + + parser.add_argument("--arithmetics_imgs", nargs="+", + default=['example_images/arithmetics/woman2.jpg', + 'example_images/arithmetics/king2.jpg', + 'example_images/arithmetics/man2.jpg']) + parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) + + args = parser.parse_args() + + return args + +def run(args, text_generator, img_path): + image_features = text_generator.get_img_feature([img_path], None) + captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) + + encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] + encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] + best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() + return captions + + +if __name__ == '__main__': + if torch.cuda.is_available(): + print ('Cuda is available.') + cuda_available = torch.cuda.is_available() + args = get_args() + device = torch.device('cuda') + + save_path_prefix = args.save_path_prefix + import os + if os.path.exists(save_path_prefix): + pass + else: # recursively construct directory + os.makedirs(save_path_prefix, exist_ok=True) + # parse save name + save_name = args.save_name + full_save_path = save_path_prefix + '/' + save_name + print ('full save path is {}'.format(full_save_path)) + + print ('Loading data...') + import json + with open(args.test_path) as f: + item_list = json.load(f) + print ('Data loaded.') + print ('Number of test instances is {}'.format(len(item_list))) + + # ZeroCap generator + text_generator = CLIPTextGenerator(**vars(args)) + + result_list = [] + invalid_num = 0 + print ('----------------------------------------------------------------') + test_num = len(item_list) + #test_num = 10 + print ('Number of inference instances is {}'.format(test_num)) + p = progressbar.ProgressBar(test_num) + p.start() + for p_idx in tqdm(range(test_num)): + p.update(p_idx) + one_test_dict = item_list[p_idx] + + one_res_dict = { + 'split':one_test_dict['split'], + 'image_name':one_test_dict['image_name'], + #'file_path':one_test_dict['file_path'], + 'captions':one_test_dict['captions'] + } + + image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] + try: + output_text = run(args, text_generator, img_path=image_full_path) + one_res_dict['prediction'] = output_text[0] + result_list.append(one_res_dict) + except Exception as error: + print(f'[!] ERROR:', error) + invalid_num += 1 + print ('invalid number is {}'.format(invalid_num)) + continue + p.finish() + print ('Inference completed!') + + import json + with open(full_save_path, 'w') as outfile: + json.dump(result_list, outfile, indent=4) diff --git a/zerocap/setup.py b/zerocap/setup.py new file mode 100644 index 0000000..8ae2efe --- /dev/null +++ b/zerocap/setup.py @@ -0,0 +1,19 @@ +import os + +import pkg_resources +from setuptools import setup, find_packages + +setup( + name="zero-shot-image-to-text", + py_modules=["zero-shot-image-to-text"], + version="1.0", + description="", + packages=find_packages(), + install_requires=[ + str(r) + for r in pkg_resources.parse_requirements( + open(os.path.join(os.path.dirname(__file__), "requirements.txt")) + ) + ], + include_package_data=True +) \ No newline at end of file