logo
Browse Source

init the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
7e9e351cfc
  1. 18
      __init__.py
  2. 141
      clip/README.md
  3. 7
      clip/build_flickr30k_index.sh
  4. 7
      clip/build_mscoco_index.sh
  5. 105
      clip/build_text_index.py
  6. 146
      clip/clip.py
  7. 135
      clip/clipretrieval.py
  8. 8
      clip/flickr30k_clip_retrieval.sh
  9. 8
      clip/mscoco_clip_retrieval.sh
  10. 8
      clip/source_flickr30k_target_mscoco_clip_retrieval.sh
  11. 8
      clip/source_mscoco_target_flickr30k_clip_retrieval.sh
  12. 99
      magic.py
  13. 0
      requirements.txt
  14. 89
      zerocap/README.md
  15. 12
      zerocap/cog.yaml
  16. 14
      zerocap/flickr30k_zerocap.sh
  17. BIN
      zerocap/forbidden_tokens.npy
  18. 389
      zerocap/model/ZeroCLIP.py
  19. 449
      zerocap/model/ZeroCLIP_batched.py
  20. 0
      zerocap/model/__init__.py
  21. BIN
      zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc
  22. BIN
      zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc
  23. BIN
      zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc
  24. BIN
      zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc
  25. BIN
      zerocap/model/__pycache__/__init__.cpython-36.pyc
  26. BIN
      zerocap/model/__pycache__/__init__.cpython-37.pyc
  27. 14
      zerocap/mscoco_zerocap.sh
  28. 117
      zerocap/predict.py
  29. 129
      zerocap/predict_arithmetic.py
  30. 3
      zerocap/requirements.txt
  31. 131
      zerocap/run.py
  32. 19
      zerocap/setup.py

18
__init__.py

@ -0,0 +1,18 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .magic import Magic
def magic(model_name: str):
return Magic(model_name)

141
clip/README.md

@ -0,0 +1,141 @@
## CLIP
This folder illustrates how to use CLIP to build text index and to conduct cross-modal retrieval baseline.
****
## Catalogue:
* <a href='#index'>1. Build Text Index</a>
* <a href='#mscoco'>1.1. Build Text Index for MSCOCO</a>
* <a href='#download_mscoco_index'>1.1.1. Download Our Built Index</a>
* <a href='#process_mscoco_index'>1.1.2. Construct the Index by Yourself</a>
* <a href='#flickr30k'>1.2. Build Text Index for Flickr30k</a>
* <a href='#download_flickr30k_index'>1.2.1. Download Our Built Index</a>
* <a href='#process_flickr30k_index'>1.2.2. Construct the Index by Yourself</a>
* <a href='#baseline'>2. CLIP Retrieval Baseline</a>
* <a href='#in_domain_baseline'>2.1. In Domain CLIP Retrieval</a>
* <a href='#cross_domain_baseline'>2.2. Cross Domain CLIP Retrieval</a>
****
<span id='index'/>
### 1. Build Text Index:
We show how to build the text index, from which the caption is retrieved from, using CLIP.
<span id='mscoco'/>
#### 1.1. Build Text Index for MSCOCO:
First, we demonstrate how to build text index for MSCOCO.
<span id='download_mscoco_index'/>
#### 1.1.1. Download Our Post-processed Index:
We share our built index for MSCOCO via this [[link]](https://drive.google.com/file/d/1Dx_RPeAmydS6ZYuiJ-dLlK9-DjDZkxAh/view?usp=sharing). After downloading, unzip the downloaded file **mscoco_index.zip** under the current directory.
> **** The resulting directory looks like:
.
├── ./mscoco_index/
├── index_matrix.txt # The file that stores the representations of captions from the training set of MSCOCO. Each row is a vector that corresponds to a specific caption from the training set.
└── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption.
<span id='process_mscoco_index'/>
#### 1.1.2. Construct the Index by Yourself:
You can also rebuild the index by yourself. First, you should make sure you have downloaded the MSCOCO data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). Then, you can run the following command to build the index.
```yaml
chmod +x ./build_mscoco_index.sh
./build_mscoco_index.sh
```
The arguments are as follows:
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
* `--text_file_path`: Where the training text corpus stores.
* `--save_index_prefix`: In which directory you would like to store your index files.
* `--save_index_name`: The saved name of the caption representations.
* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions.
* `--batch_size`: The inference batch size.
<span id='flickr30k'/>
#### 1.2. Build Text Index for Flickr30k:
Next, we demonstrate how to build text index for Flickr30k.
<span id='download_flickr30k_index'/>
#### 1.2.1. Download Our Post-processed Index:
We share our built index for Flickr30k via this [[link]](https://drive.google.com/file/d/1hS58_ir5pdZZPckApCtlz2RyasCQbrPf/view?usp=sharing). After downloading, unzip the downloaded file **flickr30k_index.zip** under the current directory.
> **** The resulting directory looks like:
.
├── ./flickr30k_index/
├── index_matrix.txt # The file that stores the representations of captions from the training set of Flickr30k. Each row is a vector that corresponds to a specific caption from the training set.
└── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption.
<span id='process_flickr30k_index'/>
#### 1.2.2. Construct the Index by Yourself:
You can also rebuild the index by yourself. First, you should make sure you have downloaded the Flickr30k data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). Then, you can run the following command to build the index.
```yaml
chmod +x ./build_flickr30k_index.sh
./build_flickr30k_index.sh
```
The arguments are as follows:
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
* `--text_file_path`: Where the training text corpus stores.
* `--save_index_prefix`: In which directory you would like to store your index files.
* `--save_index_name`: The saved name of the caption representations.
* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions.
* `--batch_size`: The inference batch size.
****
<span id='baseline'/>
### 2. CLIP Retrieval Baseline:
Here, we show how to conduct the CLIP retrieval baseline.
<span id='in_domain_baseline'/>
#### 2.1. In Domain CLIP Retrieval:
To retrieve the captions from the in domain training set, you should run the following command:
```yaml
chmod +x ./X_clip_retrieval.sh
./X_clip_retrieval.sh
```
Here, X is in ['mscoco', 'flickr30k'] which corresponds for the MSCOCO and Flickr30k benchmarks.
The arguments are as follows:
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
* `--test_image_prefix_path`: Where the test set images stores.
* `--test_path`: Where the reference test captions file stores.
* `--index_matrix_path`: The path of the representation index file.
* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions.
* `--save_path_prefix`: Where to save the inferenced result.
* `--save_name`: The saved name of the inferenced result.
**[Note]** As we are conducting in domain CLIP retrieval, the test images and the caption index should come from the same benchmark.
<span id='cross_domain_baseline'/>
#### 2.2. Cross Domain CLIP Retrieval:
To retrieve the captions from the cross domain training set, you should run the following command:
```yaml
chmod +x ./source_X_target_Y_clip_retrieval.sh
./source_X_target_Y_clip_retrieval.sh
```
Here, X is the source domain from ['mscoco', 'flickr30k'] and Y is the target domain from ['flickr30k', 'mscoco'].
The arguments are as follows:
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface.
* `--test_image_prefix_path`: Where the test set images stores.
* `--test_path`: Where the reference test captions file stores.
* `--index_matrix_path`: The path of the representation index file.
* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions.
* `--save_path_prefix`: Where to save the inferenced result.
* `--save_name`: The saved name of the inferenced result.
**[Note]** As we are conducting cross domain CLIP retrieval, the test images and the caption index should come from **different** benchmarks.

