magic
copied
wxywb
2 years ago
32 changed files with 2056 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .magic import Magic |
||||
|
|
||||
|
def magic(model_name: str): |
||||
|
return Magic(model_name) |
@ -0,0 +1,141 @@ |
|||||
|
## CLIP |
||||
|
This folder illustrates how to use CLIP to build text index and to conduct cross-modal retrieval baseline. |
||||
|
|
||||
|
**** |
||||
|
## Catalogue: |
||||
|
* <a href='#index'>1. Build Text Index</a> |
||||
|
* <a href='#mscoco'>1.1. Build Text Index for MSCOCO</a> |
||||
|
* <a href='#download_mscoco_index'>1.1.1. Download Our Built Index</a> |
||||
|
* <a href='#process_mscoco_index'>1.1.2. Construct the Index by Yourself</a> |
||||
|
* <a href='#flickr30k'>1.2. Build Text Index for Flickr30k</a> |
||||
|
* <a href='#download_flickr30k_index'>1.2.1. Download Our Built Index</a> |
||||
|
* <a href='#process_flickr30k_index'>1.2.2. Construct the Index by Yourself</a> |
||||
|
* <a href='#baseline'>2. CLIP Retrieval Baseline</a> |
||||
|
* <a href='#in_domain_baseline'>2.1. In Domain CLIP Retrieval</a> |
||||
|
* <a href='#cross_domain_baseline'>2.2. Cross Domain CLIP Retrieval</a> |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='index'/> |
||||
|
|
||||
|
### 1. Build Text Index: |
||||
|
We show how to build the text index, from which the caption is retrieved from, using CLIP. |
||||
|
|
||||
|
<span id='mscoco'/> |
||||
|
|
||||
|
#### 1.1. Build Text Index for MSCOCO: |
||||
|
First, we demonstrate how to build text index for MSCOCO. |
||||
|
|
||||
|
<span id='download_mscoco_index'/> |
||||
|
|
||||
|
#### 1.1.1. Download Our Post-processed Index: |
||||
|
We share our built index for MSCOCO via this [[link]](https://drive.google.com/file/d/1Dx_RPeAmydS6ZYuiJ-dLlK9-DjDZkxAh/view?usp=sharing). After downloading, unzip the downloaded file **mscoco_index.zip** under the current directory. |
||||
|
|
||||
|
> **** The resulting directory looks like: |
||||
|
|
||||
|
. |
||||
|
├── ./mscoco_index/ |
||||
|
├── index_matrix.txt # The file that stores the representations of captions from the training set of MSCOCO. Each row is a vector that corresponds to a specific caption from the training set. |
||||
|
└── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption. |
||||
|
|
||||
|
<span id='process_mscoco_index'/> |
||||
|
|
||||
|
#### 1.1.2. Construct the Index by Yourself: |
||||
|
|
||||
|
You can also rebuild the index by yourself. First, you should make sure you have downloaded the MSCOCO data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). Then, you can run the following command to build the index. |
||||
|
```yaml |
||||
|
chmod +x ./build_mscoco_index.sh |
||||
|
./build_mscoco_index.sh |
||||
|
``` |
||||
|
The arguments are as follows: |
||||
|
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. |
||||
|
* `--text_file_path`: Where the training text corpus stores. |
||||
|
* `--save_index_prefix`: In which directory you would like to store your index files. |
||||
|
* `--save_index_name`: The saved name of the caption representations. |
||||
|
* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions. |
||||
|
* `--batch_size`: The inference batch size. |
||||
|
|
||||
|
|
||||
|
<span id='flickr30k'/> |
||||
|
|
||||
|
#### 1.2. Build Text Index for Flickr30k: |
||||
|
Next, we demonstrate how to build text index for Flickr30k. |
||||
|
|
||||
|
<span id='download_flickr30k_index'/> |
||||
|
|
||||
|
#### 1.2.1. Download Our Post-processed Index: |
||||
|
We share our built index for Flickr30k via this [[link]](https://drive.google.com/file/d/1hS58_ir5pdZZPckApCtlz2RyasCQbrPf/view?usp=sharing). After downloading, unzip the downloaded file **flickr30k_index.zip** under the current directory. |
||||
|
|
||||
|
> **** The resulting directory looks like: |
||||
|
|
||||
|
. |
||||
|
├── ./flickr30k_index/ |
||||
|
├── index_matrix.txt # The file that stores the representations of captions from the training set of Flickr30k. Each row is a vector that corresponds to a specific caption from the training set. |
||||
|
└── text_mapping.json # The file that stores the mappings between the representation and the corresponding caption. |
||||
|
|
||||
|
<span id='process_flickr30k_index'/> |
||||
|
|
||||
|
#### 1.2.2. Construct the Index by Yourself: |
||||
|
|
||||
|
You can also rebuild the index by yourself. First, you should make sure you have downloaded the Flickr30k data following instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). Then, you can run the following command to build the index. |
||||
|
```yaml |
||||
|
chmod +x ./build_flickr30k_index.sh |
||||
|
./build_flickr30k_index.sh |
||||
|
``` |
||||
|
The arguments are as follows: |
||||
|
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. |
||||
|
* `--text_file_path`: Where the training text corpus stores. |
||||
|
* `--save_index_prefix`: In which directory you would like to store your index files. |
||||
|
* `--save_index_name`: The saved name of the caption representations. |
||||
|
* `--save_mapping_dict_name`: The saved name of the mapping dictionary between representations and captions. |
||||
|
* `--batch_size`: The inference batch size. |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='baseline'/> |
||||
|
|
||||
|
### 2. CLIP Retrieval Baseline: |
||||
|
Here, we show how to conduct the CLIP retrieval baseline. |
||||
|
|
||||
|
<span id='in_domain_baseline'/> |
||||
|
|
||||
|
#### 2.1. In Domain CLIP Retrieval: |
||||
|
To retrieve the captions from the in domain training set, you should run the following command: |
||||
|
```yaml |
||||
|
chmod +x ./X_clip_retrieval.sh |
||||
|
./X_clip_retrieval.sh |
||||
|
``` |
||||
|
Here, X is in ['mscoco', 'flickr30k'] which corresponds for the MSCOCO and Flickr30k benchmarks. |
||||
|
|
||||
|
The arguments are as follows: |
||||
|
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. |
||||
|
* `--test_image_prefix_path`: Where the test set images stores. |
||||
|
* `--test_path`: Where the reference test captions file stores. |
||||
|
* `--index_matrix_path`: The path of the representation index file. |
||||
|
* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions. |
||||
|
* `--save_path_prefix`: Where to save the inferenced result. |
||||
|
* `--save_name`: The saved name of the inferenced result. |
||||
|
|
||||
|
**[Note]** As we are conducting in domain CLIP retrieval, the test images and the caption index should come from the same benchmark. |
||||
|
|
||||
|
|
||||
|
<span id='cross_domain_baseline'/> |
||||
|
|
||||
|
#### 2.2. Cross Domain CLIP Retrieval: |
||||
|
To retrieve the captions from the cross domain training set, you should run the following command: |
||||
|
```yaml |
||||
|
chmod +x ./source_X_target_Y_clip_retrieval.sh |
||||
|
./source_X_target_Y_clip_retrieval.sh |
||||
|
``` |
||||
|
Here, X is the source domain from ['mscoco', 'flickr30k'] and Y is the target domain from ['flickr30k', 'mscoco']. |
||||
|
|
||||
|
The arguments are as follows: |
||||
|
* `--clip_name`: The configuration of the pre-trained CLIP model from huggingface. |
||||
|
* `--test_image_prefix_path`: Where the test set images stores. |
||||
|
* `--test_path`: Where the reference test captions file stores. |
||||
|
* `--index_matrix_path`: The path of the representation index file. |
||||
|
* `--mapping_dict_path`: The path of the mapping dictionary between representations and captions. |
||||
|
* `--save_path_prefix`: Where to save the inferenced result. |
||||
|
* `--save_name`: The saved name of the inferenced result. |
||||
|
|
||||
|
**[Note]** As we are conducting cross domain CLIP retrieval, the test images and the caption index should come from **different** benchmarks. |
@ -0,0 +1,7 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=1 python build_text_index.py\ |
||||
|
--clip_name openai/clip-vit-base-patch32\ |
||||
|
--text_file_path ../data/flickr30k/flickr30k_train.json\ |
||||
|
--save_index_prefix ./flickr30k_index/\ |
||||
|
--save_index_name index_matrix.txt\ |
||||
|
--save_mapping_dict_name text_mapping.json\ |
||||
|
--batch_size 128 |
@ -0,0 +1,7 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=0 python build_text_index.py\ |
||||
|
--clip_name openai/clip-vit-base-patch32\ |
||||
|
--text_file_path ../data/mscoco/mscoco_train.json\ |
||||
|
--save_index_prefix ./mscoco_index/\ |
||||
|
--save_index_name index_matrix.txt\ |
||||
|
--save_mapping_dict_name text_mapping.json\ |
||||
|
--batch_size 128 |
@ -0,0 +1,105 @@ |
|||||
|
import sys |
||||
|
import torch |
||||
|
import numpy as np |
||||
|
import progressbar |
||||
|
import os |
||||
|
|
||||
|
def parse_config(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32") |
||||
|
parser.add_argument("--text_file_path", type=str) |
||||
|
# save configuration |
||||
|
parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index') |
||||
|
parser.add_argument("--save_index_name", type=str) |
||||
|
parser.add_argument("--save_mapping_dict_name", type=str, |
||||
|
help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text") |
||||
|
# inference configuration |
||||
|
parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP") |
||||
|
return parser.parse_args() |
||||
|
|
||||
|
def load_batch_text(text_file_path, batch_size): |
||||
|
import json |
||||
|
with open(text_file_path) as f: |
||||
|
item_list = json.load(f) |
||||
|
|
||||
|
text_list = [] |
||||
|
for item in item_list: |
||||
|
captions = item["captions"] |
||||
|
for cap in captions: |
||||
|
text_list.append(cap) |
||||
|
print ('Number of text instances is {}'.format(len(text_list))) |
||||
|
|
||||
|
data_num = len(text_list) |
||||
|
batch_num = data_num // batch_size |
||||
|
batch_text_list = [] |
||||
|
s_idx, e_idx = 0, batch_size |
||||
|
for p_idx in range(batch_num): |
||||
|
one_batch_text_list = [] |
||||
|
for idx in range(s_idx, e_idx): |
||||
|
one_batch_text_list.append(text_list[idx]) |
||||
|
batch_text_list.append(one_batch_text_list) |
||||
|
return batch_text_list |
||||
|
|
||||
|
|
||||
|
import argparse |
||||
|
if __name__ == '__main__': |
||||
|
if torch.cuda.is_available(): |
||||
|
print ('Cuda is available.') |
||||
|
cuda_available = torch.cuda.is_available() |
||||
|
args = parse_config() |
||||
|
device = torch.device('cuda') |
||||
|
|
||||
|
import os |
||||
|
if os.path.exists(args.save_index_prefix): |
||||
|
pass |
||||
|
else: # recursively construct directory |
||||
|
os.makedirs(args.save_index_prefix, exist_ok=True) |
||||
|
|
||||
|
print ('Loading CLIP...') |
||||
|
from clip import CLIP |
||||
|
model = CLIP(args.clip_name) |
||||
|
if cuda_available: |
||||
|
model = model.cuda(device) |
||||
|
model.eval() |
||||
|
print ('CLIP loaded!') |
||||
|
|
||||
|
print ('Loading text data...') |
||||
|
batch_text_list = load_batch_text(args.text_file_path, args.batch_size) |
||||
|
print ('Text data loaded.') |
||||
|
|
||||
|
res_text_vec_list, res_text_list = [], [] |
||||
|
batch_num = len(batch_text_list) |
||||
|
print ('Number of batches is {}'.format(batch_num)) |
||||
|
print ('Start inference...') |
||||
|
p = progressbar.ProgressBar(batch_num) |
||||
|
p.start() |
||||
|
with torch.no_grad(): |
||||
|
for p_idx in range(batch_num): |
||||
|
p.update(p_idx) |
||||
|
one_text_batch = batch_text_list[p_idx] |
||||
|
one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu() |
||||
|
one_batch_vec_list = one_batch_vec.unbind(dim=0) |
||||
|
bsz = len(one_batch_vec_list) |
||||
|
for k in range(bsz): |
||||
|
res_text_vec_list.append(one_batch_vec_list[k].numpy()) |
||||
|
res_text_list.append(one_text_batch[k]) |
||||
|
p.finish() |
||||
|
assert len(res_text_vec_list) == len(res_text_list) |
||||
|
print ('Inference completed!') |
||||
|
|
||||
|
index_text_mapping_dict = {} |
||||
|
for k in range(len(res_text_list)): |
||||
|
index_text_mapping_dict[k] = res_text_list[k] |
||||
|
mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name |
||||
|
import json |
||||
|
with open(mapping_list_save_path, 'w') as outfile: |
||||
|
json.dump(index_text_mapping_dict, outfile, indent=4) |
||||
|
print ('Mapping dictionary saved!') |
||||
|
|
||||
|
print ('Start buiding index...') |
||||
|
index_save_path = args.save_index_prefix + '/' + args.save_index_name |
||||
|
with open(index_save_path, 'w', encoding = 'utf8') as o: |
||||
|
for vec in res_text_vec_list: |
||||
|
one_text = ' '.join([str(num) for num in vec]).strip() |
||||
|
o.writelines(one_text + '\n') |
||||
|
print ('Index completed!') |
@ -0,0 +1,146 @@ |
|||||
|
import torch |
||||
|
import requests |
||||
|
from torch import nn |
||||
|
from PIL import Image |
||||
|
|
||||
|
class CLIP(nn.Module): |
||||
|
def __init__(self, model_name): |
||||
|
super(CLIP, self).__init__() |
||||
|
# model name: e.g. openai/clip-vit-base-patch32 |
||||
|
print ('Initializing CLIP model...') |
||||
|
from transformers import CLIPProcessor, CLIPModel |
||||
|
self.model = CLIPModel.from_pretrained(model_name) |
||||
|
self.model.eval() |
||||
|
self.processor = CLIPProcessor.from_pretrained(model_name) |
||||
|
from transformers import CLIPTokenizer |
||||
|
self.tokenizer = CLIPTokenizer.from_pretrained(model_name) |
||||
|
self.cuda_has_been_checked = False |
||||
|
print ('CLIP model initialized.') |
||||
|
|
||||
|
def check_cuda(self): |
||||
|
self.cuda_available = next(self.model.parameters()).is_cuda |
||||
|
self.device = next(self.model.parameters()).get_device() |
||||
|
if self.cuda_available: |
||||
|
print ('Cuda is available.') |
||||
|
print ('Device is {}'.format(self.device)) |
||||
|
else: |
||||
|
print ('Cuda is not available.') |
||||
|
print ('Device is {}'.format(self.device)) |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def compute_image_representation_from_image_path(self, image_path): |
||||
|
if not self.cuda_has_been_checked: |
||||
|
self.check_cuda() |
||||
|
self.cuda_has_been_checked = True |
||||
|
else: |
||||
|
pass |
||||
|
# image_path: the path of the image |
||||
|
image = Image.open(image_path) |
||||
|
inputs = self.processor(images=image, return_tensors="pt") |
||||
|
pixel_values = inputs['pixel_values'] |
||||
|
if self.cuda_available: |
||||
|
pixel_values = pixel_values.cuda(self.device) |
||||
|
visual_outputs = self.model.vision_model(pixel_values=pixel_values) |
||||
|
image_embeds = visual_outputs[1] |
||||
|
image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] |
||||
|
return image_embeds |
||||
|
|
||||
|
def compute_image_representation_from_image_instance(self, image): |
||||
|
if not self.cuda_has_been_checked: |
||||
|
self.check_cuda() |
||||
|
self.cuda_has_been_checked = True |
||||
|
else: |
||||
|
pass |
||||
|
# image_path: the path of the image |
||||
|
inputs = self.processor(images=image, return_tensors="pt") |
||||
|
pixel_values = inputs['pixel_values'] |
||||
|
if self.cuda_available: |
||||
|
pixel_values = pixel_values.cuda(self.device) |
||||
|
visual_outputs = self.model.vision_model(pixel_values=pixel_values) |
||||
|
image_embeds = visual_outputs[1] |
||||
|
image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] |
||||
|
return image_embeds |
||||
|
|
||||
|
def compute_text_representation(self, text_list): |
||||
|
if not self.cuda_has_been_checked: |
||||
|
self.check_cuda() |
||||
|
self.cuda_has_been_checked = True |
||||
|
else: |
||||
|
pass |
||||
|
# text_list: a list of text |
||||
|
text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", |
||||
|
max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) |
||||
|
# self.tokenizer.max_len_single_sentence + 2 = 77 |
||||
|
input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] |
||||
|
if self.cuda_available: |
||||
|
input_ids = input_ids.cuda(self.device) |
||||
|
attention_mask = attention_mask.cuda(self.device) |
||||
|
text_outputs = self.model.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask |
||||
|
) |
||||
|
text_embeds = text_outputs[1] |
||||
|
text_embeds = self.model.text_projection(text_embeds) |
||||
|
return text_embeds |
||||
|
|
||||
|
def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds): |
||||
|
''' |
||||
|
image_embeds: 1 x embed_dim |
||||
|
text_embeds: len(text_list) x embed_dim |
||||
|
''' |
||||
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) |
||||
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) |
||||
|
logit_scale = self.model.logit_scale.exp() |
||||
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
||||
|
logits_per_image = logits_per_text.T |
||||
|
return logits_per_image.softmax(dim=1) # 1 x len(text_list) |
||||
|
|
||||
|
def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list): |
||||
|
text_embeds = self.compute_text_representation(text_list) |
||||
|
return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds) |
||||
|
|
||||
|
### -------------------- functions for building index ---------------------- ### |
||||
|
def compute_batch_index_image_features(self, image_list): |
||||
|
''' |
||||
|
# list of image instances |
||||
|
''' |
||||
|
if not self.cuda_has_been_checked: |
||||
|
self.check_cuda() |
||||
|
self.cuda_has_been_checked = True |
||||
|
else: |
||||
|
pass |
||||
|
# image_path: the path of the image |
||||
|
inputs = self.processor(images=image_list, return_tensors="pt") |
||||
|
pixel_values = inputs['pixel_values'] |
||||
|
if self.cuda_available: |
||||
|
pixel_values = pixel_values.cuda(self.device) |
||||
|
visual_outputs = self.model.vision_model(pixel_values=pixel_values) |
||||
|
image_embeds = visual_outputs[1] |
||||
|
image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] |
||||
|
return image_embeds # len(image_list) x embed_dim |
||||
|
|
||||
|
def compute_batch_index_text_representation(self, text_list): |
||||
|
if not self.cuda_has_been_checked: |
||||
|
self.check_cuda() |
||||
|
self.cuda_has_been_checked = True |
||||
|
else: |
||||
|
pass |
||||
|
# text_list: a list of text |
||||
|
#text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt") |
||||
|
text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", |
||||
|
max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) |
||||
|
input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] |
||||
|
if self.cuda_available: |
||||
|
input_ids = input_ids.cuda(self.device) |
||||
|
attention_mask = attention_mask.cuda(self.device) |
||||
|
text_outputs = self.model.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask |
||||
|
) |
||||
|
text_embeds = text_outputs[1] |
||||
|
text_embeds = self.model.text_projection(text_embeds) |
||||
|
return text_embeds |
||||
|
#logit_scale = self.model.logit_scale.exp() |
||||
|
#text_embeds = text_embeds * logit_scale |
||||
|
#return text_embeds |
||||
|
|
@ -0,0 +1,135 @@ |
|||||
|
import json |
||||
|
import copy |
||||
|
import torch |
||||
|
import progressbar |
||||
|
import numpy as np |
||||
|
from PIL import Image |
||||
|
|
||||
|
class CLIPIndex: |
||||
|
def __init__(self, index_matrix_path, mapping_dict_path, clip): |
||||
|
''' |
||||
|
index_path: the pre-trained index |
||||
|
mapping_dict_path: the pre-indexed mapping dictionary |
||||
|
clip: the pre-trained clip model |
||||
|
''' |
||||
|
print ('Loading index...') |
||||
|
self.index_matrix = self.normalization(self.load_matrix(index_matrix_path)) |
||||
|
print ('Index loaded.') |
||||
|
print (self.index_matrix.shape) |
||||
|
with open(mapping_dict_path) as f: |
||||
|
self.mapping_dict = json.load(f) |
||||
|
self.clip = clip |
||||
|
|
||||
|
def load_matrix(self, in_f): |
||||
|
matrix_list = [] |
||||
|
with open(in_f, 'r', encoding = 'utf8') as i: |
||||
|
lines = i.readlines() |
||||
|
for l in lines: |
||||
|
one_vec = [float(num) for num in l.strip('\n').split()] |
||||
|
matrix_list.append(one_vec) |
||||
|
return np.array(matrix_list) |
||||
|
|
||||
|
def normalization(self, matrix): |
||||
|
''' |
||||
|
matrix: num_instance x num_feature |
||||
|
''' |
||||
|
return matrix / np.linalg.norm(matrix, axis=1, keepdims=True) |
||||
|
|
||||
|
def get_image_representation(self, image_path): |
||||
|
image_instance = Image.open(image_path) |
||||
|
image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy() |
||||
|
image_vec = self.normalization(image_vec) |
||||
|
return image_vec |
||||
|
|
||||
|
def search_text(self, image_path): |
||||
|
image_vec = self.get_image_representation(image_path) |
||||
|
sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1] |
||||
|
top_idx = sort_idx_list[0] |
||||
|
return self.mapping_dict[str(top_idx)] |
||||
|
|
||||
|
|
||||
|
def parse_config(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
parser.add_argument("--clip_name", type=str) |
||||
|
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") |
||||
|
parser.add_argument("--test_path", type=str) |
||||
|
# index configuration |
||||
|
parser.add_argument("--index_matrix_path", type=str) |
||||
|
parser.add_argument("--mapping_dict_path", type=str) |
||||
|
# save configuration |
||||
|
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") |
||||
|
parser.add_argument("--save_name", type=str, help="the name of the saved file") |
||||
|
return parser.parse_args() |
||||
|
|
||||
|
import argparse |
||||
|
if __name__ == '__main__': |
||||
|
if torch.cuda.is_available(): |
||||
|
print ('Cuda is available.') |
||||
|
cuda_available = torch.cuda.is_available() |
||||
|
args = parse_config() |
||||
|
device = torch.device('cuda') |
||||
|
|
||||
|
save_path_prefix = args.save_path_prefix |
||||
|
import os |
||||
|
if os.path.exists(save_path_prefix): |
||||
|
pass |
||||
|
else: # recursively construct directory |
||||
|
os.makedirs(save_path_prefix, exist_ok=True) |
||||
|
# parse save name |
||||
|
save_name = args.save_name |
||||
|
full_save_path = save_path_prefix + '/' + save_name |
||||
|
print ('full save path is {}'.format(full_save_path)) |
||||
|
|
||||
|
print ('Loading CLIP...') |
||||
|
from clip import CLIP |
||||
|
clip = CLIP(args.clip_name) |
||||
|
if cuda_available: |
||||
|
clip = clip.cuda(device) |
||||
|
clip.eval() |
||||
|
print ('CLIP loaded!') |
||||
|
|
||||
|
clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip) |
||||
|
|
||||
|
print ('Loading data...') |
||||
|
import json |
||||
|
with open(args.test_path) as f: |
||||
|
item_list = json.load(f) |
||||
|
print ('Data loaded.') |
||||
|
print ('Number of test instances is {}'.format(len(item_list))) |
||||
|
|
||||
|
result_list = [] |
||||
|
invalid_num = 0 |
||||
|
print ('----------------------------------------------------------------') |
||||
|
with torch.no_grad(): |
||||
|
test_num = len(item_list) |
||||
|
#test_num = 10 |
||||
|
print ('Number of inference instances is {}'.format(test_num)) |
||||
|
p = progressbar.ProgressBar(test_num) |
||||
|
p.start() |
||||
|
for p_idx in range(test_num): |
||||
|
p.update(p_idx) |
||||
|
one_test_dict = item_list[p_idx] |
||||
|
|
||||
|
one_res_dict = { |
||||
|
'split':one_test_dict['split'], |
||||
|
'image_name':one_test_dict['image_name'], |
||||
|
#'file_path':one_test_dict['file_path'], |
||||
|
'captions':one_test_dict['captions'] |
||||
|
} |
||||
|
|
||||
|
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] |
||||
|
try: |
||||
|
output_text = clipindex.search_text(image_full_path) |
||||
|
one_res_dict['prediction'] = output_text |
||||
|
result_list.append(one_res_dict) |
||||
|
except: |
||||
|
invalid_num += 1 |
||||
|
print ('invalid number is {}'.format(invalid_num)) |
||||
|
continue |
||||
|
p.finish() |
||||
|
print ('Inference completed!') |
||||
|
|
||||
|
import json |
||||
|
with open(full_save_path, 'w') as outfile: |
||||
|
json.dump(result_list, outfile, indent=4) |
||||
|
|
@ -0,0 +1,8 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\ |
||||
|
--clip_name openai/clip-vit-base-patch32\ |
||||
|
--test_image_prefix_path ../data/flickr30k/test_images/\ |
||||
|
--test_path ../data/flickr30k/flickr30k_test.json\ |
||||
|
--index_matrix_path ./flickr30k_index/index_matrix.txt\ |
||||
|
--mapping_dict_path ./flickr30k_index/text_mapping.json\ |
||||
|
--save_path_prefix ../inference_result/flickr30k/baselines/\ |
||||
|
--save_name flickr30k_in_domain_clipretrieval.json |
@ -0,0 +1,8 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=0 python clipretrieval.py\ |
||||
|
--clip_name openai/clip-vit-base-patch32\ |
||||
|
--test_image_prefix_path ../data/mscoco/test_images/\ |
||||
|
--test_path ../data/mscoco/mscoco_test.json\ |
||||
|
--index_matrix_path ./mscoco_index/index_matrix.txt\ |
||||
|
--mapping_dict_path ./mscoco_index/text_mapping.json\ |
||||
|
--save_path_prefix ../inference_result/mscoco/baselines/\ |
||||
|
--save_name mscoco_in_domain_clipretrieval.json |
@ -0,0 +1,8 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\ |
||||
|
--clip_name openai/clip-vit-base-patch32\ |
||||
|
--test_image_prefix_path ../data/mscoco/test_images/\ |
||||
|
--test_path ../data/mscoco/mscoco_test.json\ |
||||
|
--index_matrix_path ./flickr30k_index/index_matrix.txt\ |
||||
|
--mapping_dict_path ./flickr30k_index/text_mapping.json\ |
||||
|
--save_path_prefix ../inference_result/flickr30k_model_to_mscoco/\ |
||||
|
--save_name source_flickr30k_target_mscoco_clip_retrieval.json |
@ -0,0 +1,8 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=1 python clipretrieval.py\ |
||||
|
--clip_name openai/clip-vit-base-patch32\ |
||||
|
--test_image_prefix_path ../data/flickr30k/test_images/\ |
||||
|
--test_path ../data/flickr30k/flickr30k_test.json\ |
||||
|
--index_matrix_path ./mscoco_index/index_matrix.txt\ |
||||
|
--mapping_dict_path ./mscoco_index/text_mapping.json\ |
||||
|
--save_path_prefix ../inference_result/mscoco_model_to_flickr30k/\ |
||||
|
--save_name source_mscoco_target_flickr30k_clip_retrieval.json |
@ -0,0 +1,99 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
|
||||
|
from re import I |
||||
|
import sys |
||||
|
import os |
||||
|
import pathlib |
||||
|
import pickle |
||||
|
from argparse import Namespace |
||||
|
|
||||
|
import torch |
||||
|
import torchvision |
||||
|
from torchvision import transforms |
||||
|
from transformers import GPT2Tokenizer |
||||
|
|
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee import register |
||||
|
from towhee.models import clip |
||||
|
|
||||
|
class Magic(NNOperator): |
||||
|
""" |
||||
|
Magic image captioning operator |
||||
|
""" |
||||
|
def __init__(self, model_name: str): |
||||
|
super().__init__() |
||||
|
path = str(pathlib.Path(__file__).parent) |
||||
|
sys.path.append(path) |
||||
|
from clip import CLIP |
||||
|
from simctg import SimCTG |
||||
|
sys.path.pop() |
||||
|
|
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
# Load Language Model |
||||
|
language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco' |
||||
|
sos_token, pad_token = r'<-start_of_text->', r'<-pad->' |
||||
|
self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device) |
||||
|
self.generation_model.eval() |
||||
|
|
||||
|
model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32" |
||||
|
self.clip = CLIP(model_name).to(self.device) |
||||
|
self.clip.eval() |
||||
|
|
||||
|
|
||||
|
def _preprocess(self, img): |
||||
|
img = to_pil(img) |
||||
|
processed_img = self.transf_1(img) |
||||
|
processed_img = self.transf_2(processed_img) |
||||
|
processed_img = processed_img.to(self.device) |
||||
|
return processed_img |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def inference_single_data(self, data): |
||||
|
text = self._inference_from_image(data) |
||||
|
return text |
||||
|
|
||||
|
def __call__(self, data): |
||||
|
if not isinstance(data, list): |
||||
|
data = [data] |
||||
|
else: |
||||
|
data = data |
||||
|
results = [] |
||||
|
for single_data in data: |
||||
|
result = self.inference_single_data(single_data) |
||||
|
results.append(result) |
||||
|
if len(data) == 1: |
||||
|
return results[0] |
||||
|
else: |
||||
|
return results |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def _inference_from_image(self, img): |
||||
|
#img = self._preprocess(img).unsqueeze(0) |
||||
|
k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16 |
||||
|
eos_token = '<|endoftext|>' |
||||
|
with torch.no_grad(): |
||||
|
output = generation_model.magic_search(input_ids, k, |
||||
|
alpha, decoding_len, beta, image_instance, clip, 60) |
||||
|
|
||||
|
return out |
||||
|
|
||||
|
def _configs(self): |
||||
|
config = {} |
||||
|
config['expansionnet_rf'] = {} |
||||
|
config['expansionnet_rf']['weights'] = 'rf_model.pth' |
||||
|
return config |
@ -0,0 +1,89 @@ |
|||||
|
### Our Implementation of the ZeroCap Baseline Model |
||||
|
|
||||
|
**** |
||||
|
### Catalogue: |
||||
|
* <a href='#environment'>1. Environment Preparation</a> |
||||
|
* <a href='#mscoco'>2. Image Captioning on MSCOCO</a> |
||||
|
* <a href='#flickr30k'>3. Image Captioning on Flickr30k</a> |
||||
|
* <a href='#flickr30k_to_mscoco'>4. Cross Domain Image Captioning on MSCOCO</a> |
||||
|
* <a href='#mscoco_to_flickr30k'>5. Cross Domain Image Captioning on Flickr30k</a> |
||||
|
* <a href='#citation'>6. Citation</a> |
||||
|
* <a href='#acknowledgements'>7. Acknowledgements</a> |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='environment'/> |
||||
|
|
||||
|
#### 1. Environment Preparation: |
||||
|
To install the correct environment, please run the following command: |
||||
|
```yaml |
||||
|
pip install -r requirements.txt |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='mscoco'/> |
||||
|
|
||||
|
#### 2. Image Captioning on MSCOCO: |
||||
|
To perform image captioning on MSCOCO, please run the following command: |
||||
|
```yaml |
||||
|
chmod +x ./mscoco_zerocap.sh |
||||
|
./mscoco_zerocap.sh |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='flickr30k'/> |
||||
|
|
||||
|
#### 3. Image Captioning on Flickr30k: |
||||
|
To perform image captioning on Flickr30k, please run the following command: |
||||
|
```yaml |
||||
|
chmod +x ./flickr30k_zerocap.sh |
||||
|
./flickr30k_zerocap.sh |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='flickr30k_to_mscoco'/> |
||||
|
|
||||
|
#### 4. Cross Domain Image Captioning on MSCOCO: |
||||
|
To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command: |
||||
|
```yaml |
||||
|
chmod +x ./flickr30k_to_mscoco_zerocap.sh |
||||
|
./flickr30k_to_mscoco_zerocap.sh |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='mscoco_to_flickr30k'/> |
||||
|
|
||||
|
#### 5. Cross Domain Image Captioning on Flickr30k: |
||||
|
To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command: |
||||
|
```yaml |
||||
|
chmod +x ./mscoco_to_flickr30k_zerocap.sh |
||||
|
./mscoco_to_flickr30k_zerocap.sh |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='citation'/> |
||||
|
|
||||
|
#### 6. Citation: |
||||
|
If you find our code helpful, please cite the original paper as |
||||
|
|
||||
|
```bibtex |
||||
|
@article{tewel2021zero, |
||||
|
title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, |
||||
|
author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, |
||||
|
journal={arXiv preprint arXiv:2111.14447}, |
||||
|
year={2021} |
||||
|
} |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
|
||||
|
<span id='acknowledgements'/> |
||||
|
|
||||
|
#### 7. Acknowledgements: |
||||
|
We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text). |
||||
|
|
@ -0,0 +1,12 @@ |
|||||
|
build: |
||||
|
gpu: true |
||||
|
python_version: "3.8" |
||||
|
system_packages: |
||||
|
- "libgl1-mesa-glx" |
||||
|
- "libglib2.0-0" |
||||
|
python_packages: |
||||
|
- "git+https://github.com/openai/CLIP.git" |
||||
|
- "git+https://github.com/YoadTew/zero-shot-image-to-text.git" |
||||
|
|
||||
|
predict: "predict.py:Predictor" |
||||
|
#predict: "predict_arithmetic.py:Predictor" |
@ -0,0 +1,14 @@ |
|||||
|
#!/bin/bash |
||||
|
|
||||
|
# lm_model: |
||||
|
# 1. cambridgeltl/magic_mscoco |
||||
|
# 2. cambridgeltl/magic_flickr30k |
||||
|
CUDA_VISIBLE_DEVICES=1 python run.py \ |
||||
|
--beam_size 1 \ |
||||
|
--target_seq_length 16 \ |
||||
|
--reset_context_delta \ |
||||
|
--lm_model cambridgeltl/magic_flickr30k \ |
||||
|
--test_image_prefix_path ../data/flickr30k/test_images \ |
||||
|
--test_path ../data/flickr30k/flickr30k_test.json \ |
||||
|
--save_path_prefix ../inference_result/flickr30k/baselines/ \ |
||||
|
--save_name zerocap_result.json |
Binary file not shown.
@ -0,0 +1,389 @@ |
|||||
|
import numpy as np |
||||
|
from torch import nn |
||||
|
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer |
||||
|
from transformers.models.gpt_neo import GPTNeoForCausalLM |
||||
|
import torch |
||||
|
import clip |
||||
|
from PIL import Image |
||||
|
from datetime import datetime |
||||
|
import sys |
||||
|
|
||||
|
|
||||
|
def log_info(text, verbose=True): |
||||
|
if verbose: |
||||
|
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
||||
|
print(f'{dt_string} | {text}') |
||||
|
sys.stdout.flush() |
||||
|
|
||||
|
|
||||
|
def add_context(x, y): |
||||
|
return (x[0] + y[0], x[1] + y[1]) |
||||
|
|
||||
|
|
||||
|
def convert_models_to_fp32(model): |
||||
|
for p in model.parameters(): |
||||
|
p.data = p.data.float() |
||||
|
|
||||
|
|
||||
|
class CLIPTextGenerator: |
||||
|
def __init__(self, |
||||
|
seed=0, |
||||
|
lm_model='gpt-2', |
||||
|
forbidden_tokens_file_path='./forbidden_tokens.npy', |
||||
|
clip_checkpoints='./clip_checkpoints', |
||||
|
target_seq_length=15, |
||||
|
reset_context_delta=True, |
||||
|
num_iterations=5, |
||||
|
clip_loss_temperature=0.01, |
||||
|
clip_scale=1., |
||||
|
ce_scale=0.2, |
||||
|
stepsize=0.3, |
||||
|
grad_norm_factor=0.9, |
||||
|
fusion_factor=0.99, |
||||
|
repetition_penalty=1., |
||||
|
end_token='.', |
||||
|
end_factor=1.01, |
||||
|
forbidden_factor=20, |
||||
|
**kwargs): |
||||
|
|
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
|
||||
|
# set Random seed |
||||
|
torch.manual_seed(seed) |
||||
|
np.random.seed(seed) |
||||
|
|
||||
|
# Initialize Language model |
||||
|
self.context_prefix = '' |
||||
|
|
||||
|
self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model) |
||||
|
self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True) |
||||
|
self.context_prefix = self.lm_tokenizer.bos_token |
||||
|
|
||||
|
self.lm_model.to(self.device) |
||||
|
self.lm_model.eval() |
||||
|
|
||||
|
self.forbidden_tokens = np.load(forbidden_tokens_file_path) |
||||
|
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if |
||||
|
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())] |
||||
|
|
||||
|
# Freeze LM weights |
||||
|
for param in self.lm_model.parameters(): |
||||
|
param.requires_grad = False |
||||
|
|
||||
|
# Initialize CLIP |
||||
|
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, |
||||
|
download_root=clip_checkpoints, jit=False) |
||||
|
# convert_models_to_fp32(self.clip) |
||||
|
|
||||
|
# Init arguments |
||||
|
self.target_seq_length = target_seq_length |
||||
|
self.reset_context_delta = reset_context_delta |
||||
|
self.num_iterations = num_iterations |
||||
|
self.clip_loss_temperature = clip_loss_temperature |
||||
|
self.clip_scale = clip_scale |
||||
|
self.ce_scale = ce_scale |
||||
|
self.stepsize = stepsize |
||||
|
self.grad_norm_factor = grad_norm_factor |
||||
|
self.fusion_factor = fusion_factor |
||||
|
self.repetition_penalty = repetition_penalty |
||||
|
self.end_token = self.lm_tokenizer.encode(end_token)[0] |
||||
|
self.end_factor = end_factor |
||||
|
self.ef_idx = 1 |
||||
|
self.forbidden_factor = forbidden_factor |
||||
|
|
||||
|
def get_img_feature(self, img_path, weights): |
||||
|
imgs = [Image.open(x) for x in img_path] |
||||
|
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
||||
|
|
||||
|
if weights is not None: |
||||
|
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) |
||||
|
else: |
||||
|
image_features = sum(image_fts) |
||||
|
|
||||
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
||||
|
return image_features.detach() |
||||
|
|
||||
|
def get_txt_features(self, text): |
||||
|
clip_texts = clip.tokenize(text).to(self.device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
text_features = self.clip.encode_text(clip_texts) |
||||
|
|
||||
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
||||
|
return text_features.detach() |
||||
|
|
||||
|
def get_combined_feature(self, img_path, texts, weights_i, weights_t): |
||||
|
imgs = [Image.open(x) for x in img_path] |
||||
|
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
||||
|
clip_texts = [clip.tokenize(x).to(self.device) for x in texts] |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
||||
|
text_fts = [self.clip.encode_text(x) for x in clip_texts] |
||||
|
|
||||
|
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) |
||||
|
if weights_t is not None: |
||||
|
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) |
||||
|
|
||||
|
features = features / features.norm(dim=-1, keepdim=True) |
||||
|
return features.detach() |
||||
|
|
||||
|
def run(self, image_features, cond_text, beam_size): |
||||
|
self.image_features = image_features |
||||
|
|
||||
|
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) |
||||
|
|
||||
|
output_tokens, output_text = self.generate_text(context_tokens, beam_size) |
||||
|
|
||||
|
return output_text |
||||
|
|
||||
|
def generate_text(self, context_tokens, beam_size): |
||||
|
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
gen_tokens = None |
||||
|
scores = None |
||||
|
seq_lengths = torch.ones(beam_size, device=self.device) |
||||
|
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) |
||||
|
|
||||
|
for i in range(self.target_seq_length): |
||||
|
probs = self.get_next_probs(i, context_tokens) |
||||
|
logits = probs.log() |
||||
|
|
||||
|
if scores is None: |
||||
|
scores, next_tokens = logits.topk(beam_size, -1) |
||||
|
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) |
||||
|
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
||||
|
|
||||
|
if gen_tokens is None: |
||||
|
gen_tokens = next_tokens |
||||
|
else: |
||||
|
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) |
||||
|
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) |
||||
|
else: |
||||
|
logits[is_stopped] = -float(np.inf) |
||||
|
logits[is_stopped, 0] = 0 |
||||
|
scores_sum = scores[:, None] + logits |
||||
|
seq_lengths[~is_stopped] += 1 |
||||
|
scores_sum_average = scores_sum / seq_lengths[:, None] |
||||
|
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( |
||||
|
beam_size, -1) |
||||
|
next_tokens_source = next_tokens // scores_sum.shape[1] |
||||
|
seq_lengths = seq_lengths[next_tokens_source] |
||||
|
next_tokens = next_tokens % scores_sum.shape[1] |
||||
|
next_tokens = next_tokens.unsqueeze(1) |
||||
|
gen_tokens = gen_tokens[next_tokens_source] |
||||
|
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) |
||||
|
context_tokens = context_tokens[next_tokens_source] |
||||
|
scores = scores_sum_average * seq_lengths |
||||
|
is_stopped = is_stopped[next_tokens_source] |
||||
|
|
||||
|
context_tokens = torch.cat((context_tokens, next_tokens), dim=1) |
||||
|
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() |
||||
|
|
||||
|
#### |
||||
|
tmp_scores = scores / seq_lengths |
||||
|
tmp_output_list = gen_tokens.cpu().numpy() |
||||
|
tmp_output_texts = [ |
||||
|
self.lm_tokenizer.decode(tmp_output) |
||||
|
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) |
||||
|
] |
||||
|
tmp_order = tmp_scores.argsort(descending=True) |
||||
|
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] |
||||
|
log_info(tmp_output_texts, verbose=True) |
||||
|
#### |
||||
|
|
||||
|
if is_stopped.all(): |
||||
|
break |
||||
|
|
||||
|
scores = scores / seq_lengths |
||||
|
output_list = gen_tokens.cpu().numpy() |
||||
|
output_texts = [ |
||||
|
self.lm_tokenizer.decode(output[: int(length)]) |
||||
|
for output, length in zip(output_list, seq_lengths) |
||||
|
] |
||||
|
order = scores.argsort(descending=True) |
||||
|
output_texts = [output_texts[i] for i in order] |
||||
|
|
||||
|
return context_tokens, output_texts |
||||
|
|
||||
|
def get_next_probs(self, i, context_tokens): |
||||
|
last_token = context_tokens[:, -1:] |
||||
|
|
||||
|
if self.reset_context_delta and context_tokens.size(1) > 1: |
||||
|
context = self.lm_model(context_tokens[:, :-1])["past_key_values"] |
||||
|
|
||||
|
# Logits of LM with unshifted context |
||||
|
logits_before_shift = self.lm_model(context_tokens)["logits"] |
||||
|
logits_before_shift = logits_before_shift[:, -1, :] |
||||
|
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) |
||||
|
|
||||
|
if context: |
||||
|
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) |
||||
|
|
||||
|
lm_output = self.lm_model(last_token, past_key_values=context) |
||||
|
logits, past = ( |
||||
|
lm_output["logits"], |
||||
|
lm_output["past_key_values"], |
||||
|
) |
||||
|
logits = logits[:, -1, :] |
||||
|
|
||||
|
logits = self.update_special_tokens_logits(context_tokens, i, logits) |
||||
|
|
||||
|
probs = nn.functional.softmax(logits, dim=-1) |
||||
|
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) |
||||
|
probs = probs / probs.sum() |
||||
|
|
||||
|
return probs |
||||
|
|
||||
|
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): |
||||
|
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] |
||||
|
|
||||
|
window_mask = torch.ones_like(context[0][0]).to(self.device) |
||||
|
|
||||
|
for i in range(self.num_iterations): |
||||
|
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in |
||||
|
context_delta] |
||||
|
|
||||
|
for p0, p1 in curr_shift: |
||||
|
p0.retain_grad() |
||||
|
p1.retain_grad() |
||||
|
|
||||
|
shifted_context = list(map(add_context, context, curr_shift)) |
||||
|
|
||||
|
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) |
||||
|
logits = shifted_outputs["logits"][:, -1, :] |
||||
|
probs = nn.functional.softmax(logits, dim=-1) |
||||
|
|
||||
|
loss = 0.0 |
||||
|
|
||||
|
# CLIP LOSS |
||||
|
clip_loss, clip_losses = self.clip_loss(probs, context_tokens) |
||||
|
loss += self.clip_scale * clip_loss |
||||
|
|
||||
|
# CE/Fluency loss |
||||
|
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) |
||||
|
loss += ce_loss.sum() |
||||
|
|
||||
|
loss.backward() |
||||
|
|
||||
|
# ---------- Weights ---------- |
||||
|
combined_scores_k = -(ce_loss) |
||||
|
combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) |
||||
|
|
||||
|
# minmax |
||||
|
if combined_scores_k.shape[0] == 1: |
||||
|
tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) |
||||
|
else: |
||||
|
tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( |
||||
|
combined_scores_k.max() - combined_scores_k.min()) |
||||
|
tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( |
||||
|
combined_scores_c.max() - combined_scores_c.min()) |
||||
|
|
||||
|
tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c |
||||
|
tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) |
||||
|
|
||||
|
factor = 1 |
||||
|
|
||||
|
# --------- Specific Gen --------- |
||||
|
sep_grads = None |
||||
|
|
||||
|
for b in range(context_tokens.shape[0]): |
||||
|
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] |
||||
|
for p_ in curr_shift] |
||||
|
|
||||
|
# normalize gradients |
||||
|
tmp_grad = [tuple([-self.stepsize * factor * ( |
||||
|
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ |
||||
|
j] ** self.grad_norm_factor).data.cpu().numpy() |
||||
|
for j, x in enumerate(p_)]) |
||||
|
for i, p_ in enumerate(curr_shift)] |
||||
|
if sep_grads is None: |
||||
|
sep_grads = tmp_grad |
||||
|
else: |
||||
|
for l_index in range(len(sep_grads)): |
||||
|
sep_grads[l_index] = list(sep_grads[l_index]) |
||||
|
for k_index in range(len(sep_grads[0])): |
||||
|
sep_grads[l_index][k_index] = np.concatenate( |
||||
|
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) |
||||
|
sep_grads[l_index] = tuple(sep_grads[l_index]) |
||||
|
final_grads = sep_grads |
||||
|
|
||||
|
# --------- update context --------- |
||||
|
context_delta = list(map(add_context, final_grads, context_delta)) |
||||
|
|
||||
|
for p0, p1 in curr_shift: |
||||
|
p0.grad.data.zero_() |
||||
|
p1.grad.data.zero_() |
||||
|
|
||||
|
new_context = [] |
||||
|
for p0, p1 in context: |
||||
|
new_context.append((p0.detach(), p1.detach())) |
||||
|
context = new_context |
||||
|
|
||||
|
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) |
||||
|
for p_ in context_delta] |
||||
|
context = list(map(add_context, context, context_delta)) |
||||
|
|
||||
|
new_context = [] |
||||
|
for p0, p1 in context: |
||||
|
new_context.append((p0.detach(), p1.detach())) |
||||
|
context = new_context |
||||
|
|
||||
|
return context |
||||
|
|
||||
|
def update_special_tokens_logits(self, context_tokens, i, logits): |
||||
|
for beam_id in range(context_tokens.shape[0]): |
||||
|
for token_idx in set(context_tokens[beam_id][-4:].tolist()): |
||||
|
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) |
||||
|
logits[beam_id, token_idx] /= factor |
||||
|
|
||||
|
if i >= self.ef_idx: |
||||
|
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) |
||||
|
logits[beam_id, self.end_token] *= factor |
||||
|
if i == 0: |
||||
|
start_factor = 1.6 |
||||
|
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) |
||||
|
logits[beam_id, self.end_token] /= factor |
||||
|
|
||||
|
for token_idx in list(self.forbidden_tokens): |
||||
|
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) |
||||
|
logits[beam_id, token_idx] /= factor |
||||
|
|
||||
|
return logits |
||||
|
|
||||
|
def clip_loss(self, probs, context_tokens): |
||||
|
for p_ in self.clip.transformer.parameters(): |
||||
|
if p_.grad is not None: |
||||
|
p_.grad.data.zero_() |
||||
|
|
||||
|
top_size = 512 |
||||
|
_, top_indices = probs.topk(top_size, -1) |
||||
|
|
||||
|
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] |
||||
|
|
||||
|
clip_loss = 0 |
||||
|
losses = [] |
||||
|
for idx_p in range(probs.shape[0]): |
||||
|
top_texts = [] |
||||
|
prefix_text = prefix_texts[idx_p] |
||||
|
for x in top_indices[idx_p]: |
||||
|
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
||||
|
text_features = self.get_txt_features(top_texts) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
similiraties = (self.image_features @ text_features.T) |
||||
|
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
||||
|
target_probs = target_probs.type(torch.float32) |
||||
|
|
||||
|
target = torch.zeros_like(probs[idx_p]) |
||||
|
target[top_indices[idx_p]] = target_probs[0] |
||||
|
target = target.unsqueeze(0) |
||||
|
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
||||
|
|
||||
|
clip_loss += cur_clip_loss |
||||
|
losses.append(cur_clip_loss) |
||||
|
|
||||
|
return clip_loss, losses |
@ -0,0 +1,449 @@ |
|||||
|
import numpy as np |
||||
|
from torch import nn |
||||
|
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer |
||||
|
from transformers.models.gpt_neo import GPTNeoForCausalLM |
||||
|
import torch |
||||
|
import clip |
||||
|
from PIL import Image |
||||
|
from datetime import datetime |
||||
|
import sys |
||||
|
|
||||
|
class TextCLIP(nn.Module): |
||||
|
def __init__(self, model): |
||||
|
super(TextCLIP, self).__init__() |
||||
|
self.model = model |
||||
|
|
||||
|
def forward(self, text): |
||||
|
return self.model.encode_text(text) |
||||
|
|
||||
|
|
||||
|
class ImageCLIP(nn.Module): |
||||
|
def __init__(self, model): |
||||
|
super(ImageCLIP, self).__init__() |
||||
|
self.model = model |
||||
|
|
||||
|
def forward(self, image): |
||||
|
return self.model.encode_image(image) |
||||
|
|
||||
|
def log_info(text, verbose=True): |
||||
|
if verbose: |
||||
|
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
||||
|
print(f'{dt_string} | {text}') |
||||
|
sys.stdout.flush() |
||||
|
|
||||
|
|
||||
|
def add_context(x, y): |
||||
|
return (x[0] + y[0], x[1] + y[1]) |
||||
|
|
||||
|
|
||||
|
def convert_models_to_fp32(model): |
||||
|
for p in model.parameters(): |
||||
|
p.data = p.data.float() |
||||
|
|
||||
|
|
||||
|
class CLIPTextGenerator: |
||||
|
def __init__(self, |
||||
|
seed=0, |
||||
|
lm_model='gpt-2', |
||||
|
forbidden_tokens_file_path='./forbidden_tokens.npy', |
||||
|
clip_checkpoints='./clip_checkpoints', |
||||
|
target_seq_length=15, |
||||
|
reset_context_delta=True, |
||||
|
num_iterations=5, |
||||
|
clip_loss_temperature=0.01, |
||||
|
clip_scale=1., |
||||
|
ce_scale=0.2, |
||||
|
stepsize=0.3, |
||||
|
grad_norm_factor=0.9, |
||||
|
fusion_factor=0.99, |
||||
|
repetition_penalty=1., |
||||
|
end_token='.', |
||||
|
end_factor=1.01, |
||||
|
forbidden_factor=20, |
||||
|
**kwargs): |
||||
|
|
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
|
||||
|
# set Random seed |
||||
|
torch.manual_seed(seed) |
||||
|
np.random.seed(seed) |
||||
|
|
||||
|
# Initialize Language model |
||||
|
self.context_prefix = '' |
||||
|
|
||||
|
if lm_model == 'gpt-neo': |
||||
|
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') |
||||
|
self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) |
||||
|
elif lm_model == 'gpt-2': |
||||
|
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') |
||||
|
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) |
||||
|
self.context_prefix = self.lm_tokenizer.bos_token |
||||
|
|
||||
|
self.lm_model.to(self.device) |
||||
|
self.lm_model.eval() |
||||
|
|
||||
|
self.forbidden_tokens = np.load(forbidden_tokens_file_path) |
||||
|
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if |
||||
|
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())] |
||||
|
|
||||
|
# Freeze LM weights |
||||
|
for param in self.lm_model.parameters(): |
||||
|
param.requires_grad = False |
||||
|
|
||||
|
# Initialize CLIP |
||||
|
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, |
||||
|
download_root=clip_checkpoints, jit=False) |
||||
|
self.clip_image = ImageCLIP(self.clip) |
||||
|
self.clip_image = torch.nn.DataParallel(self.clip_image) |
||||
|
self.clip_text = TextCLIP(self.clip) |
||||
|
self.clip_text = torch.nn.DataParallel(self.clip_text) |
||||
|
|
||||
|
# Init arguments |
||||
|
self.target_seq_length = target_seq_length |
||||
|
self.reset_context_delta = reset_context_delta |
||||
|
self.num_iterations = num_iterations |
||||
|
self.clip_loss_temperature = clip_loss_temperature |
||||
|
self.clip_scale = clip_scale |
||||
|
self.ce_scale = ce_scale |
||||
|
self.stepsize = stepsize |
||||
|
self.grad_norm_factor = grad_norm_factor |
||||
|
self.fusion_factor = fusion_factor |
||||
|
self.repetition_penalty = repetition_penalty |
||||
|
self.end_token = self.lm_tokenizer.encode(end_token)[0] |
||||
|
self.end_factor = end_factor |
||||
|
self.ef_idx = 1 |
||||
|
self.forbidden_factor = forbidden_factor |
||||
|
|
||||
|
def get_img_feature(self, img_path, weights): |
||||
|
imgs = [Image.open(x) for x in img_path] |
||||
|
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
image_fts = [self.clip_image(x) for x in clip_imgs] |
||||
|
|
||||
|
if weights is not None: |
||||
|
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) |
||||
|
else: |
||||
|
image_features = sum(image_fts) |
||||
|
|
||||
|
image_features = torch.nn.functional.normalize(image_features, dim=-1) |
||||
|
return image_features.detach() |
||||
|
|
||||
|
def get_txt_features(self, text): |
||||
|
clip_texts = clip.tokenize(text).to(self.device) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
text_features = self.clip_text(clip_texts) |
||||
|
|
||||
|
text_features = torch.nn.functional.normalize(text_features, dim=-1) |
||||
|
return text_features.detach() |
||||
|
|
||||
|
def get_combined_feature(self, img_path, texts, weights_i, weights_t): |
||||
|
imgs = [Image.open(x) for x in img_path] |
||||
|
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
||||
|
clip_texts = [clip.tokenize(x).to(self.device) for x in texts] |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
||||
|
text_fts = [self.clip.encode_text(x) for x in clip_texts] |
||||
|
|
||||
|
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) |
||||
|
if weights_t is not None: |
||||
|
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) |
||||
|
|
||||
|
features = features / features.norm(dim=-1, keepdim=True) |
||||
|
return features.detach() |
||||
|
|
||||
|
def run(self, image_features, cond_text, beam_size): |
||||
|
self.image_features = image_features |
||||
|
|
||||
|
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) |
||||
|
|
||||
|
output_tokens, output_text = self.generate_text(context_tokens, beam_size) |
||||
|
|
||||
|
return output_text |
||||
|
|
||||
|
def generate_text(self, context_tokens, beam_size): |
||||
|
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) |
||||
|
|
||||
|
gen_tokens = None |
||||
|
scores = None |
||||
|
seq_lengths = torch.ones(beam_size, device=self.device) |
||||
|
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) |
||||
|
|
||||
|
for i in range(self.target_seq_length): |
||||
|
probs = self.get_next_probs(i, context_tokens) |
||||
|
logits = probs.log() |
||||
|
|
||||
|
if scores is None: |
||||
|
scores, next_tokens = logits.topk(beam_size, -1) |
||||
|
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) |
||||
|
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
||||
|
|
||||
|
if gen_tokens is None: |
||||
|
gen_tokens = next_tokens |
||||
|
else: |
||||
|
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) |
||||
|
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) |
||||
|
else: |
||||
|
logits[is_stopped] = -float(np.inf) |
||||
|
logits[is_stopped, 0] = 0 |
||||
|
scores_sum = scores[:, None] + logits |
||||
|
seq_lengths[~is_stopped] += 1 |
||||
|
scores_sum_average = scores_sum / seq_lengths[:, None] |
||||
|
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( |
||||
|
beam_size, -1) |
||||
|
next_tokens_source = next_tokens // scores_sum.shape[1] |
||||
|
seq_lengths = seq_lengths[next_tokens_source] |
||||
|
next_tokens = next_tokens % scores_sum.shape[1] |
||||
|
next_tokens = next_tokens.unsqueeze(1) |
||||
|
gen_tokens = gen_tokens[next_tokens_source] |
||||
|
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) |
||||
|
context_tokens = context_tokens[next_tokens_source] |
||||
|
scores = scores_sum_average * seq_lengths |
||||
|
is_stopped = is_stopped[next_tokens_source] |
||||
|
|
||||
|
context_tokens = torch.cat((context_tokens, next_tokens), dim=1) |
||||
|
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() |
||||
|
|
||||
|
#### |
||||
|
tmp_scores = scores / seq_lengths |
||||
|
tmp_output_list = gen_tokens.cpu().numpy() |
||||
|
tmp_output_texts = [ |
||||
|
self.lm_tokenizer.decode(tmp_output) |
||||
|
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) |
||||
|
] |
||||
|
tmp_order = tmp_scores.argsort(descending=True) |
||||
|
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] |
||||
|
log_info(tmp_output_texts, verbose=True) |
||||
|
#### |
||||
|
|
||||
|
if is_stopped.all(): |
||||
|
break |
||||
|
|
||||
|
scores = scores / seq_lengths |
||||
|
output_list = gen_tokens.cpu().numpy() |
||||
|
output_texts = [ |
||||
|
self.lm_tokenizer.decode(output[: int(length)]) |
||||
|
for output, length in zip(output_list, seq_lengths) |
||||
|
] |
||||
|
order = scores.argsort(descending=True) |
||||
|
output_texts = [output_texts[i] for i in order] |
||||
|
|
||||
|
return context_tokens, output_texts |
||||
|
|
||||
|
def get_next_probs(self, i, context_tokens): |
||||
|
last_token = context_tokens[:, -1:] |
||||
|
|
||||
|
if self.reset_context_delta and context_tokens.size(1) > 1: |
||||
|
context = self.lm_model(context_tokens[:, :-1])["past_key_values"] |
||||
|
|
||||
|
# Logits of LM with unshifted context |
||||
|
logits_before_shift = self.lm_model(context_tokens)["logits"] |
||||
|
logits_before_shift = logits_before_shift[:, -1, :] |
||||
|
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) |
||||
|
|
||||
|
if context: |
||||
|
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) |
||||
|
|
||||
|
lm_output = self.lm_model(last_token, past_key_values=context) |
||||
|
logits, past = ( |
||||
|
lm_output["logits"], |
||||
|
lm_output["past_key_values"], |
||||
|
) |
||||
|
logits = logits[:, -1, :] |
||||
|
|
||||
|
logits = self.update_special_tokens_logits(context_tokens, i, logits) |
||||
|
|
||||
|
probs = nn.functional.softmax(logits, dim=-1) |
||||
|
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) |
||||
|
probs = probs / probs.sum() |
||||
|
|
||||
|
return probs |
||||
|
|
||||
|
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): |
||||
|
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] |
||||
|
|
||||
|
for i in range(self.num_iterations): |
||||
|
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in |
||||
|
context_delta] |
||||
|
|
||||
|
for p0, p1 in curr_shift: |
||||
|
p0.retain_grad() |
||||
|
p1.retain_grad() |
||||
|
|
||||
|
shifted_context = list(map(add_context, context, curr_shift)) |
||||
|
|
||||
|
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) |
||||
|
logits = shifted_outputs["logits"][:, -1, :] |
||||
|
probs = nn.functional.softmax(logits, dim=-1) |
||||
|
|
||||
|
loss = 0.0 |
||||
|
|
||||
|
# CLIP LOSS |
||||
|
clip_loss, clip_losses = self.clip_loss(probs, context_tokens) |
||||
|
loss += self.clip_scale * clip_loss |
||||
|
|
||||
|
# CE/Fluency loss |
||||
|
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) |
||||
|
loss += ce_loss.sum() |
||||
|
|
||||
|
loss.backward() |
||||
|
|
||||
|
# --------- Specific Gen --------- |
||||
|
final_grads = self.norm_grad(context, context_tokens, curr_shift) |
||||
|
|
||||
|
# --------- update context --------- |
||||
|
context_delta = list(map(add_context, final_grads, context_delta)) |
||||
|
|
||||
|
for p0, p1 in curr_shift: |
||||
|
p0.grad.data.zero_() |
||||
|
p1.grad.data.zero_() |
||||
|
|
||||
|
new_context = [] |
||||
|
for p0, p1 in context: |
||||
|
new_context.append((p0.detach(), p1.detach())) |
||||
|
context = new_context |
||||
|
|
||||
|
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) |
||||
|
for p_ in context_delta] |
||||
|
context = list(map(add_context, context, context_delta)) |
||||
|
|
||||
|
new_context = [] |
||||
|
for p0, p1 in context: |
||||
|
new_context.append((p0.detach(), p1.detach())) |
||||
|
context = new_context |
||||
|
|
||||
|
return context |
||||
|
|
||||
|
def norm_grad(self, context, context_tokens, curr_shift, ): |
||||
|
factor = 1 |
||||
|
sep_grads = None |
||||
|
window_mask = torch.ones_like(context[0][0]).to(self.device) |
||||
|
|
||||
|
for b in range(context_tokens.shape[0]): |
||||
|
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] |
||||
|
for p_ in curr_shift] |
||||
|
|
||||
|
# normalize gradients |
||||
|
tmp_grad = [tuple([-self.stepsize * factor * ( |
||||
|
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ |
||||
|
j] ** self.grad_norm_factor).data.cpu().numpy() |
||||
|
for j, x in enumerate(p_)]) |
||||
|
for i, p_ in enumerate(curr_shift)] |
||||
|
if sep_grads is None: |
||||
|
sep_grads = tmp_grad |
||||
|
else: |
||||
|
for l_index in range(len(sep_grads)): |
||||
|
sep_grads[l_index] = list(sep_grads[l_index]) |
||||
|
for k_index in range(len(sep_grads[0])): |
||||
|
sep_grads[l_index][k_index] = np.concatenate( |
||||
|
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) |
||||
|
sep_grads[l_index] = tuple(sep_grads[l_index]) |
||||
|
final_grads = sep_grads |
||||
|
|
||||
|
return final_grads |
||||
|
|
||||
|
def update_special_tokens_logits(self, context_tokens, i, logits): |
||||
|
for beam_id in range(context_tokens.shape[0]): |
||||
|
for token_idx in set(context_tokens[beam_id][-4:].tolist()): |
||||
|
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) |
||||
|
logits[beam_id, token_idx] /= factor |
||||
|
|
||||
|
if i >= self.ef_idx: |
||||
|
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) |
||||
|
logits[beam_id, self.end_token] *= factor |
||||
|
if i == 0: |
||||
|
start_factor = 1.6 |
||||
|
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) |
||||
|
logits[beam_id, self.end_token] /= factor |
||||
|
|
||||
|
for token_idx in list(self.forbidden_tokens): |
||||
|
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) |
||||
|
logits[beam_id, token_idx] /= factor |
||||
|
|
||||
|
return logits |
||||
|
|
||||
|
def clip_loss(self, probs, context_tokens): |
||||
|
for p_ in self.clip.transformer.parameters(): |
||||
|
if p_.grad is not None: |
||||
|
p_.grad.data.zero_() |
||||
|
|
||||
|
top_size = 512 |
||||
|
top_probs, top_indices = probs.topk(top_size, -1) |
||||
|
|
||||
|
prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] |
||||
|
|
||||
|
clip_loss = 0 |
||||
|
losses = [] |
||||
|
|
||||
|
top_texts = [] |
||||
|
for idx_p in range(probs.shape[0]): |
||||
|
prefix_text = prefix_texts[idx_p] |
||||
|
for x in top_indices[idx_p]: |
||||
|
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
||||
|
|
||||
|
text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) |
||||
|
similiraties = similiraties.reshape(probs.size(0), -1) |
||||
|
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
||||
|
target_probs = target_probs.type(torch.float32) |
||||
|
|
||||
|
clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) |
||||
|
# for idx_p in range(probs.shape[0]): |
||||
|
# top_texts = [] |
||||
|
# prefix_text = prefix_texts[idx_p] |
||||
|
# for x in top_indices[idx_p]: |
||||
|
# top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
||||
|
# text_features = self.get_txt_features(top_texts) |
||||
|
# |
||||
|
# with torch.no_grad(): |
||||
|
# similiraties = (self.image_features @ text_features.T) |
||||
|
# target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
||||
|
# target_probs = target_probs.type(torch.float32) |
||||
|
# |
||||
|
# target = torch.zeros_like(probs[idx_p]) |
||||
|
# target[top_indices[idx_p]] = target_probs[0] |
||||
|
# target = target.unsqueeze(0) |
||||
|
# cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
||||
|
# |
||||
|
# clip_loss += cur_clip_loss |
||||
|
# losses.append(cur_clip_loss) |
||||
|
|
||||
|
return clip_loss, losses |
||||
|
|
||||
|
def clip_loss_old(self, probs, context_tokens): |
||||
|
for p_ in self.clip.transformer.parameters(): |
||||
|
if p_.grad is not None: |
||||
|
p_.grad.data.zero_() |
||||
|
|
||||
|
top_size = 512 |
||||
|
_, top_indices = probs.topk(top_size, -1) |
||||
|
|
||||
|
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] |
||||
|
|
||||
|
clip_loss = 0 |
||||
|
losses = [] |
||||
|
for idx_p in range(probs.shape[0]): |
||||
|
top_texts = [] |
||||
|
prefix_text = prefix_texts[idx_p] |
||||
|
for x in top_indices[idx_p]: |
||||
|
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
||||
|
text_features = self.get_txt_features(top_texts) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
similiraties = (self.image_features @ text_features.T) |
||||
|
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
||||
|
target_probs = target_probs.type(torch.float32) |
||||
|
|
||||
|
target = torch.zeros_like(probs[idx_p]) |
||||
|
target[top_indices[idx_p]] = target_probs[0] |
||||
|
target = target.unsqueeze(0) |
||||
|
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
||||
|
|
||||
|
clip_loss += cur_clip_loss |
||||
|
losses.append(cur_clip_loss) |
||||
|
|
||||
|
return clip_loss, losses |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,14 @@ |
|||||
|
#!/bin/bash |
||||
|
|
||||
|
# lm_model: |
||||
|
# 1. cambridgeltl/magic_mscoco |
||||
|
# 2. cambridgeltl/magic_flickr30k |
||||
|
CUDA_VISIBLE_DEVICES=1 python run.py \ |
||||
|
--beam_size 1 \ |
||||
|
--target_seq_length 16 \ |
||||
|
--reset_context_delta \ |
||||
|
--lm_model cambridgeltl/magic_mscoco \ |
||||
|
--test_image_prefix_path ../data/mscoco/test_images \ |
||||
|
--test_path ../data/mscoco/mscoco_test.json \ |
||||
|
--save_path_prefix ../inference_result/mscoco/baselines/ \ |
||||
|
--save_name zerocap_result.json |
@ -0,0 +1,117 @@ |
|||||
|
import os |
||||
|
import tempfile |
||||
|
import sys |
||||
|
sys.path.append('CLIP') |
||||
|
from pathlib import Path |
||||
|
import cog |
||||
|
import argparse |
||||
|
import torch |
||||
|
import clip |
||||
|
from model.ZeroCLIP import CLIPTextGenerator |
||||
|
|
||||
|
def perplexity_score(text, lm_model, lm_tokenizer, device): |
||||
|
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') |
||||
|
input_ids = encodings.input_ids.to(device) |
||||
|
target_ids = input_ids.clone() |
||||
|
|
||||
|
outputs = lm_model(input_ids, labels=target_ids) |
||||
|
log_likelihood = outputs[0] |
||||
|
ll = log_likelihood.item() |
||||
|
|
||||
|
return ll |
||||
|
|
||||
|
class Predictor(cog.Predictor): |
||||
|
def setup(self): |
||||
|
self.args = get_args() |
||||
|
self.args.reset_context_delta = True |
||||
|
self.text_generator = CLIPTextGenerator(**vars(self.args)) |
||||
|
|
||||
|
@cog.input( |
||||
|
"image", |
||||
|
type=Path, |
||||
|
help="input image" |
||||
|
) |
||||
|
@cog.input( |
||||
|
"cond_text", |
||||
|
type=str, |
||||
|
default='Image of a', |
||||
|
help="conditional text", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"beam_size", |
||||
|
type=int, |
||||
|
default=5, min=1, max=10, |
||||
|
help="Number of beams to use", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"end_factor", |
||||
|
type=float, |
||||
|
default=1.01, min=1.0, max=1.10, |
||||
|
help="Higher value for shorter captions", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"max_seq_length", |
||||
|
type=int, |
||||
|
default=15, min=1, max=20, |
||||
|
help="Maximum number of tokens to generate", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"ce_loss_scale", |
||||
|
type=float, |
||||
|
default=0.2, min=0.0, max=0.6, |
||||
|
help="Scale of cross-entropy loss with un-shifted language model", |
||||
|
) |
||||
|
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): |
||||
|
self.args.cond_text = cond_text |
||||
|
self.text_generator.end_factor = end_factor |
||||
|
self.text_generator.target_seq_length = max_seq_length |
||||
|
self.text_generator.ce_scale = ce_loss_scale |
||||
|
|
||||
|
image_features = self.text_generator.get_img_feature([str(image)], None) |
||||
|
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) |
||||
|
|
||||
|
# CLIP SCORE |
||||
|
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) |
||||
|
for c in captions] |
||||
|
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
||||
|
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
||||
|
|
||||
|
# Perplexity SCORE |
||||
|
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] |
||||
|
best_ppl_index = torch.tensor(ppl_scores).argmin().item() |
||||
|
|
||||
|
best_clip_caption = self.args.cond_text + captions[best_clip_idx] |
||||
|
best_mixed = self.args.cond_text + captions[0] |
||||
|
best_PPL = self.args.cond_text + captions[best_ppl_index] |
||||
|
|
||||
|
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' |
||||
|
|
||||
|
return final |
||||
|
# return self.args.cond_text + captions[best_clip_idx] |
||||
|
|
||||
|
|
||||
|
def get_args(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
|
||||
|
parser.add_argument("--seed", type=int, default=0) |
||||
|
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
||||
|
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
||||
|
parser.add_argument("--target_seq_length", type=int, default=15) |
||||
|
parser.add_argument("--cond_text", type=str, default="Image of a") |
||||
|
parser.add_argument("--reset_context_delta", action="store_true", |
||||
|
help="Should we reset the context at each token gen") |
||||
|
parser.add_argument("--num_iterations", type=int, default=5) |
||||
|
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
||||
|
parser.add_argument("--clip_scale", type=float, default=1) |
||||
|
parser.add_argument("--ce_scale", type=float, default=0.2) |
||||
|
parser.add_argument("--stepsize", type=float, default=0.3) |
||||
|
parser.add_argument("--grad_norm_factor", type=float, default=0.9) |
||||
|
parser.add_argument("--fusion_factor", type=float, default=0.99) |
||||
|
parser.add_argument("--repetition_penalty", type=float, default=1) |
||||
|
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
||||
|
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
||||
|
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
||||
|
parser.add_argument("--beam_size", type=int, default=5) |
||||
|
|
||||
|
args = parser.parse_args('') |
||||
|
return args |
@ -0,0 +1,129 @@ |
|||||
|
import os |
||||
|
import tempfile |
||||
|
import sys |
||||
|
sys.path.append('CLIP') |
||||
|
from pathlib import Path |
||||
|
import cog |
||||
|
import argparse |
||||
|
import torch |
||||
|
import clip |
||||
|
from model.ZeroCLIP import CLIPTextGenerator |
||||
|
|
||||
|
def perplexity_score(text, lm_model, lm_tokenizer, device): |
||||
|
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') |
||||
|
input_ids = encodings.input_ids.to(device) |
||||
|
target_ids = input_ids.clone() |
||||
|
|
||||
|
outputs = lm_model(input_ids, labels=target_ids) |
||||
|
log_likelihood = outputs[0] |
||||
|
ll = log_likelihood.item() |
||||
|
|
||||
|
return ll |
||||
|
|
||||
|
class Predictor(cog.Predictor): |
||||
|
def setup(self): |
||||
|
self.args = get_args() |
||||
|
self.args.reset_context_delta = True |
||||
|
self.text_generator = CLIPTextGenerator(**vars(self.args)) |
||||
|
|
||||
|
@cog.input( |
||||
|
"image1", |
||||
|
type=Path, |
||||
|
help="Final result will be: image1 + (image2 - image3)" |
||||
|
) |
||||
|
@cog.input( |
||||
|
"image2", |
||||
|
type=Path, |
||||
|
help="Final result will be: image1 + (image2 - image3)" |
||||
|
) |
||||
|
@cog.input( |
||||
|
"image3", |
||||
|
type=Path, |
||||
|
help="Final result will be: image1 + (image2 - image3)" |
||||
|
) |
||||
|
@cog.input( |
||||
|
"cond_text", |
||||
|
type=str, |
||||
|
default='Image of a', |
||||
|
help="conditional text", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"beam_size", |
||||
|
type=int, |
||||
|
default=3, min=1, max=10, |
||||
|
help="Number of beams to use", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"end_factors", |
||||
|
type=float, |
||||
|
default=1.06, min=1.0, max=1.10, |
||||
|
help="Higher value for shorter captions", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"max_seq_lengths", |
||||
|
type=int, |
||||
|
default=3, min=1, max=20, |
||||
|
help="Maximum number of tokens to generate", |
||||
|
) |
||||
|
@cog.input( |
||||
|
"ce_loss_scale", |
||||
|
type=float, |
||||
|
default=0.2, min=0.0, max=0.6, |
||||
|
help="Scale of cross-entropy loss with un-shifted language model", |
||||
|
) |
||||
|
def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): |
||||
|
self.args.cond_text = cond_text |
||||
|
self.text_generator.end_factor = end_factors |
||||
|
self.text_generator.target_seq_length = max_seq_lengths |
||||
|
self.text_generator.ce_scale = ce_loss_scale |
||||
|
self.text_generator.fusion_factor = 0.95 |
||||
|
self.text_generator.grad_norm_factor = 0.95 |
||||
|
|
||||
|
image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) |
||||
|
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) |
||||
|
|
||||
|
# CLIP SCORE |
||||
|
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) |
||||
|
for c in captions] |
||||
|
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
||||
|
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
||||
|
|
||||
|
# Perplexity SCORE |
||||
|
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] |
||||
|
best_ppl_index = torch.tensor(ppl_scores).argmin().item() |
||||
|
|
||||
|
best_clip_caption = self.args.cond_text + captions[best_clip_idx] |
||||
|
best_mixed = self.args.cond_text + captions[0] |
||||
|
best_PPL = self.args.cond_text + captions[best_ppl_index] |
||||
|
|
||||
|
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' |
||||
|
|
||||
|
return final |
||||
|
# return self.args.cond_text + captions[best_clip_idx] |
||||
|
|
||||
|
|
||||
|
def get_args(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
|
||||
|
parser.add_argument("--seed", type=int, default=0) |
||||
|
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
||||
|
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
||||
|
parser.add_argument("--target_seq_length", type=int, default=15) |
||||
|
parser.add_argument("--cond_text", type=str, default="Image of a") |
||||
|
parser.add_argument("--reset_context_delta", action="store_true", |
||||
|
help="Should we reset the context at each token gen") |
||||
|
parser.add_argument("--num_iterations", type=int, default=5) |
||||
|
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
||||
|
parser.add_argument("--clip_scale", type=float, default=1) |
||||
|
parser.add_argument("--ce_scale", type=float, default=0.2) |
||||
|
parser.add_argument("--stepsize", type=float, default=0.3) |
||||
|
parser.add_argument("--grad_norm_factor", type=float, default=0.95) |
||||
|
parser.add_argument("--fusion_factor", type=float, default=0.95) |
||||
|
parser.add_argument("--repetition_penalty", type=float, default=1) |
||||
|
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
||||
|
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
||||
|
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
||||
|
parser.add_argument("--beam_size", type=int, default=5) |
||||
|
|
||||
|
args = parser.parse_args('') |
||||
|
return args |
@ -0,0 +1,3 @@ |
|||||
|
ftfy |
||||
|
regex |
||||
|
tqdm |
@ -0,0 +1,131 @@ |
|||||
|
import argparse |
||||
|
import ipdb |
||||
|
from tqdm import tqdm |
||||
|
import progressbar |
||||
|
import torch |
||||
|
import ipdb |
||||
|
import clip |
||||
|
from model.ZeroCLIP import CLIPTextGenerator |
||||
|
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu |
||||
|
|
||||
|
def get_args(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
|
||||
|
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") |
||||
|
parser.add_argument("--test_path", type=str) |
||||
|
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") |
||||
|
parser.add_argument("--save_name", type=str, help="the name of the saved file") |
||||
|
|
||||
|
parser.add_argument("--seed", type=int, default=0) |
||||
|
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
||||
|
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
||||
|
parser.add_argument("--target_seq_length", type=int, default=15) |
||||
|
parser.add_argument("--cond_text", type=str, default="Image of a") |
||||
|
parser.add_argument("--reset_context_delta", action="store_true", |
||||
|
help="Should we reset the context at each token gen") |
||||
|
parser.add_argument("--num_iterations", type=int, default=5) |
||||
|
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
||||
|
parser.add_argument("--clip_scale", type=float, default=1) |
||||
|
parser.add_argument("--ce_scale", type=float, default=0.2) |
||||
|
parser.add_argument("--stepsize", type=float, default=0.3) |
||||
|
parser.add_argument("--grad_norm_factor", type=float, default=0.9) |
||||
|
parser.add_argument("--fusion_factor", type=float, default=0.99) |
||||
|
parser.add_argument("--repetition_penalty", type=float, default=1) |
||||
|
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
||||
|
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
||||
|
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
||||
|
parser.add_argument("--beam_size", type=int, default=1) |
||||
|
|
||||
|
parser.add_argument("--multi_gpu", action="store_true") |
||||
|
|
||||
|
parser.add_argument('--run_type', |
||||
|
default='caption', |
||||
|
nargs='?', |
||||
|
choices=['caption', 'arithmetics']) |
||||
|
|
||||
|
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', |
||||
|
help="Path to image for captioning") |
||||
|
|
||||
|
parser.add_argument("--arithmetics_imgs", nargs="+", |
||||
|
default=['example_images/arithmetics/woman2.jpg', |
||||
|
'example_images/arithmetics/king2.jpg', |
||||
|
'example_images/arithmetics/man2.jpg']) |
||||
|
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) |
||||
|
|
||||
|
args = parser.parse_args() |
||||
|
|
||||
|
return args |
||||
|
|
||||
|
def run(args, text_generator, img_path): |
||||
|
image_features = text_generator.get_img_feature([img_path], None) |
||||
|
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) |
||||
|
|
||||
|
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] |
||||
|
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
||||
|
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
||||
|
return captions |
||||
|
|
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
if torch.cuda.is_available(): |
||||
|
print ('Cuda is available.') |
||||
|
cuda_available = torch.cuda.is_available() |
||||
|
args = get_args() |
||||
|
device = torch.device('cuda') |
||||
|
|
||||
|
save_path_prefix = args.save_path_prefix |
||||
|
import os |
||||
|
if os.path.exists(save_path_prefix): |
||||
|
pass |
||||
|
else: # recursively construct directory |
||||
|
os.makedirs(save_path_prefix, exist_ok=True) |
||||
|
# parse save name |
||||
|
save_name = args.save_name |
||||
|
full_save_path = save_path_prefix + '/' + save_name |
||||
|
print ('full save path is {}'.format(full_save_path)) |
||||
|
|
||||
|
print ('Loading data...') |
||||
|
import json |
||||
|
with open(args.test_path) as f: |
||||
|
item_list = json.load(f) |
||||
|
print ('Data loaded.') |
||||
|
print ('Number of test instances is {}'.format(len(item_list))) |
||||
|
|
||||
|
# ZeroCap generator |
||||
|
text_generator = CLIPTextGenerator(**vars(args)) |
||||
|
|
||||
|
result_list = [] |
||||
|
invalid_num = 0 |
||||
|
print ('----------------------------------------------------------------') |
||||
|
test_num = len(item_list) |
||||
|
#test_num = 10 |
||||
|
print ('Number of inference instances is {}'.format(test_num)) |
||||
|
p = progressbar.ProgressBar(test_num) |
||||
|
p.start() |
||||
|
for p_idx in tqdm(range(test_num)): |
||||
|
p.update(p_idx) |
||||
|
one_test_dict = item_list[p_idx] |
||||
|
|
||||
|
one_res_dict = { |
||||
|
'split':one_test_dict['split'], |
||||
|
'image_name':one_test_dict['image_name'], |
||||
|
#'file_path':one_test_dict['file_path'], |
||||
|
'captions':one_test_dict['captions'] |
||||
|
} |
||||
|
|
||||
|
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] |
||||
|
try: |
||||
|
output_text = run(args, text_generator, img_path=image_full_path) |
||||
|
one_res_dict['prediction'] = output_text[0] |
||||
|
result_list.append(one_res_dict) |
||||
|
except Exception as error: |
||||
|
print(f'[!] ERROR:', error) |
||||
|
invalid_num += 1 |
||||
|
print ('invalid number is {}'.format(invalid_num)) |
||||
|
continue |
||||
|
p.finish() |
||||
|
print ('Inference completed!') |
||||
|
|
||||
|
import json |
||||
|
with open(full_save_path, 'w') as outfile: |
||||
|
json.dump(result_list, outfile, indent=4) |
@ -0,0 +1,19 @@ |
|||||
|
import os |
||||
|
|
||||
|
import pkg_resources |
||||
|
from setuptools import setup, find_packages |
||||
|
|
||||
|
setup( |
||||
|
name="zero-shot-image-to-text", |
||||
|
py_modules=["zero-shot-image-to-text"], |
||||
|
version="1.0", |
||||
|
description="", |
||||
|
packages=find_packages(), |
||||
|
install_requires=[ |
||||
|
str(r) |
||||
|
for r in pkg_resources.parse_requirements( |
||||
|
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) |
||||
|
) |
||||
|
], |
||||
|
include_package_data=True |
||||
|
) |
Loading…
Reference in new issue