7
clip/build_flickr30k_index.sh

@ -0,0 +1,7 @@
CUDA_VISIBLE_DEVICES=1 python build_text_index.py\
--clip_name openai/clip-vit-base-patch32\
--text_file_path ../data/flickr30k/flickr30k_train.json\
--save_index_prefix ./flickr30k_index/\
--save_index_name index_matrix.txt\
--save_mapping_dict_name text_mapping.json\
--batch_size 128

7
clip/build_mscoco_index.sh

@ -0,0 +1,7 @@
CUDA_VISIBLE_DEVICES=0 python build_text_index.py\
--clip_name openai/clip-vit-base-patch32\
--text_file_path ../data/mscoco/mscoco_train.json\
--save_index_prefix ./mscoco_index/\
--save_index_name index_matrix.txt\
--save_mapping_dict_name text_mapping.json\
--batch_size 128

105
clip/build_text_index.py

@ -0,0 +1,105 @@
import sys
import torch
import numpy as np
import progressbar
import os
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32")
parser.add_argument("--text_file_path", type=str)
# save configuration
parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index')
parser.add_argument("--save_index_name", type=str)
parser.add_argument("--save_mapping_dict_name", type=str,
help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text")
# inference configuration
parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP")
return parser.parse_args()
def load_batch_text(text_file_path, batch_size):
import json
with open(text_file_path) as f:
item_list = json.load(f)
text_list = []
for item in item_list:
captions = item["captions"]
for cap in captions:
text_list.append(cap)
print ('Number of text instances is {}'.format(len(text_list)))
data_num = len(text_list)
batch_num = data_num // batch_size
batch_text_list = []
s_idx, e_idx = 0, batch_size
for p_idx in range(batch_num):
one_batch_text_list = []
for idx in range(s_idx, e_idx):
one_batch_text_list.append(text_list[idx])
batch_text_list.append(one_batch_text_list)
return batch_text_list
import argparse
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = parse_config()
device = torch.device('cuda')
import os
if os.path.exists(args.save_index_prefix):
pass
else: # recursively construct directory
os.makedirs(args.save_index_prefix, exist_ok=True)
print ('Loading CLIP...')
from clip import CLIP
model = CLIP(args.clip_name)
if cuda_available:
model = model.cuda(device)
model.eval()
print ('CLIP loaded!')
print ('Loading text data...')
batch_text_list = load_batch_text(args.text_file_path, args.batch_size)
print ('Text data loaded.')
res_text_vec_list, res_text_list = [], []
batch_num = len(batch_text_list)
print ('Number of batches is {}'.format(batch_num))
print ('Start inference...')
p = progressbar.ProgressBar(batch_num)
p.start()
with torch.no_grad():
for p_idx in range(batch_num):
p.update(p_idx)
one_text_batch = batch_text_list[p_idx]
one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu()
one_batch_vec_list = one_batch_vec.unbind(dim=0)
bsz = len(one_batch_vec_list)
for k in range(bsz):
res_text_vec_list.append(one_batch_vec_list[k].numpy())
res_text_list.append(one_text_batch[k])
p.finish()
assert len(res_text_vec_list) == len(res_text_list)
print ('Inference completed!')
index_text_mapping_dict = {}
for k in range(len(res_text_list)):
index_text_mapping_dict[k] = res_text_list[k]
mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name
import json
with open(mapping_list_save_path, 'w') as outfile:
json.dump(index_text_mapping_dict, outfile, indent=4)
print ('Mapping dictionary saved!')
print ('Start buiding index...')
index_save_path = args.save_index_prefix + '/' + args.save_index_name
with open(index_save_path, 'w', encoding = 'utf8') as o:
for vec in res_text_vec_list:
one_text = ' '.join([str(num) for num in vec]).strip()
o.writelines(one_text + '\n')
print ('Index completed!')

146
clip/clip.py

@ -0,0 +1,146 @@
import torch
import requests
from torch import nn
from PIL import Image
class CLIP(nn.Module):
def __init__(self, model_name):
super(CLIP, self).__init__()
# model name: e.g. openai/clip-vit-base-patch32
print ('Initializing CLIP model...')
from transformers import CLIPProcessor, CLIPModel
self.model = CLIPModel.from_pretrained(model_name)
self.model.eval()
self.processor = CLIPProcessor.from_pretrained(model_name)
from transformers import CLIPTokenizer
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
self.cuda_has_been_checked = False
print ('CLIP model initialized.')
def check_cuda(self):
self.cuda_available = next(self.model.parameters()).is_cuda
self.device = next(self.model.parameters()).get_device()
if self.cuda_available:
print ('Cuda is available.')
print ('Device is {}'.format(self.device))
else:
print ('Cuda is not available.')
print ('Device is {}'.format(self.device))
@torch.no_grad()
def compute_image_representation_from_image_path(self, image_path):
if not self.cuda_has_been_checked:
self.check_cuda()
self.cuda_has_been_checked = True
else:
pass
# image_path: the path of the image
image = Image.open(image_path)
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values']
if self.cuda_available:
pixel_values = pixel_values.cuda(self.device)
visual_outputs = self.model.vision_model(pixel_values=pixel_values)
image_embeds = visual_outputs[1]
image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
return image_embeds
def compute_image_representation_from_image_instance(self, image):
if not self.cuda_has_been_checked:
self.check_cuda()
self.cuda_has_been_checked = True
else:
pass
# image_path: the path of the image
inputs = self.processor(images=image, return_tensors="pt")
pixel_values = inputs['pixel_values']
if self.cuda_available:
pixel_values = pixel_values.cuda(self.device)
visual_outputs = self.model.vision_model(pixel_values=pixel_values)
image_embeds = visual_outputs[1]
image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
return image_embeds
def compute_text_representation(self, text_list):
if not self.cuda_has_been_checked:
self.check_cuda()
self.cuda_has_been_checked = True
else:
pass
# text_list: a list of text
text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt",
max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True)
# self.tokenizer.max_len_single_sentence + 2 = 77
input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']
if self.cuda_available:
input_ids = input_ids.cuda(self.device)
attention_mask = attention_mask.cuda(self.device)
text_outputs = self.model.text_model(
input_ids=input_ids,
attention_mask=attention_mask
)
text_embeds = text_outputs[1]
text_embeds = self.model.text_projection(text_embeds)
return text_embeds
def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds):
'''
image_embeds: 1 x embed_dim
text_embeds: len(text_list) x embed_dim
'''
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
logit_scale = self.model.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T
return logits_per_image.softmax(dim=1) # 1 x len(text_list)
def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list):
text_embeds = self.compute_text_representation(text_list)
return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds)
### -------------------- functions for building index ---------------------- ###
def compute_batch_index_image_features(self, image_list):
'''
# list of image instances
'''
if not self.cuda_has_been_checked:
self.check_cuda()
self.cuda_has_been_checked = True
else:
pass
# image_path: the path of the image
inputs = self.processor(images=image_list, return_tensors="pt")
pixel_values = inputs['pixel_values']
if self.cuda_available:
pixel_values = pixel_values.cuda(self.device)
visual_outputs = self.model.vision_model(pixel_values=pixel_values)
image_embeds = visual_outputs[1]
image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim]
return image_embeds # len(image_list) x embed_dim
def compute_batch_index_text_representation(self, text_list):
if not self.cuda_has_been_checked:
self.check_cuda()
self.cuda_has_been_checked = True
else:
pass
# text_list: a list of text
#text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt")
text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt",
max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True)
input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask']
if self.cuda_available:
input_ids = input_ids.cuda(self.device)
attention_mask = attention_mask.cuda(self.device)
text_outputs = self.model.text_model(
input_ids=input_ids,
attention_mask=attention_mask
)
text_embeds = text_outputs[1]
text_embeds = self.model.text_projection(text_embeds)
return text_embeds
#logit_scale = self.model.logit_scale.exp()
#text_embeds = text_embeds * logit_scale
#return text_embeds

135
clip/clipretrieval.py

@ -0,0 +1,135 @@
import json
import copy
import torch
import progressbar
import numpy as np
from PIL import Image
class CLIPIndex:
def __init__(self, index_matrix_path, mapping_dict_path, clip):
'''
index_path: the pre-trained index
mapping_dict_path: the pre-indexed mapping dictionary
clip: the pre-trained clip model
'''
print ('Loading index...')
self.index_matrix = self.normalization(self.load_matrix(index_matrix_path))
print ('Index loaded.')
print (self.index_matrix.shape)
with open(mapping_dict_path) as f:
self.mapping_dict = json.load(f)
self.clip = clip
def load_matrix(self, in_f):
matrix_list = []
with open(in_f, 'r', encoding = 'utf8') as i:
lines = i.readlines()
for l in lines:
one_vec = [float(num) for num in l.strip('\n').split()]
matrix_list.append(one_vec)
return np.array(matrix_list)
def normalization(self, matrix):
'''
matrix: num_instance x num_feature
'''
return matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
def get_image_representation(self, image_path):
image_instance = Image.open(image_path)
image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy()
image_vec = self.normalization(image_vec)
return image_vec
def search_text(self, image_path):
image_vec = self.get_image_representation(image_path)
sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1]
top_idx = sort_idx_list[0]
return self.mapping_dict[str(top_idx)]
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument("--clip_name", type=str)
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
parser.add_argument("--test_path", type=str)
# index configuration
parser.add_argument("--index_matrix_path", type=str)
parser.add_argument("--mapping_dict_path", type=str)
# save configuration
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
parser.add_argument("--save_name", type=str, help="the name of the saved file")
return parser.parse_args()
import argparse
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = parse_config()
device = torch.device('cuda')
save_path_prefix = args.save_path_prefix
import os
if os.path.exists(save_path_prefix):
pass
else: # recursively construct directory
os.makedirs(save_path_prefix, exist_ok=True)
# parse save name
save_name = args.save_name
full_save_path = save_path_prefix + '/' + save_name
print ('full save path is {}'.format(full_save_path))
print ('Loading CLIP...')
from clip import CLIP
clip = CLIP(args.clip_name)
if cuda_available:
clip = clip.cuda(device)
clip.eval()
print ('CLIP loaded!')
clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip)
print ('Loading data...')
import json
with open(args.test_path) as f:
item_list = json.load(f)
print ('Data loaded.')
print ('Number of test instances is {}'.format(len(item_list)))
result_list = []
invalid_num = 0
print ('----------------------------------------------------------------')
with torch.no_grad():
test_num = len(item_list)
#test_num = 10
print ('Number of inference instances is {}'.format(test_num))
p = progressbar.ProgressBar(test_num)
p.start()
for p_idx in range(test_num):
p.update(p_idx)
one_test_dict = item_list[p_idx]
one_res_dict = {
'split':one_test_dict['split'],
'image_name':one_test_dict['image_name'],
#'file_path':one_test_dict['file_path'],
'captions':one_test_dict['captions']
}
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
try:
output_text = clipindex.search_text(image_full_path)
one_res_dict['prediction'] = output_text
result_list.append(one_res_dict)
except:
invalid_num += 1
print ('invalid number is {}'.format(invalid_num))
continue
p.finish()
print ('Inference completed!')
import json
with open(full_save_path, 'w') as outfile:
json.dump(result_list, outfile, indent=4)

8
clip/flickr30k_clip_retrieval.sh

@ -0,0 +1,8 @@
CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\
--clip_name openai/clip-vit-base-patch32\
--test_image_prefix_path ../data/flickr30k/test_images/\
--test_path ../data/flickr30k/flickr30k_test.json\
--index_matrix_path ./flickr30k_index/index_matrix.txt\
--mapping_dict_path ./flickr30k_index/text_mapping.json\
--save_path_prefix ../inference_result/flickr30k/baselines/\
--save_name flickr30k_in_domain_clipretrieval.json

8
clip/mscoco_clip_retrieval.sh

@ -0,0 +1,8 @@
CUDA_VISIBLE_DEVICES=0 python clipretrieval.py\
--clip_name openai/clip-vit-base-patch32\
--test_image_prefix_path ../data/mscoco/test_images/\
--test_path ../data/mscoco/mscoco_test.json\
--index_matrix_path ./mscoco_index/index_matrix.txt\
--mapping_dict_path ./mscoco_index/text_mapping.json\
--save_path_prefix ../inference_result/mscoco/baselines/\
--save_name mscoco_in_domain_clipretrieval.json

8
clip/source_flickr30k_target_mscoco_clip_retrieval.sh

@ -0,0 +1,8 @@
CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\
--clip_name openai/clip-vit-base-patch32\
--test_image_prefix_path ../data/mscoco/test_images/\
--test_path ../data/mscoco/mscoco_test.json\
--index_matrix_path ./flickr30k_index/index_matrix.txt\
--mapping_dict_path ./flickr30k_index/text_mapping.json\
--save_path_prefix ../inference_result/flickr30k_model_to_mscoco/\
--save_name source_flickr30k_target_mscoco_clip_retrieval.json

8
clip/source_mscoco_target_flickr30k_clip_retrieval.sh

@ -0,0 +1,8 @@
CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\
--clip_name openai/clip-vit-base-patch32\
--test_image_prefix_path ../data/flickr30k/test_images/\
--test_path ../data/flickr30k/flickr30k_test.json\
--index_matrix_path ./mscoco_index/index_matrix.txt\
--mapping_dict_path ./mscoco_index/text_mapping.json\
--save_path_prefix ../inference_result/mscoco_model_to_flickr30k/\
--save_name source_mscoco_target_flickr30k_clip_retrieval.json

99
magic.py

@ -0,0 +1,99 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from re import I
import sys
import os
import pathlib
import pickle
from argparse import Namespace
import torch
import torchvision
from torchvision import transforms
from transformers import GPT2Tokenizer
from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
class Magic(NNOperator):
"""
Magic image captioning operator
"""
def __init__(self, model_name: str):
super().__init__()
path = str(pathlib.Path(__file__).parent)
sys.path.append(path)
from clip import CLIP
from simctg import SimCTG
sys.path.pop()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Language Model
language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco'
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device)
self.generation_model.eval()
model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32"
self.clip = CLIP(model_name).to(self.device)
self.clip.eval()
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.transf_1(img)
processed_img = self.transf_2(processed_img)
processed_img = processed_img.to(self.device)
return processed_img
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
text = self._inference_from_image(data)
return text
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
results = []
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
#img = self._preprocess(img).unsqueeze(0)
k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16
eos_token = '<|endoftext|>'
with torch.no_grad():
output = generation_model.magic_search(input_ids, k,
alpha, decoding_len, beta, image_instance, clip, 60)
return out
def _configs(self):
config = {}
config['expansionnet_rf'] = {}
config['expansionnet_rf']['weights'] = 'rf_model.pth'
return config

0
requirements.txt

89
zerocap/README.md

@ -0,0 +1,89 @@
### Our Implementation of the ZeroCap Baseline Model
****
### Catalogue:
* <a href='#environment'>1. Environment Preparation</a>
* <a href='#mscoco'>2. Image Captioning on MSCOCO</a>
* <a href='#flickr30k'>3. Image Captioning on Flickr30k</a>
* <a href='#flickr30k_to_mscoco'>4. Cross Domain Image Captioning on MSCOCO</a>
* <a href='#mscoco_to_flickr30k'>5. Cross Domain Image Captioning on Flickr30k</a>
* <a href='#citation'>6. Citation</a>
* <a href='#acknowledgements'>7. Acknowledgements</a>
****
<span id='environment'/>
#### 1. Environment Preparation:
To install the correct environment, please run the following command:
```yaml
pip install -r requirements.txt
```
****
<span id='mscoco'/>
#### 2. Image Captioning on MSCOCO:
To perform image captioning on MSCOCO, please run the following command:
```yaml
chmod +x ./mscoco_zerocap.sh
./mscoco_zerocap.sh
```
****
<span id='flickr30k'/>
#### 3. Image Captioning on Flickr30k:
To perform image captioning on Flickr30k, please run the following command:
```yaml
chmod +x ./flickr30k_zerocap.sh
./flickr30k_zerocap.sh
```
****
<span id='flickr30k_to_mscoco'/>
#### 4. Cross Domain Image Captioning on MSCOCO:
To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command:
```yaml
chmod +x ./flickr30k_to_mscoco_zerocap.sh
./flickr30k_to_mscoco_zerocap.sh
```
****
<span id='mscoco_to_flickr30k'/>
#### 5. Cross Domain Image Captioning on Flickr30k:
To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command:
```yaml
chmod +x ./mscoco_to_flickr30k_zerocap.sh
./mscoco_to_flickr30k_zerocap.sh
```
****
<span id='citation'/>
#### 6. Citation:
If you find our code helpful, please cite the original paper as
```bibtex
@article{tewel2021zero,
title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic},
author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior},
journal={arXiv preprint arXiv:2111.14447},
year={2021}
}
```
****
<span id='acknowledgements'/>
#### 7. Acknowledgements:
We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text).

12
zerocap/cog.yaml

@ -0,0 +1,12 @@
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "git+https://github.com/openai/CLIP.git"
- "git+https://github.com/YoadTew/zero-shot-image-to-text.git"
predict: "predict.py:Predictor"
#predict: "predict_arithmetic.py:Predictor"

14
zerocap/flickr30k_zerocap.sh

@ -0,0 +1,14 @@
#!/bin/bash
# lm_model:
# 1. cambridgeltl/magic_mscoco
# 2. cambridgeltl/magic_flickr30k
CUDA_VISIBLE_DEVICES=1 python run.py \
--beam_size 1 \
--target_seq_length 16 \
--reset_context_delta \
--lm_model cambridgeltl/magic_flickr30k \
--test_image_prefix_path ../data/flickr30k/test_images \
--test_path ../data/flickr30k/flickr30k_test.json \
--save_path_prefix ../inference_result/flickr30k/baselines/ \
--save_name zerocap_result.json

BIN
zerocap/forbidden_tokens.npy

Binary file not shown.

389
zerocap/model/ZeroCLIP.py

@ -0,0 +1,389 @@
import numpy as np
from torch import nn
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
from transformers.models.gpt_neo import GPTNeoForCausalLM
import torch
import clip
from PIL import Image
from datetime import datetime
import sys
def log_info(text, verbose=True):
if verbose:
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f'{dt_string} | {text}')
sys.stdout.flush()
def add_context(x, y):
return (x[0] + y[0], x[1] + y[1])
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
class CLIPTextGenerator:
def __init__(self,
seed=0,
lm_model='gpt-2',
forbidden_tokens_file_path='./forbidden_tokens.npy',
clip_checkpoints='./clip_checkpoints',
target_seq_length=15,
reset_context_delta=True,
num_iterations=5,
clip_loss_temperature=0.01,
clip_scale=1.,
ce_scale=0.2,
stepsize=0.3,
grad_norm_factor=0.9,
fusion_factor=0.99,
repetition_penalty=1.,
end_token='.',
end_factor=1.01,
forbidden_factor=20,
**kwargs):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# set Random seed
torch.manual_seed(seed)
np.random.seed(seed)
# Initialize Language model
self.context_prefix = ''
self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model)
self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True)
self.context_prefix = self.lm_tokenizer.bos_token
self.lm_model.to(self.device)
self.lm_model.eval()
self.forbidden_tokens = np.load(forbidden_tokens_file_path)
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())]
# Freeze LM weights
for param in self.lm_model.parameters():
param.requires_grad = False
# Initialize CLIP
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
download_root=clip_checkpoints, jit=False)
# convert_models_to_fp32(self.clip)
# Init arguments
self.target_seq_length = target_seq_length
self.reset_context_delta = reset_context_delta
self.num_iterations = num_iterations
self.clip_loss_temperature = clip_loss_temperature
self.clip_scale = clip_scale
self.ce_scale = ce_scale
self.stepsize = stepsize
self.grad_norm_factor = grad_norm_factor
self.fusion_factor = fusion_factor
self.repetition_penalty = repetition_penalty
self.end_token = self.lm_tokenizer.encode(end_token)[0]
self.end_factor = end_factor
self.ef_idx = 1
self.forbidden_factor = forbidden_factor
def get_img_feature(self, img_path, weights):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
with torch.no_grad():
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
if weights is not None:
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
else:
image_features = sum(image_fts)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.detach()
def get_txt_features(self, text):
clip_texts = clip.tokenize(text).to(self.device)
with torch.no_grad():
text_features = self.clip.encode_text(clip_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.detach()
def get_combined_feature(self, img_path, texts, weights_i, weights_t):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
with torch.no_grad():
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
text_fts = [self.clip.encode_text(x) for x in clip_texts]
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
if weights_t is not None:
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
features = features / features.norm(dim=-1, keepdim=True)
return features.detach()
def run(self, image_features, cond_text, beam_size):
self.image_features = image_features
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
output_tokens, output_text = self.generate_text(context_tokens, beam_size)
return output_text
def generate_text(self, context_tokens, beam_size):
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
gen_tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=self.device)
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
for i in range(self.target_seq_length):
probs = self.get_next_probs(i, context_tokens)
logits = probs.log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if gen_tokens is None:
gen_tokens = next_tokens
else:
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
gen_tokens = gen_tokens[next_tokens_source]
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
context_tokens = context_tokens[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
####
tmp_scores = scores / seq_lengths
tmp_output_list = gen_tokens.cpu().numpy()
tmp_output_texts = [
self.lm_tokenizer.decode(tmp_output)
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
]
tmp_order = tmp_scores.argsort(descending=True)
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
log_info(tmp_output_texts, verbose=True)
####
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = gen_tokens.cpu().numpy()
output_texts = [
self.lm_tokenizer.decode(output[: int(length)])
for output, length in zip(output_list, seq_lengths)
]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return context_tokens, output_texts
def get_next_probs(self, i, context_tokens):
last_token = context_tokens[:, -1:]
if self.reset_context_delta and context_tokens.size(1) > 1:
context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
# Logits of LM with unshifted context
logits_before_shift = self.lm_model(context_tokens)["logits"]
logits_before_shift = logits_before_shift[:, -1, :]
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
if context:
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
lm_output = self.lm_model(last_token, past_key_values=context)
logits, past = (
lm_output["logits"],
lm_output["past_key_values"],
)
logits = logits[:, -1, :]
logits = self.update_special_tokens_logits(context_tokens, i, logits)
probs = nn.functional.softmax(logits, dim=-1)
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
probs = probs / probs.sum()
return probs
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
window_mask = torch.ones_like(context[0][0]).to(self.device)
for i in range(self.num_iterations):
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
context_delta]
for p0, p1 in curr_shift:
p0.retain_grad()
p1.retain_grad()
shifted_context = list(map(add_context, context, curr_shift))
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
logits = shifted_outputs["logits"][:, -1, :]
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
# CLIP LOSS
clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
loss += self.clip_scale * clip_loss
# CE/Fluency loss
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
loss += ce_loss.sum()
loss.backward()
# ---------- Weights ----------
combined_scores_k = -(ce_loss)
combined_scores_c = -(self.clip_scale * torch.stack(clip_losses))
# minmax
if combined_scores_k.shape[0] == 1:
tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device)
else:
tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / (
combined_scores_k.max() - combined_scores_k.min())
tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / (
combined_scores_c.max() - combined_scores_c.min())
tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c
tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1)
factor = 1
# --------- Specific Gen ---------
sep_grads = None
for b in range(context_tokens.shape[0]):
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
for p_ in curr_shift]
# normalize gradients
tmp_grad = [tuple([-self.stepsize * factor * (
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
j] ** self.grad_norm_factor).data.cpu().numpy()
for j, x in enumerate(p_)])
for i, p_ in enumerate(curr_shift)]
if sep_grads is None:
sep_grads = tmp_grad
else:
for l_index in range(len(sep_grads)):
sep_grads[l_index] = list(sep_grads[l_index])
for k_index in range(len(sep_grads[0])):
sep_grads[l_index][k_index] = np.concatenate(
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
sep_grads[l_index] = tuple(sep_grads[l_index])
final_grads = sep_grads
# --------- update context ---------
context_delta = list(map(add_context, final_grads, context_delta))
for p0, p1 in curr_shift:
p0.grad.data.zero_()
p1.grad.data.zero_()
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
for p_ in context_delta]
context = list(map(add_context, context, context_delta))
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
return context
def update_special_tokens_logits(self, context_tokens, i, logits):
for beam_id in range(context_tokens.shape[0]):
for token_idx in set(context_tokens[beam_id][-4:].tolist()):
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
logits[beam_id, token_idx] /= factor
if i >= self.ef_idx:
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
logits[beam_id, self.end_token] *= factor
if i == 0:
start_factor = 1.6
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
logits[beam_id, self.end_token] /= factor
for token_idx in list(self.forbidden_tokens):
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
logits[beam_id, token_idx] /= factor
return logits
def clip_loss(self, probs, context_tokens):
for p_ in self.clip.transformer.parameters():
if p_.grad is not None:
p_.grad.data.zero_()
top_size = 512
_, top_indices = probs.topk(top_size, -1)
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
clip_loss = 0
losses = []
for idx_p in range(probs.shape[0]):
top_texts = []
prefix_text = prefix_texts[idx_p]
for x in top_indices[idx_p]:
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
text_features = self.get_txt_features(top_texts)
with torch.no_grad():
similiraties = (self.image_features @ text_features.T)
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
target_probs = target_probs.type(torch.float32)
target = torch.zeros_like(probs[idx_p])
target[top_indices[idx_p]] = target_probs[0]
target = target.unsqueeze(0)
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
clip_loss += cur_clip_loss
losses.append(cur_clip_loss)
return clip_loss, losses

449
zerocap/model/ZeroCLIP_batched.py

@ -0,0 +1,449 @@
import numpy as np
from torch import nn
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
from transformers.models.gpt_neo import GPTNeoForCausalLM
import torch
import clip
from PIL import Image
from datetime import datetime
import sys
class TextCLIP(nn.Module):
def __init__(self, model):
super(TextCLIP, self).__init__()
self.model = model
def forward(self, text):
return self.model.encode_text(text)
class ImageCLIP(nn.Module):
def __init__(self, model):
super(ImageCLIP, self).__init__()
self.model = model
def forward(self, image):
return self.model.encode_image(image)
def log_info(text, verbose=True):
if verbose:
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f'{dt_string} | {text}')
sys.stdout.flush()
def add_context(x, y):
return (x[0] + y[0], x[1] + y[1])
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
class CLIPTextGenerator:
def __init__(self,
seed=0,
lm_model='gpt-2',
forbidden_tokens_file_path='./forbidden_tokens.npy',
clip_checkpoints='./clip_checkpoints',
target_seq_length=15,
reset_context_delta=True,
num_iterations=5,
clip_loss_temperature=0.01,
clip_scale=1.,
ce_scale=0.2,
stepsize=0.3,
grad_norm_factor=0.9,
fusion_factor=0.99,
repetition_penalty=1.,
end_token='.',
end_factor=1.01,
forbidden_factor=20,
**kwargs):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# set Random seed
torch.manual_seed(seed)
np.random.seed(seed)
# Initialize Language model
self.context_prefix = ''
if lm_model == 'gpt-neo':
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True)
elif lm_model == 'gpt-2':
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True)
self.context_prefix = self.lm_tokenizer.bos_token
self.lm_model.to(self.device)
self.lm_model.eval()
self.forbidden_tokens = np.load(forbidden_tokens_file_path)
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())]
# Freeze LM weights
for param in self.lm_model.parameters():
param.requires_grad = False
# Initialize CLIP
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
download_root=clip_checkpoints, jit=False)
self.clip_image = ImageCLIP(self.clip)
self.clip_image = torch.nn.DataParallel(self.clip_image)
self.clip_text = TextCLIP(self.clip)
self.clip_text = torch.nn.DataParallel(self.clip_text)
# Init arguments
self.target_seq_length = target_seq_length
self.reset_context_delta = reset_context_delta
self.num_iterations = num_iterations
self.clip_loss_temperature = clip_loss_temperature
self.clip_scale = clip_scale
self.ce_scale = ce_scale
self.stepsize = stepsize
self.grad_norm_factor = grad_norm_factor
self.fusion_factor = fusion_factor
self.repetition_penalty = repetition_penalty
self.end_token = self.lm_tokenizer.encode(end_token)[0]
self.end_factor = end_factor
self.ef_idx = 1
self.forbidden_factor = forbidden_factor
def get_img_feature(self, img_path, weights):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
with torch.no_grad():
image_fts = [self.clip_image(x) for x in clip_imgs]
if weights is not None:
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
else:
image_features = sum(image_fts)
image_features = torch.nn.functional.normalize(image_features, dim=-1)
return image_features.detach()
def get_txt_features(self, text):
clip_texts = clip.tokenize(text).to(self.device)
with torch.no_grad():
text_features = self.clip_text(clip_texts)
text_features = torch.nn.functional.normalize(text_features, dim=-1)
return text_features.detach()
def get_combined_feature(self, img_path, texts, weights_i, weights_t):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
with torch.no_grad():
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
text_fts = [self.clip.encode_text(x) for x in clip_texts]
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
if weights_t is not None:
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
features = features / features.norm(dim=-1, keepdim=True)
return features.detach()
def run(self, image_features, cond_text, beam_size):
self.image_features = image_features
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
output_tokens, output_text = self.generate_text(context_tokens, beam_size)
return output_text
def generate_text(self, context_tokens, beam_size):
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
gen_tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=self.device)
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
for i in range(self.target_seq_length):
probs = self.get_next_probs(i, context_tokens)
logits = probs.log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if gen_tokens is None:
gen_tokens = next_tokens
else:
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
gen_tokens = gen_tokens[next_tokens_source]
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
context_tokens = context_tokens[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
####
tmp_scores = scores / seq_lengths
tmp_output_list = gen_tokens.cpu().numpy()
tmp_output_texts = [
self.lm_tokenizer.decode(tmp_output)
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
]
tmp_order = tmp_scores.argsort(descending=True)
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
log_info(tmp_output_texts, verbose=True)
####
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = gen_tokens.cpu().numpy()
output_texts = [
self.lm_tokenizer.decode(output[: int(length)])
for output, length in zip(output_list, seq_lengths)
]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return context_tokens, output_texts
def get_next_probs(self, i, context_tokens):
last_token = context_tokens[:, -1:]
if self.reset_context_delta and context_tokens.size(1) > 1:
context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
# Logits of LM with unshifted context
logits_before_shift = self.lm_model(context_tokens)["logits"]
logits_before_shift = logits_before_shift[:, -1, :]
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
if context:
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
lm_output = self.lm_model(last_token, past_key_values=context)
logits, past = (
lm_output["logits"],
lm_output["past_key_values"],
)
logits = logits[:, -1, :]
logits = self.update_special_tokens_logits(context_tokens, i, logits)
probs = nn.functional.softmax(logits, dim=-1)
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
probs = probs / probs.sum()
return probs
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
for i in range(self.num_iterations):
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
context_delta]
for p0, p1 in curr_shift:
p0.retain_grad()
p1.retain_grad()
shifted_context = list(map(add_context, context, curr_shift))
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
logits = shifted_outputs["logits"][:, -1, :]
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
# CLIP LOSS
clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
loss += self.clip_scale * clip_loss
# CE/Fluency loss
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
loss += ce_loss.sum()
loss.backward()
# --------- Specific Gen ---------
final_grads = self.norm_grad(context, context_tokens, curr_shift)
# --------- update context ---------
context_delta = list(map(add_context, final_grads, context_delta))
for p0, p1 in curr_shift:
p0.grad.data.zero_()
p1.grad.data.zero_()
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
for p_ in context_delta]
context = list(map(add_context, context, context_delta))
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
return context
def norm_grad(self, context, context_tokens, curr_shift, ):
factor = 1
sep_grads = None
window_mask = torch.ones_like(context[0][0]).to(self.device)
for b in range(context_tokens.shape[0]):
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
for p_ in curr_shift]
# normalize gradients
tmp_grad = [tuple([-self.stepsize * factor * (
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
j] ** self.grad_norm_factor).data.cpu().numpy()
for j, x in enumerate(p_)])
for i, p_ in enumerate(curr_shift)]
if sep_grads is None:
sep_grads = tmp_grad
else:
for l_index in range(len(sep_grads)):
sep_grads[l_index] = list(sep_grads[l_index])
for k_index in range(len(sep_grads[0])):
sep_grads[l_index][k_index] = np.concatenate(
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
sep_grads[l_index] = tuple(sep_grads[l_index])
final_grads = sep_grads
return final_grads
def update_special_tokens_logits(self, context_tokens, i, logits):
for beam_id in range(context_tokens.shape[0]):
for token_idx in set(context_tokens[beam_id][-4:].tolist()):
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
logits[beam_id, token_idx] /= factor
if i >= self.ef_idx:
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
logits[beam_id, self.end_token] *= factor
if i == 0:
start_factor = 1.6
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
logits[beam_id, self.end_token] /= factor
for token_idx in list(self.forbidden_tokens):
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
logits[beam_id, token_idx] /= factor
return logits
def clip_loss(self, probs, context_tokens):
for p_ in self.clip.transformer.parameters():
if p_.grad is not None:
p_.grad.data.zero_()
top_size = 512
top_probs, top_indices = probs.topk(top_size, -1)
prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens]
clip_loss = 0
losses = []
top_texts = []
for idx_p in range(probs.shape[0]):
prefix_text = prefix_texts[idx_p]
for x in top_indices[idx_p]:
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1)
with torch.no_grad():
similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1)
similiraties = similiraties.reshape(probs.size(0), -1)
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
target_probs = target_probs.type(torch.float32)
clip_loss += torch.sum(-(target_probs * torch.log(top_probs)))
# for idx_p in range(probs.shape[0]):
# top_texts = []
# prefix_text = prefix_texts[idx_p]
# for x in top_indices[idx_p]:
# top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
# text_features = self.get_txt_features(top_texts)
#
# with torch.no_grad():
# similiraties = (self.image_features @ text_features.T)
# target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
# target_probs = target_probs.type(torch.float32)
#
# target = torch.zeros_like(probs[idx_p])
# target[top_indices[idx_p]] = target_probs[0]
# target = target.unsqueeze(0)
# cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
#
# clip_loss += cur_clip_loss
# losses.append(cur_clip_loss)
return clip_loss, losses
def clip_loss_old(self, probs, context_tokens):
for p_ in self.clip.transformer.parameters():
if p_.grad is not None:
p_.grad.data.zero_()
top_size = 512
_, top_indices = probs.topk(top_size, -1)
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
clip_loss = 0
losses = []
for idx_p in range(probs.shape[0]):
top_texts = []
prefix_text = prefix_texts[idx_p]
for x in top_indices[idx_p]:
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
text_features = self.get_txt_features(top_texts)
with torch.no_grad():
similiraties = (self.image_features @ text_features.T)
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
target_probs = target_probs.type(torch.float32)
target = torch.zeros_like(probs[idx_p])
target[top_indices[idx_p]] = target_probs[0]
target = target.unsqueeze(0)
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
clip_loss += cur_clip_loss
losses.append(cur_clip_loss)
return clip_loss, losses

0
zerocap/model/__init__.py

BIN
zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/__init__.cpython-36.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/__init__.cpython-37.pyc

Binary file not shown.

14
zerocap/mscoco_zerocap.sh

@ -0,0 +1,14 @@
#!/bin/bash
# lm_model:
# 1. cambridgeltl/magic_mscoco
# 2. cambridgeltl/magic_flickr30k
CUDA_VISIBLE_DEVICES=1 python run.py \
--beam_size 1 \
--target_seq_length 16 \
--reset_context_delta \
--lm_model cambridgeltl/magic_mscoco \
--test_image_prefix_path ../data/mscoco/test_images \
--test_path ../data/mscoco/mscoco_test.json \
--save_path_prefix ../inference_result/mscoco/baselines/ \
--save_name zerocap_result.json

117
zerocap/predict.py

@ -0,0 +1,117 @@
import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image",
type=Path,
help="input image"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=5, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factor",
type=float,
default=1.01, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_length",
type=int,
default=15, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factor
self.text_generator.target_seq_length = max_seq_length
self.text_generator.ce_scale = ce_loss_scale
image_features = self.text_generator.get_img_feature([str(image)], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args

129
zerocap/predict_arithmetic.py

@ -0,0 +1,129 @@
import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image1",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"image2",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"image3",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=3, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factors",
type=float,
default=1.06, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_lengths",
type=int,
default=3, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factors
self.text_generator.target_seq_length = max_seq_lengths
self.text_generator.ce_scale = ce_loss_scale
self.text_generator.fusion_factor = 0.95
self.text_generator.grad_norm_factor = 0.95
image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.95)
parser.add_argument("--fusion_factor", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args

3
zerocap/requirements.txt

@ -0,0 +1,3 @@
ftfy
regex
tqdm

131
zerocap/run.py

@ -0,0 +1,131 @@
import argparse
import ipdb
from tqdm import tqdm
import progressbar
import torch
import ipdb
import clip
from model.ZeroCLIP import CLIPTextGenerator
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
parser.add_argument("--test_path", type=str)
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
parser.add_argument("--save_name", type=str, help="the name of the saved file")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=1)
parser.add_argument("--multi_gpu", action="store_true")
parser.add_argument('--run_type',
default='caption',
nargs='?',
choices=['caption', 'arithmetics'])
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
help="Path to image for captioning")
parser.add_argument("--arithmetics_imgs", nargs="+",
default=['example_images/arithmetics/woman2.jpg',
'example_images/arithmetics/king2.jpg',
'example_images/arithmetics/man2.jpg'])
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
args = parser.parse_args()
return args
def run(args, text_generator, img_path):
image_features = text_generator.get_img_feature([img_path], None)
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
return captions
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = get_args()
device = torch.device('cuda')
save_path_prefix = args.save_path_prefix
import os
if os.path.exists(save_path_prefix):
pass
else: # recursively construct directory
os.makedirs(save_path_prefix, exist_ok=True)
# parse save name
save_name = args.save_name
full_save_path = save_path_prefix + '/' + save_name
print ('full save path is {}'.format(full_save_path))
print ('Loading data...')
import json
with open(args.test_path) as f:
item_list = json.load(f)
print ('Data loaded.')
print ('Number of test instances is {}'.format(len(item_list)))
# ZeroCap generator
text_generator = CLIPTextGenerator(**vars(args))
result_list = []
invalid_num = 0
print ('----------------------------------------------------------------')
test_num = len(item_list)
#test_num = 10
print ('Number of inference instances is {}'.format(test_num))
p = progressbar.ProgressBar(test_num)
p.start()
for p_idx in tqdm(range(test_num)):
p.update(p_idx)
one_test_dict = item_list[p_idx]
one_res_dict = {
'split':one_test_dict['split'],
'image_name':one_test_dict['image_name'],
#'file_path':one_test_dict['file_path'],
'captions':one_test_dict['captions']
}
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
try:
output_text = run(args, text_generator, img_path=image_full_path)
one_res_dict['prediction'] = output_text[0]
result_list.append(one_res_dict)
except Exception as error:
print(f'[!] ERROR:', error)
invalid_num += 1
print ('invalid number is {}'.format(invalid_num))
continue
p.finish()
print ('Inference completed!')
import json
with open(full_save_path, 'w') as outfile:
json.dump(result_list, outfile, indent=4)

19
zerocap/setup.py

@ -0,0 +1,19 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="zero-shot-image-to-text",
py_modules=["zero-shot-image-to-text"],
version="1.0",
description="",
packages=find_packages(),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
include_package_data=True
)
Loading…
Cancel
Save