magic
copied
wxywb
2 years ago
32 changed files with 1334 additions and 1376 deletions
Binary file not shown.
Binary file not shown.
@ -1,2 +1,81 @@ |
|||||
# magic |
|
||||
|
# Image Captioning with MAGIC |
||||
|
|
||||
|
*author: David Wang* |
||||
|
|
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
## Description |
||||
|
|
||||
|
This operator generates the caption with [MAGIC](https://arxiv.org/abs/2205.02655) which describes the content of the given image. MAGIC is a simple yet efficient plug-and-play framework, which directly combines an off-the-shelf LM (i.e., GPT-2) and an image-text matching model (i.e., CLIP) for image-grounded text generation. During decoding, MAGIC influences the generation of the LM by introducing a CLIP-induced score, called magic score, which regularizes the generated result to be semantically related to a given image while being coherent to the previously generated context. This is an adaptation from [yxuansu / MAGIC](https://github.com/yxuansu/MAGIC). |
||||
|
|
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
Load an image from path './image.jpg' to generate the caption. |
||||
|
|
||||
|
*Write the pipeline in simplified style*: |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
towhee.glob('./image.jpg') \ |
||||
|
.image_decode() \ |
||||
|
.image_captioning.magic(model_name='expansionnet_rf') \ |
||||
|
.show() |
||||
|
``` |
||||
|
<img src="./cap.png" alt="result1" style="height:20px;"/> |
||||
|
|
||||
|
*Write a same pipeline with explicit inputs/outputs name specifications:* |
||||
|
|
||||
|
```python |
||||
|
import towhee |
||||
|
|
||||
|
towhee.glob['path']('./image.jpg') \ |
||||
|
.image_decode['path', 'img']() \ |
||||
|
.image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \ |
||||
|
.select['img', 'text']() \ |
||||
|
.show() |
||||
|
``` |
||||
|
<img src="./tabular.png" alt="result2" style="height:60px;"/> |
||||
|
|
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method |
||||
|
|
||||
|
***expansionnet_v2(model_name)*** |
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
​ ***model_name:*** *str* |
||||
|
|
||||
|
​ The model name of MAGIC. Supported model names: |
||||
|
- magic_mscoco |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Interface |
||||
|
|
||||
|
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. |
||||
|
|
||||
|
|
||||
|
**Parameters:** |
||||
|
|
||||
|
​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* |
||||
|
|
||||
|
​ The image to generate embedding. |
||||
|
|
||||
|
|
||||
|
|
||||
|
**Returns:** *str* |
||||
|
|
||||
|
​ The caption generated by model. |
||||
|
|
||||
|
@ -0,0 +1,167 @@ |
|||||
|
## Unsupervised Domain Adaptation of Language Model |
||||
|
**** |
||||
|
### Catalogue: |
||||
|
* <a href='#mscoco'>1. MSCOCO Benchmark</a> |
||||
|
* <a href='#mscoco_data_preparation'>1.1. MSCOCO Data Preparation</a> |
||||
|
* <a href='#mscoco_training'>1.2. Unsupervised Domain Adaptation on MSCOCO</a> |
||||
|
* <a href='#flickr30k'>2. Flickr30k Benchmark</a> |
||||
|
* <a href='#flickr30k_data_preparation'>2.1. Flickr30k Data Preparation</a> |
||||
|
* <a href='#flickr30k_training'>2.2. Unsupervised Domain Adaptation on Flickr30k</a> |
||||
|
* <a href='#unsupervised_baselines'>3. Unsupervised Baselines</a> |
||||
|
* <a href='#contrastive_search'>3.1. Contrastive Search</a> |
||||
|
* <a href='#top_k_sampling'>3.2. Top-k Sampling</a> |
||||
|
* <a href='#nucleus_sampling'>3.3. Nucleus Sampling</a> |
||||
|
|
||||
|
**** |
||||
|
<span id='mscoco'/> |
||||
|
|
||||
|
#### 1. MSCOCO Benchmark: |
||||
|
|
||||
|
We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark. |
||||
|
|
||||
|
<span id='mscoco_data_preparation'/> |
||||
|
|
||||
|
##### 1.1. MSCOCO Data Preparation: |
||||
|
|
||||
|
To prepare the MSCOCO benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). |
||||
|
|
||||
|
<span id='mscoco_training'/> |
||||
|
|
||||
|
##### 1.2.Unsupervised Domain Adaptation on MSCOCO: |
||||
|
After preparing the MSCOCO data, run the following command to train the language model. |
||||
|
```yaml |
||||
|
chmod +x ./train_mscoco.sh |
||||
|
./train_mscoco.sh |
||||
|
``` |
||||
|
The arguments are as follows: |
||||
|
* `--model_name`: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large). |
||||
|
* `--train_path`: The file path of training set. |
||||
|
* `--dev_path`: The file path of validation set. |
||||
|
* `--test_path`: The file path of test set. |
||||
|
* `--add_eos_token_to_data`: Whether adding an eos token at the end of text sequence. |
||||
|
* `--margin`: The contrastive margin $\rho$. |
||||
|
* `--max_len`: The maximum length of training samples. |
||||
|
* `--number_of_gpu`: The number of available GPUs. |
||||
|
* `--batch_size_per_gpu`: The batch size for each GPU. |
||||
|
* `--gradient_accumulation_steps`: How many forward computations between two gradient updates. |
||||
|
* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu. |
||||
|
* `--total_steps`: The number of total gradient update steps. |
||||
|
* `--print_every`: Have many steps to show the intermediate results. |
||||
|
* `--save_every`: How many steps to save one checkpoint. |
||||
|
* `--learning_rate`: The learning rate. |
||||
|
* `--save_path_prefix`: Where to save the checkpoints. |
||||
|
|
||||
|
**** |
||||
|
<span id='flickr30k'/> |
||||
|
|
||||
|
#### 2. Flickr30k Benchmark: |
||||
|
|
||||
|
We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark. |
||||
|
|
||||
|
<span id='flickr30k_data_preparation'/> |
||||
|
|
||||
|
##### 2.1. Flickr30k Data Preparation: |
||||
|
|
||||
|
To prepare the Flickr30k benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). |
||||
|
|
||||
|
<span id='flickr30k_training'/> |
||||
|
|
||||
|
##### 2.2. Unsupervised Domain Adaptation on Flickr30k: |
||||
|
After preparing the Flickr30k data, run the following command to train the language model. |
||||
|
```yaml |
||||
|
chmod +x ./train_flickr30k.sh |
||||
|
./train_flickr30k.sh |
||||
|
``` |
||||
|
|
||||
|
**** |
||||
|
<span id='unsupervised_baselines'/> |
||||
|
|
||||
|
#### 3. Unsupervised Baselines: |
||||
|
|
||||
|
Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are **unsupervised** as the language model is a text-only model and does not take image as input. |
||||
|
|
||||
|
```python |
||||
|
# first, load the language model |
||||
|
import torch |
||||
|
from simctg import SimCTG |
||||
|
sos_token, pad_token = r'<-start_of_text->', r'<-pad->' |
||||
|
# we use the language model adapted on MSCOCO as an example. |
||||
|
language_model_name = r'cambridgeltl/magic_mscoco' |
||||
|
generation_model = SimCTG(language_model_name, sos_token, pad_token) |
||||
|
generation_model.eval() |
||||
|
|
||||
|
# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token. |
||||
|
tokens = generation_model.tokenizer.tokenize(sos_token) |
||||
|
input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens) |
||||
|
input_ids = torch.LongTensor(input_ids).view(1,-1) |
||||
|
``` |
||||
|
|
||||
|
<span id='contrastive_search'/> |
||||
|
|
||||
|
##### 3.1. Contrastive Search : |
||||
|
```python |
||||
|
''' |
||||
|
use contrastive search to generate the result. |
||||
|
note that, contrastive search is a deterministic decoding method, thus the generated text is always the same. |
||||
|
''' |
||||
|
beam_width, alpha, decoding_len = 45, 0.1, 16 |
||||
|
output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len) |
||||
|
print (output_text) |
||||
|
''' |
||||
|
A man is riding a skateboard down a street. |
||||
|
''' |
||||
|
``` |
||||
|
The arguments are as follows: |
||||
|
* `--input_ids`: The id of the start of sentence token. |
||||
|
* `--beam_width`: k in the contrastive search. |
||||
|
* `--alpha`: alpha in the contrastive search. |
||||
|
* `--decoding_len`: Number of tokens to generate. |
||||
|
|
||||
|
<span id='top_k_sampling'/> |
||||
|
|
||||
|
##### 3.2. Top-k Sampling : |
||||
|
```python |
||||
|
''' |
||||
|
use top-k sampling to generate the result. |
||||
|
note that, the this method is a stochastic method, thus the generated text is always different. |
||||
|
''' |
||||
|
top_k, decoding_len = 40, 16 |
||||
|
output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len) |
||||
|
print (output_text) |
||||
|
''' |
||||
|
some very different types of vases with flowers together |
||||
|
''' |
||||
|
``` |
||||
|
The arguments are as follows: |
||||
|
* `--input_ids`: The id of the start of sentence token. |
||||
|
* `--k`: The k in top-k sampling. |
||||
|
* `--decoding_len`: Number of tokens to generate. |
||||
|
|
||||
|
<span id='nucleus_sampling'/> |
||||
|
|
||||
|
##### 3.3. Nucleus Sampling : |
||||
|
```python |
||||
|
''' |
||||
|
use nucleus sampling to generate the result. |
||||
|
note that, the this method is a stochastic method, thus the generated text is always different. |
||||
|
''' |
||||
|
nucleus_p, decoding_len = 0.95, 16 |
||||
|
output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len) |
||||
|
print (output_text) |
||||
|
''' |
||||
|
Two young girls enjoying a hot dog hot dog bun. |
||||
|
''' |
||||
|
``` |
||||
|
The arguments are as follows: |
||||
|
* `--input_ids`: The id of the start of sentence token. |
||||
|
* `--nucleus_p`: The probability in nucleus sampling. |
||||
|
* `--decoding_len`: Number of tokens to generate. |
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
|
|
@ -0,0 +1,157 @@ |
|||||
|
import json |
||||
|
import random |
||||
|
import torch |
||||
|
import numpy as np |
||||
|
import progressbar |
||||
|
from torch.nn.utils import rnn |
||||
|
|
||||
|
class Data: |
||||
|
def __init__(self, model_name, train_path, dev_path, test_path, max_len, |
||||
|
sos_token, pad_token, add_eos_token_to_data): |
||||
|
''' |
||||
|
model_name: gpt2 |
||||
|
train_path: training data path |
||||
|
dev_path: validation data path |
||||
|
test_path: test data path |
||||
|
max_len: maximum length for training sequences |
||||
|
sos_token: initialized sos token <-start_of_text-> |
||||
|
pad_token: used to pad the sequences <-pad-> |
||||
|
add_eos_token_to_data: whether we want to the model learn to generate eos token; |
||||
|
if so, the model could automatically stop generation by generating eos token |
||||
|
''' |
||||
|
from transformers import GPT2TokenizerFast |
||||
|
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name) |
||||
|
self.sos_token, self.sos_token_id = self.add_special_token(sos_token) |
||||
|
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) |
||||
|
self.pad_token, self.pad_token_id = self.add_special_token(pad_token) |
||||
|
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) |
||||
|
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id |
||||
|
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) |
||||
|
self.add_eos_token_to_data = add_eos_token_to_data |
||||
|
|
||||
|
self.max_len = max_len |
||||
|
self.train_token_list, self.train_token_id_list = self.process_one_file(train_path) |
||||
|
self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path) |
||||
|
self.test_token_list, self.test_token_id_list = self.process_one_file(test_path) |
||||
|
self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \ |
||||
|
len(self.test_token_list) |
||||
|
print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num)) |
||||
|
|
||||
|
self.train_idx_list = [i for i in range(self.train_num)] |
||||
|
random.shuffle(self.train_idx_list) |
||||
|
self.dev_idx_list = [j for j in range(self.dev_num)] |
||||
|
self.test_idx_list = [j for j in range(self.test_num)] |
||||
|
self.dev_current_idx, self.test_current_idx = 0, 0 |
||||
|
|
||||
|
def add_special_token(self, special_token): |
||||
|
if special_token in self.tokenizer.vocab: |
||||
|
print (special_token + ' token exists.') |
||||
|
else: |
||||
|
print ('Add token to the tokenizer.') |
||||
|
print ('Original vocabulary size is {}'.format(len(self.tokenizer))) |
||||
|
self.tokenizer.add_tokens([special_token]) |
||||
|
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) |
||||
|
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 |
||||
|
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] |
||||
|
return special_token, special_token_id |
||||
|
|
||||
|
def process_one_file(self, path): |
||||
|
print ('Processing {}'.format(path)) |
||||
|
with open(path) as f: |
||||
|
item_list = json.load(f) |
||||
|
lines = [] |
||||
|
for item in item_list: |
||||
|
captions_list = item['captions'] |
||||
|
for one_caption in captions_list: |
||||
|
lines.append(one_caption.strip()) |
||||
|
|
||||
|
res_token_list, res_token_id_list = [], [] |
||||
|
n = len(lines) |
||||
|
p = progressbar.ProgressBar(n) |
||||
|
p.start() |
||||
|
for i in range(n): |
||||
|
p.update(i) |
||||
|
text = lines[i].strip('\n') |
||||
|
self.process_one_text(text, res_token_list, res_token_id_list) |
||||
|
p.finish() |
||||
|
print ('{} processed!'.format(path)) |
||||
|
return res_token_list, res_token_id_list |
||||
|
|
||||
|
def process_one_text(self, text, res_token_list, res_token_id_list): |
||||
|
tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True) |
||||
|
if len(tokens) <= 1: # filter out too short sequence |
||||
|
return |
||||
|
tokens = [self.sos_token] + tokens[:self.max_len] |
||||
|
if self.add_eos_token_to_data: |
||||
|
tokens = tokens + [self.eos_token] |
||||
|
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) |
||||
|
res_token_list.append(tokens) |
||||
|
res_token_id_list.append(token_ids) |
||||
|
return |
||||
|
|
||||
|
def pad_batch(self, batch_id_list): |
||||
|
batch_id_list = [torch.LongTensor(item) for item in batch_id_list] |
||||
|
batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id) |
||||
|
batch_mask = torch.ones_like(batch_tensor) |
||||
|
batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor) |
||||
|
return batch_tensor, batch_mask |
||||
|
|
||||
|
def process_output(self, batch_tgt_id_list): |
||||
|
batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list] |
||||
|
batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence |
||||
|
batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone() |
||||
|
batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone() |
||||
|
return batch_tgt_input_tensor, batch_tgt_output_tensor |
||||
|
|
||||
|
def parse_batch(self, batch_id_list): |
||||
|
batch_input, batch_labels = self.process_output(batch_id_list) |
||||
|
batch_labels[batch_labels[:, :] == self.pad_token_id] = -100 |
||||
|
return batch_input, batch_labels |
||||
|
|
||||
|
def get_next_train_batch(self, batch_size): |
||||
|
batch_idx_list = random.sample(self.train_idx_list, batch_size) |
||||
|
batch_id_list, batch_token_list = [], [] |
||||
|
|
||||
|
for idx in batch_idx_list: |
||||
|
batch_id_list.append(self.train_token_id_list[idx]) |
||||
|
batch_token_list.append(self.train_token_list[idx]) |
||||
|
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) |
||||
|
return batch_input_tensor, batch_labels, batch_token_list |
||||
|
|
||||
|
def get_next_validation_batch(self, batch_size, mode): |
||||
|
batch_id_list, batch_token_list = [], [] |
||||
|
if mode == 'dev': |
||||
|
curr_select_idx, instance_num = self.dev_current_idx, self.dev_num |
||||
|
tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list |
||||
|
elif mode == 'test': |
||||
|
curr_select_idx, instance_num = self.test_current_idx, self.test_num |
||||
|
tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list |
||||
|
else: |
||||
|
raise Exception('Wrong Validation Mode!!!') |
||||
|
|
||||
|
if curr_select_idx + batch_size < instance_num: |
||||
|
for i in range(batch_size): |
||||
|
curr_idx = curr_select_idx + i |
||||
|
batch_id_list.append(tgt_token_id_list[curr_idx]) |
||||
|
batch_token_list.append(tgt_token_list[curr_idx]) |
||||
|
if mode == 'dev': |
||||
|
self.dev_current_idx += batch_size |
||||
|
else: |
||||
|
self.test_current_idx += batch_size |
||||
|
else: |
||||
|
for i in range(batch_size): |
||||
|
curr_idx = curr_select_idx + i |
||||
|
if curr_idx > instance_num - 1: |
||||
|
curr_idx = 0 |
||||
|
if mode == 'dev': |
||||
|
self.dev_current_idx = 0 |
||||
|
else: |
||||
|
self.test_current_idx = 0 |
||||
|
batch_id_list.append(tgt_token_id_list[curr_idx]) |
||||
|
batch_token_list.append(tgt_token_list[curr_idx]) |
||||
|
if mode == 'dev': |
||||
|
self.dev_current_idx = 0 |
||||
|
else: |
||||
|
self.test_current_idx = 0 |
||||
|
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) |
||||
|
return batch_input_tensor, batch_labels, batch_token_list |
@ -0,0 +1,80 @@ |
|||||
|
import torch |
||||
|
|
||||
|
def compute_valid_token_num(valid_len_list): |
||||
|
res = 0 |
||||
|
for one_len in valid_len_list: |
||||
|
res += one_len * (one_len - 1) |
||||
|
return res |
||||
|
|
||||
|
def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0): |
||||
|
''' |
||||
|
prefix_len: the length of prefix that we do not want to compute CL loss for. |
||||
|
|
||||
|
(1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4), |
||||
|
then the loss padding matrix looks like |
||||
|
[0., 1., 1., 1.], |
||||
|
[1., 0., 1., 1.], |
||||
|
[1., 1., 0., 1.], |
||||
|
[1., 1., 1., 0.] |
||||
|
|
||||
|
(2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3), |
||||
|
then the loss padding matrix looks like |
||||
|
[0., 1., 1., 0.], |
||||
|
[1., 0., 1., 0.], |
||||
|
[1., 1., 0., 0.], |
||||
|
[0., 0., 0., 0.] |
||||
|
''' |
||||
|
res_list = [] |
||||
|
base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen) |
||||
|
base_mask = base_mask.type(torch.FloatTensor) |
||||
|
bsz = len(valid_len_list) |
||||
|
for i in range(bsz): |
||||
|
one_base_mask = base_mask.clone() |
||||
|
one_valid_len = valid_len_list[i] |
||||
|
one_base_mask[:,one_valid_len:] = 0. |
||||
|
one_base_mask[one_valid_len:, :] = 0. |
||||
|
if prefix_len > 0: |
||||
|
one_base_mask[:prefix_len, :prefix_len] = 0. |
||||
|
res_list.append(one_base_mask) |
||||
|
res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list) |
||||
|
#print (res_mask) |
||||
|
assert res_mask.size() == torch.Size([bsz, seqlen, seqlen]) |
||||
|
return res_mask |
||||
|
|
||||
|
def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0): |
||||
|
''' |
||||
|
margin: predefined margin to push similarity score away |
||||
|
score_matrix: bsz x seqlen x seqlen |
||||
|
input_ids: bsz x seqlen |
||||
|
pad_token_id: indicating which tokens are padding token |
||||
|
''' |
||||
|
bsz, seqlen, _ = score_matrix.size() |
||||
|
gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen |
||||
|
gold_score = torch.unsqueeze(gold_score, -1) |
||||
|
assert gold_score.size() == torch.Size([bsz, seqlen, 1]) |
||||
|
difference_matrix = gold_score - score_matrix |
||||
|
assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen]) |
||||
|
loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen |
||||
|
loss_matrix = torch.nn.functional.relu(loss_matrix) |
||||
|
|
||||
|
### input mask |
||||
|
input_mask = torch.ones_like(input_ids).type(torch.FloatTensor) |
||||
|
if loss_matrix.is_cuda: |
||||
|
input_mask = input_mask.cuda(loss_matrix.get_device()) |
||||
|
input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0) |
||||
|
|
||||
|
if loss_matrix.is_cuda: |
||||
|
input_mask = input_mask.cuda(loss_matrix.get_device()) |
||||
|
|
||||
|
valid_len_list = torch.sum(input_mask, dim = -1).tolist() |
||||
|
loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len) |
||||
|
if score_matrix.is_cuda: |
||||
|
loss_mask = loss_mask.cuda(score_matrix.get_device()) |
||||
|
masked_loss_matrix = loss_matrix * loss_mask |
||||
|
|
||||
|
loss_matrix = torch.sum(masked_loss_matrix, dim = -1) |
||||
|
assert loss_matrix.size() == input_ids.size() |
||||
|
loss_matrix = loss_matrix * input_mask |
||||
|
cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask) |
||||
|
return cl_loss |
||||
|
|
@ -0,0 +1,233 @@ |
|||||
|
import os |
||||
|
import sys |
||||
|
import operator |
||||
|
from tqdm import tqdm |
||||
|
from operator import itemgetter |
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import random |
||||
|
import argparse |
||||
|
import numpy as np |
||||
|
import torch.nn.functional as F |
||||
|
from torch.nn import CrossEntropyLoss |
||||
|
from loss_func import contrastive_loss |
||||
|
from utlis import PlugAndPlayContrastiveDecodingOneStepFast |
||||
|
|
||||
|
import seaborn as sns |
||||
|
import matplotlib.pyplot as plt |
||||
|
import pandas as pd |
||||
|
import datetime |
||||
|
|
||||
|
train_fct = CrossEntropyLoss() |
||||
|
val_fct = CrossEntropyLoss(reduction='none') |
||||
|
class SimCTG(nn.Module): |
||||
|
def __init__(self, model_name, sos_token, pad_token): |
||||
|
super(SimCTG, self).__init__() |
||||
|
from transformers import AutoTokenizer, GPT2LMHeadModel |
||||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
||||
|
self.sos_token, self.sos_token_id = self.add_special_token(sos_token) |
||||
|
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) |
||||
|
self.pad_token, self.pad_token_id = self.add_special_token(pad_token) |
||||
|
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) |
||||
|
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id |
||||
|
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) |
||||
|
self.model = GPT2LMHeadModel.from_pretrained(model_name) |
||||
|
self.vocab_size = len(self.tokenizer) |
||||
|
print ('Resizing model embedding...') |
||||
|
self.model.resize_token_embeddings(len(self.tokenizer)) |
||||
|
print ('Model embedding resized!') |
||||
|
self.embed_dim = self.model.config.hidden_size |
||||
|
|
||||
|
def add_special_token(self, special_token): |
||||
|
if special_token in self.tokenizer.vocab: |
||||
|
print (special_token + ' token exists.') |
||||
|
else: |
||||
|
print ('Add token to the tokenizer.') |
||||
|
print ('Original vocabulary size is {}'.format(len(self.tokenizer))) |
||||
|
self.tokenizer.add_tokens([special_token]) |
||||
|
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) |
||||
|
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 |
||||
|
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] |
||||
|
return special_token, special_token_id |
||||
|
|
||||
|
def compute_logits_and_hidden_states(self, input_ids): |
||||
|
# used for advanced decoding |
||||
|
# input_ids: 1 x seqlen |
||||
|
outputs = self.model(input_ids=input_ids, output_hidden_states=True) |
||||
|
last_hidden_states = outputs.hidden_states[-1] |
||||
|
logits = outputs.logits |
||||
|
return last_hidden_states, logits |
||||
|
|
||||
|
def forward(self, input_ids, labels, margin): |
||||
|
bsz, seqlen = input_ids.size() |
||||
|
outputs = self.model(input_ids=input_ids, output_hidden_states=True) |
||||
|
logits = outputs.logits |
||||
|
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) |
||||
|
last_hidden_states = outputs.hidden_states[-1] |
||||
|
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) |
||||
|
mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1)) |
||||
|
|
||||
|
norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True) |
||||
|
cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2)) |
||||
|
assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen]) |
||||
|
cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0) |
||||
|
return mle_loss, cl_loss |
||||
|
|
||||
|
def eval_loss(self, input_ids, labels): |
||||
|
bsz, seqlen = input_ids.size() |
||||
|
outputs = self.model(input_ids=input_ids, output_hidden_states=True) |
||||
|
logits = outputs.logits |
||||
|
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) |
||||
|
last_hidden_states = outputs.hidden_states[-1] |
||||
|
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) |
||||
|
mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1)) |
||||
|
assert mle_loss.size() == torch.Size([bsz * seqlen]) |
||||
|
mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0) |
||||
|
mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0) |
||||
|
# sum |
||||
|
mle_loss_sum = torch.sum(mle_loss) |
||||
|
token_num_sum = torch.sum(mask) |
||||
|
return mle_loss_sum, token_num_sum |
||||
|
|
||||
|
def save_model(self, ckpt_save_path): |
||||
|
import os |
||||
|
if os.path.exists(ckpt_save_path): |
||||
|
pass |
||||
|
else: # recursively construct directory |
||||
|
os.makedirs(ckpt_save_path, exist_ok=True) |
||||
|
# save model |
||||
|
self.model.save_pretrained(ckpt_save_path) |
||||
|
# save tokenizer |
||||
|
self.tokenizer.save_pretrained(ckpt_save_path) |
||||
|
|
||||
|
def parse_sentences(self, text, num_of_sentences_to_keep): |
||||
|
item_list = text.split('.') |
||||
|
res_list = item_list[:num_of_sentences_to_keep] |
||||
|
if len(item_list) > num_of_sentences_to_keep: |
||||
|
res_text = '.'.join(res_list).strip('.') + '.' |
||||
|
else: |
||||
|
res_text = '.'.join(res_list).strip('.').strip() |
||||
|
return res_text |
||||
|
|
||||
|
def parse_generated_result(self, output, num_of_sentences_to_keep): |
||||
|
output_text = self.tokenizer.decode(output) |
||||
|
item_list = output_text.split(self.eos_token) |
||||
|
full_text = self.eos_token.join(item_list[:2]).strip() |
||||
|
full_text = self.parse_sentences(full_text, num_of_sentences_to_keep) |
||||
|
generated_text = item_list[1].strip() |
||||
|
generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep) |
||||
|
return full_text, generated_text |
||||
|
|
||||
|
# decoding functions |
||||
|
# ------------------------------------------------------- # |
||||
|
|
||||
|
def parse_output_token_list(self, output): |
||||
|
output = output.tolist() |
||||
|
res_list = [] |
||||
|
for token_id in output: |
||||
|
if token_id == self.sos_token_id: |
||||
|
continue |
||||
|
elif token_id == self.eos_token_id: |
||||
|
break |
||||
|
else: |
||||
|
res_list.append(token_id) |
||||
|
text = self.tokenizer.decode(res_list).strip() |
||||
|
return ' '.join(text.split()).strip() |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip, |
||||
|
clip_text_max_len):#, add_token_level_score=False): |
||||
|
prefix_len = input_ids.size()[1] |
||||
|
#from utlis import PlugAndPlayContrastiveDecodingOneStepFast |
||||
|
past_key_values, last_hidden_states, logits = None, None, None |
||||
|
generated = [item for item in input_ids.tolist()] |
||||
|
input_ids_for_class = input_ids.clone() |
||||
|
|
||||
|
image_embeds = clip.compute_image_representation_from_image_instance(image_instance) |
||||
|
|
||||
|
start_time = datetime.datetime.now() |
||||
|
|
||||
|
# the maximum supported length of generation for SimCTG is 256 |
||||
|
# to support longer generated length, you can re-train the SimCTG model with longer sequences |
||||
|
decoding_len = decoding_len - prefix_len |
||||
|
for step in range(decoding_len): |
||||
|
input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \ |
||||
|
PlugAndPlayContrastiveDecodingOneStepFast( |
||||
|
self.model, |
||||
|
input_ids, |
||||
|
prefix_len, |
||||
|
beam_width, |
||||
|
alpha, |
||||
|
beta, |
||||
|
self.tokenizer, |
||||
|
image_embeds, |
||||
|
clip, |
||||
|
clip_text_max_len, |
||||
|
past_key_values, |
||||
|
last_hidden_states, |
||||
|
logits, |
||||
|
first_step=step==0, |
||||
|
input_ids_for_class=input_ids_for_class, |
||||
|
) |
||||
|
end_time = datetime.datetime.now() |
||||
|
time_diff = (end_time - start_time) |
||||
|
execution_time = time_diff.total_seconds() * 1000 |
||||
|
return self.parse_output_token_list(input_ids_for_class[0]) |
||||
|
|
||||
|
def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len): |
||||
|
''' |
||||
|
input_ids: prefix input; 1 x prefix_len |
||||
|
decoding_len: how many tokens to generate |
||||
|
beam_width: size of candidate pool during decoding |
||||
|
alpha: regulates importance of model confidence and degeneration penalty |
||||
|
''' |
||||
|
self.model.eval() |
||||
|
#from utlis import ContrastiveDecodingOneStepFast |
||||
|
# sanity check |
||||
|
assert alpha >= 0. and alpha <= 1.0 |
||||
|
|
||||
|
# fast mode |
||||
|
prefix_len = input_ids.size()[1] |
||||
|
batch_size, seqlen = input_ids.size() |
||||
|
#generated = [[] for _ in range(batch_size)] |
||||
|
generated = [item for item in input_ids.tolist()] |
||||
|
past_key_values = None |
||||
|
last_hidden_states = None |
||||
|
logits = None |
||||
|
decoding_len = decoding_len - prefix_len |
||||
|
for step in range(decoding_len): |
||||
|
input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast( |
||||
|
self.model, |
||||
|
input_ids, |
||||
|
beam_width, |
||||
|
alpha, |
||||
|
past_key_values, |
||||
|
last_hidden_states, |
||||
|
self.tokenizer, |
||||
|
logits, |
||||
|
first_step=step == 0, |
||||
|
) |
||||
|
tokens = input_ids.squeeze(dim=-1).tolist() |
||||
|
for idx, t in enumerate(tokens): |
||||
|
generated[idx].append(t) |
||||
|
return self.parse_output_token_list(torch.LongTensor(generated[0])) |
||||
|
|
||||
|
def top_k_sampling(self, input_ids, k, decoding_len): |
||||
|
_, prefix_len = input_ids.size() |
||||
|
output = self.model.generate( |
||||
|
input_ids, |
||||
|
do_sample=True, |
||||
|
max_length=decoding_len, |
||||
|
top_p=1.0, |
||||
|
top_k=k) |
||||
|
return self.parse_output_token_list(output[0]) |
||||
|
|
||||
|
def nucleus_sampling(self, input_ids, nucleus_p, decoding_len): |
||||
|
_, prefix_len = input_ids.size() |
||||
|
output = self.model.generate( |
||||
|
input_ids, |
||||
|
do_sample=True, |
||||
|
max_length=decoding_len, |
||||
|
top_p=nucleus_p, |
||||
|
top_k=0) |
||||
|
return self.parse_output_token_list(output[0]) |
@ -0,0 +1,107 @@ |
|||||
|
# coding=utf-8 |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
import torch.multiprocessing as mp |
||||
|
import argparse, os |
||||
|
import random |
||||
|
import numpy as np |
||||
|
import time |
||||
|
import logging |
||||
|
import progressbar |
||||
|
|
||||
|
import logging |
||||
|
logging.getLogger('transformers.generation_utils').disabled = True |
||||
|
|
||||
|
def parse_config(): |
||||
|
parser = argparse.ArgumentParser() |
||||
|
# data configuration |
||||
|
parser.add_argument("--model_name", type=str, default='gpt2') |
||||
|
parser.add_argument("--train_path", type=str) |
||||
|
parser.add_argument("--dev_path", type=str) |
||||
|
parser.add_argument("--test_path", type=str) |
||||
|
parser.add_argument("--max_len", type=int) |
||||
|
parser.add_argument("--add_eos_token_to_data", type=str) |
||||
|
# mini-batch training configuration |
||||
|
parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.") |
||||
|
parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.') |
||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.") |
||||
|
parser.add_argument("--effective_batch_size", type=int, |
||||
|
help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps") |
||||
|
# pre-training configuration |
||||
|
parser.add_argument("--total_steps", type=int, |
||||
|
help="total effective training steps") |
||||
|
parser.add_argument("--print_every", type=int, |
||||
|
help="how many update steps to print one intermediate result") |
||||
|
parser.add_argument("--save_every", type=int, |
||||
|
help="how many update steps to save one model") |
||||
|
# learning configuration |
||||
|
parser.add_argument("--learning_rate", type=float, default=2e-5) |
||||
|
parser.add_argument("--margin", type=float) |
||||
|
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
||||
|
parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.") |
||||
|
return parser.parse_args() |
||||
|
|
||||
|
def load_previous_best_model(path): |
||||
|
import os |
||||
|
filenames = os.listdir(path) |
||||
|
for file in filenames: |
||||
|
if file.startswith('training_step'): |
||||
|
return path + '/' + file |
||||
|
raise Exception('No best model found!') |
||||
|
|
||||
|
import argparse |
||||
|
if __name__ == '__main__': |
||||
|
if torch.cuda.is_available(): |
||||
|
print ('Cuda is available.') |
||||
|
cuda_available = torch.cuda.is_available() |
||||
|
multi_gpu_training = False |
||||
|
if cuda_available: |
||||
|
if torch.cuda.device_count() > 1: |
||||
|
multi_gpu_training = True |
||||
|
print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count())) |
||||
|
else: |
||||
|
print ('Using single GPU training.') |
||||
|
else: |
||||
|
pass |
||||
|
args = parse_config() |
||||
|
device = torch.device('cuda') |
||||
|
model_name = args.model_name |
||||
|
|
||||
|
sos_token, pad_token = r'<-start_of_text->', r'<-pad->' |
||||
|
add_eos_token_to_data = args.add_eos_token_to_data |
||||
|
if add_eos_token_to_data == 'True': |
||||
|
add_eos_token_to_data = True |
||||
|
print ('Add eos token to data!') |
||||
|
elif add_eos_token_to_data == 'False': |
||||
|
add_eos_token_to_data = False |
||||
|
print ('Do not add eos token to data!') |
||||
|
else: |
||||
|
raise Exception('Wrong eos configuration for data!!!') |
||||
|
print ('Loading data...') |
||||
|
from dataclass import Data |
||||
|
data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len, |
||||
|
sos_token, pad_token, add_eos_token_to_data) |
||||
|
print ('Data loaded.') |
||||
|
|
||||
|
from trainer import model_training |
||||
|
print ('############################################################') |
||||
|
print ('Start Training...') |
||||
|
from simctg import SimCTG |
||||
|
print ('Initializaing SimCTG model...') |
||||
|
model = SimCTG(model_name, sos_token, pad_token) |
||||
|
if cuda_available: |
||||
|
if multi_gpu_training: |
||||
|
model = nn.DataParallel(model) # multi-gpu training |
||||
|
else: |
||||
|
pass |
||||
|
model = model.to(device) |
||||
|
else: |
||||
|
pass |
||||
|
print ('Model loaded') |
||||
|
total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every |
||||
|
ckpt_save_path = args.save_path_prefix |
||||
|
model = model_training(args, data, model, total_steps, print_every, save_every, |
||||
|
ckpt_save_path, cuda_available, device) |
||||
|
print ('Training stage completed!') |
||||
|
print ('############################################################') |
@ -0,0 +1,17 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=0 python train.py\ |
||||
|
--model_name gpt2\ |
||||
|
--train_path ../data/flickr30k/flickr30k_train.json\ |
||||
|
--dev_path ../data/flickr30k/flickr30k_val.json\ |
||||
|
--test_path ../data/flickr30k/flickr30k_test.json\ |
||||
|
--add_eos_token_to_data True\ |
||||
|
--margin 0.5\ |
||||
|
--max_len 64\ |
||||
|
--number_of_gpu 1\ |
||||
|
--batch_size_per_gpu 32\ |
||||
|
--gradient_accumulation_steps 4\ |
||||
|
--effective_batch_size 128\ |
||||
|
--total_steps 10000\ |
||||
|
--print_every 50\ |
||||
|
--save_every 250\ |
||||
|
--learning_rate 2e-5\ |
||||
|
--save_path_prefix ./magic_flickr30k/ |
@ -0,0 +1,17 @@ |
|||||
|
CUDA_VISIBLE_DEVICES=0 python train.py\ |
||||
|
--model_name gpt2\ |
||||
|
--train_path ../data/mscoco/mscoco_train.json\ |
||||
|
--dev_path ../data/mscoco/mscoco_val.json\ |
||||
|
--test_path ../data/mscoco/mscoco_test.json\ |
||||
|
--add_eos_token_to_data True\ |
||||
|
--margin 0.5\ |
||||
|
--max_len 64\ |
||||
|
--number_of_gpu 1\ |
||||
|
--batch_size_per_gpu 32\ |
||||
|
--gradient_accumulation_steps 4\ |
||||
|
--effective_batch_size 128\ |
||||
|
--total_steps 20000\ |
||||
|
--print_every 100\ |
||||
|
--save_every 500\ |
||||
|
--learning_rate 2e-5\ |
||||
|
--save_path_prefix ./magic_mscoco/ |
@ -0,0 +1,165 @@ |
|||||
|
# coding=utf-8 |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
import torch.multiprocessing as mp |
||||
|
import argparse, os |
||||
|
import random |
||||
|
import numpy as np |
||||
|
import time |
||||
|
import logging |
||||
|
import progressbar |
||||
|
|
||||
|
import logging |
||||
|
logging.getLogger('transformers.generation_utils').disabled = True |
||||
|
|
||||
|
def eval_model(args, model, data, cuda_available, device): |
||||
|
dataset_batch_size = args.batch_size_per_gpu * args.number_of_gpu |
||||
|
eval_step = int(data.test_num / dataset_batch_size) + 1 |
||||
|
val_loss, token_sum = 0., 0. |
||||
|
model.eval() |
||||
|
with torch.no_grad(): |
||||
|
p = progressbar.ProgressBar(eval_step) |
||||
|
p.start() |
||||
|
for idx in range(eval_step): |
||||
|
p.update(idx) |
||||
|
batch_input_tensor, batch_labels, _ = \ |
||||
|
data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test') |
||||
|
if cuda_available: |
||||
|
batch_input_tensor = batch_input_tensor.cuda(device) |
||||
|
batch_labels = batch_labels.cuda(device) |
||||
|
one_val_loss, one_val_token_sum = model.eval_loss(batch_input_tensor, batch_labels) |
||||
|
one_val_loss = torch.sum(one_val_loss) |
||||
|
one_val_token_sum = torch.sum(one_val_token_sum) |
||||
|
val_loss += one_val_loss.item() |
||||
|
token_sum += one_val_token_sum.item() |
||||
|
p.finish() |
||||
|
model.train() |
||||
|
val_loss = val_loss / token_sum |
||||
|
return val_loss |
||||
|
|
||||
|
def model_training(args, data, model, total_steps, print_every, save_every, ckpt_save_path, cuda_available, device): |
||||
|
import os |
||||
|
if os.path.exists(ckpt_save_path): |
||||
|
pass |
||||
|
else: # recursively construct directory |
||||
|
os.makedirs(ckpt_save_path, exist_ok=True) |
||||
|
|
||||
|
max_save_num = 1 |
||||
|
|
||||
|
batch_size_per_gpu, gradient_accumulation_steps, number_of_gpu, effective_batch_size = \ |
||||
|
args.batch_size_per_gpu, args.gradient_accumulation_steps, args.number_of_gpu, args.effective_batch_size |
||||
|
assert effective_batch_size == batch_size_per_gpu * gradient_accumulation_steps * number_of_gpu |
||||
|
|
||||
|
warmup_steps = int(0.1 * total_steps) # 10% of training steps are used for warmup |
||||
|
print ('total training steps is {}, warmup steps is {}'.format(total_steps, warmup_steps)) |
||||
|
from transformers.optimization import AdamW, get_linear_schedule_with_warmup |
||||
|
optimizer = AdamW(model.parameters(), lr=args.learning_rate) |
||||
|
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) |
||||
|
optimizer.zero_grad() |
||||
|
|
||||
|
effective_batch_acm = 0 |
||||
|
all_batch_step = 1 |
||||
|
print_valid, save_valid = False, False |
||||
|
train_loss, train_cl_loss, min_val_loss = 0., 0., 1e10 |
||||
|
train_ave_bleu = 0. |
||||
|
|
||||
|
print ('--------------------------------------------------------------------------') |
||||
|
print ('Start Training:') |
||||
|
model.train() |
||||
|
number_of_saves = 0 |
||||
|
|
||||
|
while effective_batch_acm < total_steps: |
||||
|
all_batch_step += 1 |
||||
|
train_batch_input_tensor, train_batch_labels, _ = data.get_next_train_batch(batch_size_per_gpu * number_of_gpu) |
||||
|
if cuda_available: |
||||
|
train_batch_input_tensor = train_batch_input_tensor.cuda(device) |
||||
|
train_batch_labels = train_batch_labels.cuda(device) |
||||
|
mle_loss, cl_loss = model(train_batch_input_tensor, train_batch_labels, args.margin) |
||||
|
|
||||
|
loss = mle_loss + cl_loss |
||||
|
loss = loss.mean() |
||||
|
loss.backward() |
||||
|
train_loss += mle_loss.item() |
||||
|
train_cl_loss += cl_loss.item() |
||||
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
||||
|
|
||||
|
# parameter update |
||||
|
if all_batch_step % gradient_accumulation_steps == 0: |
||||
|
optimizer.step() |
||||
|
scheduler.step() |
||||
|
optimizer.zero_grad() |
||||
|
effective_batch_acm += 1 |
||||
|
print_valid, save_valid = True, True |
||||
|
|
||||
|
# print intermediate result |
||||
|
if effective_batch_acm % print_every == 0 and print_valid: |
||||
|
denominator = (effective_batch_acm - (number_of_saves * save_every)) * gradient_accumulation_steps |
||||
|
one_train_loss = train_loss / denominator |
||||
|
one_train_cl_loss = train_cl_loss / denominator |
||||
|
print ('At training steps {}, training MLE loss is {}, train CL loss is {}'.format(effective_batch_acm, |
||||
|
one_train_loss, one_train_cl_loss)) |
||||
|
print_valid = False |
||||
|
|
||||
|
# saving result |
||||
|
if effective_batch_acm % save_every == 0 and save_valid: |
||||
|
number_of_saves += 1 |
||||
|
|
||||
|
save_valid = False |
||||
|
one_train_loss = train_loss / (save_every * gradient_accumulation_steps) |
||||
|
one_train_cl_loss = train_cl_loss / (save_every * gradient_accumulation_steps) |
||||
|
|
||||
|
model.eval() |
||||
|
one_val_loss = eval_model(args, model, data, cuda_available, device) |
||||
|
model.train() |
||||
|
|
||||
|
print ('At training steps {}, training MLE loss is {}, train CL loss is {}, validation loss is {}'.format(effective_batch_acm, |
||||
|
one_train_loss, one_train_cl_loss, one_val_loss)) |
||||
|
|
||||
|
train_loss, train_cl_loss = 0., 0. |
||||
|
|
||||
|
if one_val_loss < min_val_loss: |
||||
|
# in finetuning stage, we always save the model |
||||
|
min_val_loss = min(one_val_loss, min_val_loss) |
||||
|
print ('Saving model...') |
||||
|
one_val_ppl = np.exp(one_val_loss) |
||||
|
one_val_ppl = round(one_val_ppl, 3) |
||||
|
save_name = 'training_step_{}_train_mle_loss_{}_train_cl_loss_{}_dev_loss_{}_dev_ppl_{}'.format(effective_batch_acm, |
||||
|
round(one_train_loss,5), round(one_train_cl_loss,5), round(one_val_loss,5), one_val_ppl) |
||||
|
|
||||
|
model_save_path = ckpt_save_path + '/' + save_name |
||||
|
import os |
||||
|
if os.path.exists(model_save_path): |
||||
|
pass |
||||
|
else: # recursively construct directory |
||||
|
os.makedirs(model_save_path, exist_ok=True) |
||||
|
if cuda_available and torch.cuda.device_count() > 1: |
||||
|
model.module.save_model(model_save_path) |
||||
|
else: |
||||
|
model.save_model(model_save_path) |
||||
|
print ('Model Saved!') |
||||
|
|
||||
|
# --------------------------------------------------------------------------------------------- # |
||||
|
# removing extra checkpoints... |
||||
|
import os |
||||
|
from operator import itemgetter |
||||
|
fileData = {} |
||||
|
test_output_dir = ckpt_save_path |
||||
|
for fname in os.listdir(test_output_dir): |
||||
|
if fname.startswith('training_step'): |
||||
|
fileData[fname] = os.stat(test_output_dir + '/' + fname).st_mtime |
||||
|
else: |
||||
|
pass |
||||
|
sortedFiles = sorted(fileData.items(), key=itemgetter(1)) |
||||
|
|
||||
|
if len(sortedFiles) < max_save_num: |
||||
|
pass |
||||
|
else: |
||||
|
delete = len(sortedFiles) - max_save_num |
||||
|
for x in range(0, delete): |
||||
|
one_folder_name = test_output_dir + '/' + sortedFiles[x][0] |
||||
|
os.system('rm -r ' + one_folder_name) |
||||
|
print ('-----------------------------------') |
||||
|
# --------------------------------------------------------------------------------------------- # |
||||
|
return model |
||||
|
|
@ -0,0 +1,291 @@ |
|||||
|
import sys |
||||
|
import os |
||||
|
import operator |
||||
|
from operator import itemgetter |
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
import random |
||||
|
import numpy as np |
||||
|
import argparse |
||||
|
import random |
||||
|
|
||||
|
def parse_prompt(text): |
||||
|
''' |
||||
|
process the prompt text; |
||||
|
''' |
||||
|
eos_token = '<|endoftext|>' |
||||
|
text = text.strip(eos_token).strip() |
||||
|
left_bracket_idx, right_bracket_idx = -1, -1 |
||||
|
for idx in range(len(text)): |
||||
|
char = text[idx] |
||||
|
if char == '[' and left_bracket_idx == -1: # first [ is met |
||||
|
left_bracket_idx = idx |
||||
|
elif char == ']' and right_bracket_idx == -1: # first ] is met |
||||
|
right_bracket_idx = idx |
||||
|
else: |
||||
|
pass |
||||
|
res_text = '' |
||||
|
remove = False |
||||
|
if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx: |
||||
|
if right_bracket_idx - left_bracket_idx <= 6: |
||||
|
remove = True |
||||
|
else: |
||||
|
pass |
||||
|
|
||||
|
for idx in range(len(text)): |
||||
|
if remove: |
||||
|
if idx >= left_bracket_idx and idx <= right_bracket_idx: |
||||
|
continue |
||||
|
else: |
||||
|
res_text += text[idx] |
||||
|
else: |
||||
|
res_text += text[idx] |
||||
|
res_text = res_text.strip() |
||||
|
res_text = ' '.join(res_text.split()).strip() |
||||
|
return res_text |
||||
|
|
||||
|
def typical_filtering(scores, mass, min_tokens_to_keep, filter_value): |
||||
|
# calculate entropy |
||||
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1) |
||||
|
p = torch.exp(normalized) |
||||
|
ent = -(normalized * p).nansum(-1, keepdim=True) |
||||
|
|
||||
|
# shift and sort |
||||
|
shifted_scores = torch.abs((-normalized) - ent) |
||||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) |
||||
|
sorted_logits = scores.gather(-1, sorted_indices) |
||||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
||||
|
|
||||
|
# Remove tokens with cumulative mass above the threshold |
||||
|
last_ind = (cumulative_probs < mass).sum(dim=1) |
||||
|
last_ind[last_ind < 0] = 0 |
||||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) |
||||
|
if min_tokens_to_keep > 1: |
||||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) |
||||
|
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 |
||||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
||||
|
|
||||
|
scores = scores.masked_fill(indices_to_remove, filter_value) |
||||
|
return scores |
||||
|
|
||||
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf): |
||||
|
assert logits.dim() == 1 |
||||
|
top_k = min(top_k, logits.size(-1)) |
||||
|
if top_k > 0: |
||||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
||||
|
logits[indices_to_remove] = filter_value |
||||
|
if top_p > 0.0: |
||||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
||||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
||||
|
sorted_indices_to_remove = cumulative_probs > top_p |
||||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
||||
|
sorted_indices_to_remove[..., 0] = 0 |
||||
|
|
||||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
||||
|
logits[indices_to_remove] = filter_value |
||||
|
|
||||
|
indices_to_remove = logits < threshold |
||||
|
logits[indices_to_remove] = filter_value |
||||
|
return logits |
||||
|
|
||||
|
# ========== batch version ========= # |
||||
|
def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width): |
||||
|
''' |
||||
|
context_hidden: bsz*beam x seqlen x embed_dim |
||||
|
next_hidden: bsz*beam x 1 x embed_dim |
||||
|
next_top_k_probs: bsz x beam |
||||
|
''' |
||||
|
_, context_len, embed_dim = context_hidden.size() |
||||
|
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) |
||||
|
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) |
||||
|
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S] |
||||
|
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] |
||||
|
next_top_k_probs = next_top_k_probs.view(-1) # [B*K] |
||||
|
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores |
||||
|
scores = torch.stack(torch.split(scores, beam_width)) # [B, K] |
||||
|
selected_idx = scores.max(dim=-1)[1] # [B] |
||||
|
return selected_idx |
||||
|
|
||||
|
def ContrastiveDecodingOneStepFast( |
||||
|
model, |
||||
|
ids, |
||||
|
beam_width, |
||||
|
alpha, |
||||
|
past_key_values, |
||||
|
last_hidden_states, |
||||
|
vocab, |
||||
|
logit_for_next_step, |
||||
|
first_step=False, |
||||
|
): |
||||
|
# input_ids: [B, S] |
||||
|
if first_step: |
||||
|
output = model( |
||||
|
input_ids=ids, |
||||
|
past_key_values=past_key_values, |
||||
|
use_cache=True, |
||||
|
output_hidden_states=True |
||||
|
) |
||||
|
past_key_values = output.past_key_values |
||||
|
last_hidden_states = output.hidden_states[-1] # [B, S, E] |
||||
|
logit_for_next_step = output.logits[:, -1, :] # [B, V] |
||||
|
bsz, seqlen, embed_dim = last_hidden_states.size() |
||||
|
p = random.uniform(0, 1) |
||||
|
|
||||
|
next_probs = F.softmax(logit_for_next_step, dim=-1) |
||||
|
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K] |
||||
|
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K] |
||||
|
# compute new hidden |
||||
|
past_key_values = enlarge_past_key_values(past_key_values, beam_width) |
||||
|
output = model( |
||||
|
input_ids=top_k_ids.view(-1, 1), |
||||
|
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), |
||||
|
past_key_values=past_key_values, |
||||
|
output_hidden_states=True, |
||||
|
use_cache=True, |
||||
|
) |
||||
|
past_key_values = output.past_key_values |
||||
|
logits = output.logits[:, -1, :] # [B*K, V] |
||||
|
next_hidden = output.hidden_states[-1] # [B*K, 1, E] |
||||
|
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E] |
||||
|
|
||||
|
selected_idx = ranking_fast( |
||||
|
context_hidden, |
||||
|
next_hidden, |
||||
|
top_k_probs, # [B, K] |
||||
|
alpha, |
||||
|
beam_width, |
||||
|
) # [B] |
||||
|
# prepare for the next step |
||||
|
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1] |
||||
|
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E] |
||||
|
next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E] |
||||
|
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E] |
||||
|
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) |
||||
|
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V] |
||||
|
# next_id: [B, 1] |
||||
|
return next_id, past_key_values, last_hidden_states, logits |
||||
|
|
||||
|
def enlarge_past_key_values(past_key_values, beam_width): |
||||
|
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] |
||||
|
new_key_values = [] |
||||
|
for layer in past_key_values: |
||||
|
items = [] |
||||
|
for item in layer: |
||||
|
# item is the key and value matrix |
||||
|
bsz, num_head, seq_len, esz = item.size() |
||||
|
item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz] |
||||
|
items.append(item) |
||||
|
new_key_values.append(items) |
||||
|
return new_key_values |
||||
|
|
||||
|
def select_past_key_values(past_key_values, beam_width, selected_idx): |
||||
|
'''select_idx: [B]''' |
||||
|
new_key_values = [] |
||||
|
for layer in past_key_values: |
||||
|
items = [] |
||||
|
for item in layer: |
||||
|
bsz_and_beam, num_head, seq_len, esz = item.size() |
||||
|
bsz = int(bsz_and_beam//beam_width) |
||||
|
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] |
||||
|
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] |
||||
|
items.append(item) |
||||
|
new_key_values.append(items) |
||||
|
return new_key_values |
||||
|
|
||||
|
# ========== fast plug and play version ========= # |
||||
|
def plug_and_play_fast_ranking( |
||||
|
context_hidden, |
||||
|
next_hidden, |
||||
|
next_top_k_ids, |
||||
|
next_top_k_probs, |
||||
|
alpha, |
||||
|
beta, |
||||
|
batch_class_score, |
||||
|
beam_width, |
||||
|
): |
||||
|
''' |
||||
|
context_hidden: beam_width x context_len x embed_dim |
||||
|
next_hidden: beam_width x 1 x embed_dim |
||||
|
next_top_k_ids: beam_width x 1 |
||||
|
batch_class_score: beam_width x 1 |
||||
|
''' |
||||
|
_, context_len, embed_dim = context_hidden.size() |
||||
|
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) |
||||
|
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) |
||||
|
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) |
||||
|
scores, _ = torch.max(cosine_matrix, dim = -1) |
||||
|
next_top_k_probs = next_top_k_probs.view(-1) |
||||
|
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width]) |
||||
|
scores = torch.stack(torch.split(scores, beam_width)) |
||||
|
selected_idx = scores.max(dim=-1)[1] |
||||
|
return selected_idx |
||||
|
|
||||
|
def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta, |
||||
|
simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states, |
||||
|
logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False): |
||||
|
''' |
||||
|
model: the generation model, e.g., gpt2 |
||||
|
input_ids: 1 x seqlen |
||||
|
''' |
||||
|
|
||||
|
if first_step: |
||||
|
output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True) |
||||
|
past_key_values = output.past_key_values |
||||
|
last_hidden_states = output.hidden_states[-1] # [B, S, E] |
||||
|
logit_for_next_step = output.logits[:, -1, :] # [B, V] |
||||
|
bsz, seqlen, embed_dim = last_hidden_states.size() |
||||
|
next_probs = F.softmax(logit_for_next_step, dim = -1) |
||||
|
_, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width) |
||||
|
top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids) |
||||
|
|
||||
|
# compute the new hidden |
||||
|
past_key_values = enlarge_past_key_values(past_key_values, beam_width) |
||||
|
output = model( |
||||
|
input_ids=top_k_ids.view(-1, 1) , |
||||
|
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), |
||||
|
past_key_values=past_key_values, |
||||
|
output_hidden_states=True, |
||||
|
use_cache=True, |
||||
|
) |
||||
|
past_key_values = output.past_key_values |
||||
|
logits = output.logits[:, -1, :] |
||||
|
next_hidden = output.hidden_states[-1] |
||||
|
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) |
||||
|
|
||||
|
# prepare for the classification model |
||||
|
input_ids_for_class_ = torch.cat([ |
||||
|
input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen), |
||||
|
top_k_ids.view(-1, 1) |
||||
|
], dim=-1 |
||||
|
) |
||||
|
|
||||
|
batch_text_list = [] |
||||
|
for one_input_id in input_ids_for_class_: |
||||
|
one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:]) |
||||
|
# we only consider the class score of the generated text continuation |
||||
|
batch_text_list.append(one_text) |
||||
|
batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) |
||||
|
|
||||
|
selected_idx = plug_and_play_fast_ranking( |
||||
|
context_hidden, |
||||
|
next_hidden, |
||||
|
top_k_ids, |
||||
|
top_k_probs, |
||||
|
alpha, |
||||
|
beta, |
||||
|
batch_score, |
||||
|
beam_width, |
||||
|
) |
||||
|
|
||||
|
# prepare for the next step |
||||
|
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) |
||||
|
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) |
||||
|
next_hidden = next_hidden[range(bsz), selected_idx, :] |
||||
|
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) |
||||
|
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) |
||||
|
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] |
||||
|
input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1) |
||||
|
return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class |
||||
|
|
||||
|
|
@ -1,89 +0,0 @@ |
|||||
### Our Implementation of the ZeroCap Baseline Model |
|
||||
|
|
||||
**** |
|
||||
### Catalogue: |
|
||||
* <a href='#environment'>1. Environment Preparation</a> |
|
||||
* <a href='#mscoco'>2. Image Captioning on MSCOCO</a> |
|
||||
* <a href='#flickr30k'>3. Image Captioning on Flickr30k</a> |
|
||||
* <a href='#flickr30k_to_mscoco'>4. Cross Domain Image Captioning on MSCOCO</a> |
|
||||
* <a href='#mscoco_to_flickr30k'>5. Cross Domain Image Captioning on Flickr30k</a> |
|
||||
* <a href='#citation'>6. Citation</a> |
|
||||
* <a href='#acknowledgements'>7. Acknowledgements</a> |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='environment'/> |
|
||||
|
|
||||
#### 1. Environment Preparation: |
|
||||
To install the correct environment, please run the following command: |
|
||||
```yaml |
|
||||
pip install -r requirements.txt |
|
||||
``` |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='mscoco'/> |
|
||||
|
|
||||
#### 2. Image Captioning on MSCOCO: |
|
||||
To perform image captioning on MSCOCO, please run the following command: |
|
||||
```yaml |
|
||||
chmod +x ./mscoco_zerocap.sh |
|
||||
./mscoco_zerocap.sh |
|
||||
``` |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='flickr30k'/> |
|
||||
|
|
||||
#### 3. Image Captioning on Flickr30k: |
|
||||
To perform image captioning on Flickr30k, please run the following command: |
|
||||
```yaml |
|
||||
chmod +x ./flickr30k_zerocap.sh |
|
||||
./flickr30k_zerocap.sh |
|
||||
``` |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='flickr30k_to_mscoco'/> |
|
||||
|
|
||||
#### 4. Cross Domain Image Captioning on MSCOCO: |
|
||||
To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command: |
|
||||
```yaml |
|
||||
chmod +x ./flickr30k_to_mscoco_zerocap.sh |
|
||||
./flickr30k_to_mscoco_zerocap.sh |
|
||||
``` |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='mscoco_to_flickr30k'/> |
|
||||
|
|
||||
#### 5. Cross Domain Image Captioning on Flickr30k: |
|
||||
To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command: |
|
||||
```yaml |
|
||||
chmod +x ./mscoco_to_flickr30k_zerocap.sh |
|
||||
./mscoco_to_flickr30k_zerocap.sh |
|
||||
``` |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='citation'/> |
|
||||
|
|
||||
#### 6. Citation: |
|
||||
If you find our code helpful, please cite the original paper as |
|
||||
|
|
||||
```bibtex |
|
||||
@article{tewel2021zero, |
|
||||
title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, |
|
||||
author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, |
|
||||
journal={arXiv preprint arXiv:2111.14447}, |
|
||||
year={2021} |
|
||||
} |
|
||||
``` |
|
||||
|
|
||||
**** |
|
||||
|
|
||||
<span id='acknowledgements'/> |
|
||||
|
|
||||
#### 7. Acknowledgements: |
|
||||
We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text). |
|
||||
|
|
@ -1,12 +0,0 @@ |
|||||
build: |
|
||||
gpu: true |
|
||||
python_version: "3.8" |
|
||||
system_packages: |
|
||||
- "libgl1-mesa-glx" |
|
||||
- "libglib2.0-0" |
|
||||
python_packages: |
|
||||
- "git+https://github.com/openai/CLIP.git" |
|
||||
- "git+https://github.com/YoadTew/zero-shot-image-to-text.git" |
|
||||
|
|
||||
predict: "predict.py:Predictor" |
|
||||
#predict: "predict_arithmetic.py:Predictor" |
|
@ -1,14 +0,0 @@ |
|||||
#!/bin/bash |
|
||||
|
|
||||
# lm_model: |
|
||||
# 1. cambridgeltl/magic_mscoco |
|
||||
# 2. cambridgeltl/magic_flickr30k |
|
||||
CUDA_VISIBLE_DEVICES=1 python run.py \ |
|
||||
--beam_size 1 \ |
|
||||
--target_seq_length 16 \ |
|
||||
--reset_context_delta \ |
|
||||
--lm_model cambridgeltl/magic_flickr30k \ |
|
||||
--test_image_prefix_path ../data/flickr30k/test_images \ |
|
||||
--test_path ../data/flickr30k/flickr30k_test.json \ |
|
||||
--save_path_prefix ../inference_result/flickr30k/baselines/ \ |
|
||||
--save_name zerocap_result.json |
|
Binary file not shown.
@ -1,389 +0,0 @@ |
|||||
import numpy as np |
|
||||
from torch import nn |
|
||||
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer |
|
||||
from transformers.models.gpt_neo import GPTNeoForCausalLM |
|
||||
import torch |
|
||||
import clip |
|
||||
from PIL import Image |
|
||||
from datetime import datetime |
|
||||
import sys |
|
||||
|
|
||||
|
|
||||
def log_info(text, verbose=True): |
|
||||
if verbose: |
|
||||
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
|
||||
print(f'{dt_string} | {text}') |
|
||||
sys.stdout.flush() |
|
||||
|
|
||||
|
|
||||
def add_context(x, y): |
|
||||
return (x[0] + y[0], x[1] + y[1]) |
|
||||
|
|
||||
|
|
||||
def convert_models_to_fp32(model): |
|
||||
for p in model.parameters(): |
|
||||
p.data = p.data.float() |
|
||||
|
|
||||
|
|
||||
class CLIPTextGenerator: |
|
||||
def __init__(self, |
|
||||
seed=0, |
|
||||
lm_model='gpt-2', |
|
||||
forbidden_tokens_file_path='./forbidden_tokens.npy', |
|
||||
clip_checkpoints='./clip_checkpoints', |
|
||||
target_seq_length=15, |
|
||||
reset_context_delta=True, |
|
||||
num_iterations=5, |
|
||||
clip_loss_temperature=0.01, |
|
||||
clip_scale=1., |
|
||||
ce_scale=0.2, |
|
||||
stepsize=0.3, |
|
||||
grad_norm_factor=0.9, |
|
||||
fusion_factor=0.99, |
|
||||
repetition_penalty=1., |
|
||||
end_token='.', |
|
||||
end_factor=1.01, |
|
||||
forbidden_factor=20, |
|
||||
**kwargs): |
|
||||
|
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
||||
|
|
||||
# set Random seed |
|
||||
torch.manual_seed(seed) |
|
||||
np.random.seed(seed) |
|
||||
|
|
||||
# Initialize Language model |
|
||||
self.context_prefix = '' |
|
||||
|
|
||||
self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model) |
|
||||
self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True) |
|
||||
self.context_prefix = self.lm_tokenizer.bos_token |
|
||||
|
|
||||
self.lm_model.to(self.device) |
|
||||
self.lm_model.eval() |
|
||||
|
|
||||
self.forbidden_tokens = np.load(forbidden_tokens_file_path) |
|
||||
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if |
|
||||
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())] |
|
||||
|
|
||||
# Freeze LM weights |
|
||||
for param in self.lm_model.parameters(): |
|
||||
param.requires_grad = False |
|
||||
|
|
||||
# Initialize CLIP |
|
||||
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, |
|
||||
download_root=clip_checkpoints, jit=False) |
|
||||
# convert_models_to_fp32(self.clip) |
|
||||
|
|
||||
# Init arguments |
|
||||
self.target_seq_length = target_seq_length |
|
||||
self.reset_context_delta = reset_context_delta |
|
||||
self.num_iterations = num_iterations |
|
||||
self.clip_loss_temperature = clip_loss_temperature |
|
||||
self.clip_scale = clip_scale |
|
||||
self.ce_scale = ce_scale |
|
||||
self.stepsize = stepsize |
|
||||
self.grad_norm_factor = grad_norm_factor |
|
||||
self.fusion_factor = fusion_factor |
|
||||
self.repetition_penalty = repetition_penalty |
|
||||
self.end_token = self.lm_tokenizer.encode(end_token)[0] |
|
||||
self.end_factor = end_factor |
|
||||
self.ef_idx = 1 |
|
||||
self.forbidden_factor = forbidden_factor |
|
||||
|
|
||||
def get_img_feature(self, img_path, weights): |
|
||||
imgs = [Image.open(x) for x in img_path] |
|
||||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
|
||||
|
|
||||
if weights is not None: |
|
||||
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) |
|
||||
else: |
|
||||
image_features = sum(image_fts) |
|
||||
|
|
||||
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
||||
return image_features.detach() |
|
||||
|
|
||||
def get_txt_features(self, text): |
|
||||
clip_texts = clip.tokenize(text).to(self.device) |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
text_features = self.clip.encode_text(clip_texts) |
|
||||
|
|
||||
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
||||
return text_features.detach() |
|
||||
|
|
||||
def get_combined_feature(self, img_path, texts, weights_i, weights_t): |
|
||||
imgs = [Image.open(x) for x in img_path] |
|
||||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|
||||
clip_texts = [clip.tokenize(x).to(self.device) for x in texts] |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
|
||||
text_fts = [self.clip.encode_text(x) for x in clip_texts] |
|
||||
|
|
||||
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) |
|
||||
if weights_t is not None: |
|
||||
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) |
|
||||
|
|
||||
features = features / features.norm(dim=-1, keepdim=True) |
|
||||
return features.detach() |
|
||||
|
|
||||
def run(self, image_features, cond_text, beam_size): |
|
||||
self.image_features = image_features |
|
||||
|
|
||||
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) |
|
||||
|
|
||||
output_tokens, output_text = self.generate_text(context_tokens, beam_size) |
|
||||
|
|
||||
return output_text |
|
||||
|
|
||||
def generate_text(self, context_tokens, beam_size): |
|
||||
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) |
|
||||
|
|
||||
gen_tokens = None |
|
||||
scores = None |
|
||||
seq_lengths = torch.ones(beam_size, device=self.device) |
|
||||
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) |
|
||||
|
|
||||
for i in range(self.target_seq_length): |
|
||||
probs = self.get_next_probs(i, context_tokens) |
|
||||
logits = probs.log() |
|
||||
|
|
||||
if scores is None: |
|
||||
scores, next_tokens = logits.topk(beam_size, -1) |
|
||||
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) |
|
||||
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
|
||||
|
|
||||
if gen_tokens is None: |
|
||||
gen_tokens = next_tokens |
|
||||
else: |
|
||||
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) |
|
||||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) |
|
||||
else: |
|
||||
logits[is_stopped] = -float(np.inf) |
|
||||
logits[is_stopped, 0] = 0 |
|
||||
scores_sum = scores[:, None] + logits |
|
||||
seq_lengths[~is_stopped] += 1 |
|
||||
scores_sum_average = scores_sum / seq_lengths[:, None] |
|
||||
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( |
|
||||
beam_size, -1) |
|
||||
next_tokens_source = next_tokens // scores_sum.shape[1] |
|
||||
seq_lengths = seq_lengths[next_tokens_source] |
|
||||
next_tokens = next_tokens % scores_sum.shape[1] |
|
||||
next_tokens = next_tokens.unsqueeze(1) |
|
||||
gen_tokens = gen_tokens[next_tokens_source] |
|
||||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) |
|
||||
context_tokens = context_tokens[next_tokens_source] |
|
||||
scores = scores_sum_average * seq_lengths |
|
||||
is_stopped = is_stopped[next_tokens_source] |
|
||||
|
|
||||
context_tokens = torch.cat((context_tokens, next_tokens), dim=1) |
|
||||
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() |
|
||||
|
|
||||
#### |
|
||||
tmp_scores = scores / seq_lengths |
|
||||
tmp_output_list = gen_tokens.cpu().numpy() |
|
||||
tmp_output_texts = [ |
|
||||
self.lm_tokenizer.decode(tmp_output) |
|
||||
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) |
|
||||
] |
|
||||
tmp_order = tmp_scores.argsort(descending=True) |
|
||||
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] |
|
||||
log_info(tmp_output_texts, verbose=True) |
|
||||
#### |
|
||||
|
|
||||
if is_stopped.all(): |
|
||||
break |
|
||||
|
|
||||
scores = scores / seq_lengths |
|
||||
output_list = gen_tokens.cpu().numpy() |
|
||||
output_texts = [ |
|
||||
self.lm_tokenizer.decode(output[: int(length)]) |
|
||||
for output, length in zip(output_list, seq_lengths) |
|
||||
] |
|
||||
order = scores.argsort(descending=True) |
|
||||
output_texts = [output_texts[i] for i in order] |
|
||||
|
|
||||
return context_tokens, output_texts |
|
||||
|
|
||||
def get_next_probs(self, i, context_tokens): |
|
||||
last_token = context_tokens[:, -1:] |
|
||||
|
|
||||
if self.reset_context_delta and context_tokens.size(1) > 1: |
|
||||
context = self.lm_model(context_tokens[:, :-1])["past_key_values"] |
|
||||
|
|
||||
# Logits of LM with unshifted context |
|
||||
logits_before_shift = self.lm_model(context_tokens)["logits"] |
|
||||
logits_before_shift = logits_before_shift[:, -1, :] |
|
||||
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) |
|
||||
|
|
||||
if context: |
|
||||
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) |
|
||||
|
|
||||
lm_output = self.lm_model(last_token, past_key_values=context) |
|
||||
logits, past = ( |
|
||||
lm_output["logits"], |
|
||||
lm_output["past_key_values"], |
|
||||
) |
|
||||
logits = logits[:, -1, :] |
|
||||
|
|
||||
logits = self.update_special_tokens_logits(context_tokens, i, logits) |
|
||||
|
|
||||
probs = nn.functional.softmax(logits, dim=-1) |
|
||||
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) |
|
||||
probs = probs / probs.sum() |
|
||||
|
|
||||
return probs |
|
||||
|
|
||||
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): |
|
||||
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] |
|
||||
|
|
||||
window_mask = torch.ones_like(context[0][0]).to(self.device) |
|
||||
|
|
||||
for i in range(self.num_iterations): |
|
||||
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in |
|
||||
context_delta] |
|
||||
|
|
||||
for p0, p1 in curr_shift: |
|
||||
p0.retain_grad() |
|
||||
p1.retain_grad() |
|
||||
|
|
||||
shifted_context = list(map(add_context, context, curr_shift)) |
|
||||
|
|
||||
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) |
|
||||
logits = shifted_outputs["logits"][:, -1, :] |
|
||||
probs = nn.functional.softmax(logits, dim=-1) |
|
||||
|
|
||||
loss = 0.0 |
|
||||
|
|
||||
# CLIP LOSS |
|
||||
clip_loss, clip_losses = self.clip_loss(probs, context_tokens) |
|
||||
loss += self.clip_scale * clip_loss |
|
||||
|
|
||||
# CE/Fluency loss |
|
||||
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) |
|
||||
loss += ce_loss.sum() |
|
||||
|
|
||||
loss.backward() |
|
||||
|
|
||||
# ---------- Weights ---------- |
|
||||
combined_scores_k = -(ce_loss) |
|
||||
combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) |
|
||||
|
|
||||
# minmax |
|
||||
if combined_scores_k.shape[0] == 1: |
|
||||
tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) |
|
||||
else: |
|
||||
tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( |
|
||||
combined_scores_k.max() - combined_scores_k.min()) |
|
||||
tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( |
|
||||
combined_scores_c.max() - combined_scores_c.min()) |
|
||||
|
|
||||
tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c |
|
||||
tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) |
|
||||
|
|
||||
factor = 1 |
|
||||
|
|
||||
# --------- Specific Gen --------- |
|
||||
sep_grads = None |
|
||||
|
|
||||
for b in range(context_tokens.shape[0]): |
|
||||
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] |
|
||||
for p_ in curr_shift] |
|
||||
|
|
||||
# normalize gradients |
|
||||
tmp_grad = [tuple([-self.stepsize * factor * ( |
|
||||
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ |
|
||||
j] ** self.grad_norm_factor).data.cpu().numpy() |
|
||||
for j, x in enumerate(p_)]) |
|
||||
for i, p_ in enumerate(curr_shift)] |
|
||||
if sep_grads is None: |
|
||||
sep_grads = tmp_grad |
|
||||
else: |
|
||||
for l_index in range(len(sep_grads)): |
|
||||
sep_grads[l_index] = list(sep_grads[l_index]) |
|
||||
for k_index in range(len(sep_grads[0])): |
|
||||
sep_grads[l_index][k_index] = np.concatenate( |
|
||||
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) |
|
||||
sep_grads[l_index] = tuple(sep_grads[l_index]) |
|
||||
final_grads = sep_grads |
|
||||
|
|
||||
# --------- update context --------- |
|
||||
context_delta = list(map(add_context, final_grads, context_delta)) |
|
||||
|
|
||||
for p0, p1 in curr_shift: |
|
||||
p0.grad.data.zero_() |
|
||||
p1.grad.data.zero_() |
|
||||
|
|
||||
new_context = [] |
|
||||
for p0, p1 in context: |
|
||||
new_context.append((p0.detach(), p1.detach())) |
|
||||
context = new_context |
|
||||
|
|
||||
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) |
|
||||
for p_ in context_delta] |
|
||||
context = list(map(add_context, context, context_delta)) |
|
||||
|
|
||||
new_context = [] |
|
||||
for p0, p1 in context: |
|
||||
new_context.append((p0.detach(), p1.detach())) |
|
||||
context = new_context |
|
||||
|
|
||||
return context |
|
||||
|
|
||||
def update_special_tokens_logits(self, context_tokens, i, logits): |
|
||||
for beam_id in range(context_tokens.shape[0]): |
|
||||
for token_idx in set(context_tokens[beam_id][-4:].tolist()): |
|
||||
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) |
|
||||
logits[beam_id, token_idx] /= factor |
|
||||
|
|
||||
if i >= self.ef_idx: |
|
||||
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) |
|
||||
logits[beam_id, self.end_token] *= factor |
|
||||
if i == 0: |
|
||||
start_factor = 1.6 |
|
||||
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) |
|
||||
logits[beam_id, self.end_token] /= factor |
|
||||
|
|
||||
for token_idx in list(self.forbidden_tokens): |
|
||||
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) |
|
||||
logits[beam_id, token_idx] /= factor |
|
||||
|
|
||||
return logits |
|
||||
|
|
||||
def clip_loss(self, probs, context_tokens): |
|
||||
for p_ in self.clip.transformer.parameters(): |
|
||||
if p_.grad is not None: |
|
||||
p_.grad.data.zero_() |
|
||||
|
|
||||
top_size = 512 |
|
||||
_, top_indices = probs.topk(top_size, -1) |
|
||||
|
|
||||
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] |
|
||||
|
|
||||
clip_loss = 0 |
|
||||
losses = [] |
|
||||
for idx_p in range(probs.shape[0]): |
|
||||
top_texts = [] |
|
||||
prefix_text = prefix_texts[idx_p] |
|
||||
for x in top_indices[idx_p]: |
|
||||
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|
||||
text_features = self.get_txt_features(top_texts) |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
similiraties = (self.image_features @ text_features.T) |
|
||||
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|
||||
target_probs = target_probs.type(torch.float32) |
|
||||
|
|
||||
target = torch.zeros_like(probs[idx_p]) |
|
||||
target[top_indices[idx_p]] = target_probs[0] |
|
||||
target = target.unsqueeze(0) |
|
||||
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
|
||||
|
|
||||
clip_loss += cur_clip_loss |
|
||||
losses.append(cur_clip_loss) |
|
||||
|
|
||||
return clip_loss, losses |
|
@ -1,449 +0,0 @@ |
|||||
import numpy as np |
|
||||
from torch import nn |
|
||||
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer |
|
||||
from transformers.models.gpt_neo import GPTNeoForCausalLM |
|
||||
import torch |
|
||||
import clip |
|
||||
from PIL import Image |
|
||||
from datetime import datetime |
|
||||
import sys |
|
||||
|
|
||||
class TextCLIP(nn.Module): |
|
||||
def __init__(self, model): |
|
||||
super(TextCLIP, self).__init__() |
|
||||
self.model = model |
|
||||
|
|
||||
def forward(self, text): |
|
||||
return self.model.encode_text(text) |
|
||||
|
|
||||
|
|
||||
class ImageCLIP(nn.Module): |
|
||||
def __init__(self, model): |
|
||||
super(ImageCLIP, self).__init__() |
|
||||
self.model = model |
|
||||
|
|
||||
def forward(self, image): |
|
||||
return self.model.encode_image(image) |
|
||||
|
|
||||
def log_info(text, verbose=True): |
|
||||
if verbose: |
|
||||
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
|
||||
print(f'{dt_string} | {text}') |
|
||||
sys.stdout.flush() |
|
||||
|
|
||||
|
|
||||
def add_context(x, y): |
|
||||
return (x[0] + y[0], x[1] + y[1]) |
|
||||
|
|
||||
|
|
||||
def convert_models_to_fp32(model): |
|
||||
for p in model.parameters(): |
|
||||
p.data = p.data.float() |
|
||||
|
|
||||
|
|
||||
class CLIPTextGenerator: |
|
||||
def __init__(self, |
|
||||
seed=0, |
|
||||
lm_model='gpt-2', |
|
||||
forbidden_tokens_file_path='./forbidden_tokens.npy', |
|
||||
clip_checkpoints='./clip_checkpoints', |
|
||||
target_seq_length=15, |
|
||||
reset_context_delta=True, |
|
||||
num_iterations=5, |
|
||||
clip_loss_temperature=0.01, |
|
||||
clip_scale=1., |
|
||||
ce_scale=0.2, |
|
||||
stepsize=0.3, |
|
||||
grad_norm_factor=0.9, |
|
||||
fusion_factor=0.99, |
|
||||
repetition_penalty=1., |
|
||||
end_token='.', |
|
||||
end_factor=1.01, |
|
||||
forbidden_factor=20, |
|
||||
**kwargs): |
|
||||
|
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
||||
|
|
||||
# set Random seed |
|
||||
torch.manual_seed(seed) |
|
||||
np.random.seed(seed) |
|
||||
|
|
||||
# Initialize Language model |
|
||||
self.context_prefix = '' |
|
||||
|
|
||||
if lm_model == 'gpt-neo': |
|
||||
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') |
|
||||
self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) |
|
||||
elif lm_model == 'gpt-2': |
|
||||
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') |
|
||||
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) |
|
||||
self.context_prefix = self.lm_tokenizer.bos_token |
|
||||
|
|
||||
self.lm_model.to(self.device) |
|
||||
self.lm_model.eval() |
|
||||
|
|
||||
self.forbidden_tokens = np.load(forbidden_tokens_file_path) |
|
||||
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if |
|
||||
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())] |
|
||||
|
|
||||
# Freeze LM weights |
|
||||
for param in self.lm_model.parameters(): |
|
||||
param.requires_grad = False |
|
||||
|
|
||||
# Initialize CLIP |
|
||||
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, |
|
||||
download_root=clip_checkpoints, jit=False) |
|
||||
self.clip_image = ImageCLIP(self.clip) |
|
||||
self.clip_image = torch.nn.DataParallel(self.clip_image) |
|
||||
self.clip_text = TextCLIP(self.clip) |
|
||||
self.clip_text = torch.nn.DataParallel(self.clip_text) |
|
||||
|
|
||||
# Init arguments |
|
||||
self.target_seq_length = target_seq_length |
|
||||
self.reset_context_delta = reset_context_delta |
|
||||
self.num_iterations = num_iterations |
|
||||
self.clip_loss_temperature = clip_loss_temperature |
|
||||
self.clip_scale = clip_scale |
|
||||
self.ce_scale = ce_scale |
|
||||
self.stepsize = stepsize |
|
||||
self.grad_norm_factor = grad_norm_factor |
|
||||
self.fusion_factor = fusion_factor |
|
||||
self.repetition_penalty = repetition_penalty |
|
||||
self.end_token = self.lm_tokenizer.encode(end_token)[0] |
|
||||
self.end_factor = end_factor |
|
||||
self.ef_idx = 1 |
|
||||
self.forbidden_factor = forbidden_factor |
|
||||
|
|
||||
def get_img_feature(self, img_path, weights): |
|
||||
imgs = [Image.open(x) for x in img_path] |
|
||||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
image_fts = [self.clip_image(x) for x in clip_imgs] |
|
||||
|
|
||||
if weights is not None: |
|
||||
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) |
|
||||
else: |
|
||||
image_features = sum(image_fts) |
|
||||
|
|
||||
image_features = torch.nn.functional.normalize(image_features, dim=-1) |
|
||||
return image_features.detach() |
|
||||
|
|
||||
def get_txt_features(self, text): |
|
||||
clip_texts = clip.tokenize(text).to(self.device) |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
text_features = self.clip_text(clip_texts) |
|
||||
|
|
||||
text_features = torch.nn.functional.normalize(text_features, dim=-1) |
|
||||
return text_features.detach() |
|
||||
|
|
||||
def get_combined_feature(self, img_path, texts, weights_i, weights_t): |
|
||||
imgs = [Image.open(x) for x in img_path] |
|
||||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|
||||
clip_texts = [clip.tokenize(x).to(self.device) for x in texts] |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
|
||||
text_fts = [self.clip.encode_text(x) for x in clip_texts] |
|
||||
|
|
||||
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) |
|
||||
if weights_t is not None: |
|
||||
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) |
|
||||
|
|
||||
features = features / features.norm(dim=-1, keepdim=True) |
|
||||
return features.detach() |
|
||||
|
|
||||
def run(self, image_features, cond_text, beam_size): |
|
||||
self.image_features = image_features |
|
||||
|
|
||||
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) |
|
||||
|
|
||||
output_tokens, output_text = self.generate_text(context_tokens, beam_size) |
|
||||
|
|
||||
return output_text |
|
||||
|
|
||||
def generate_text(self, context_tokens, beam_size): |
|
||||
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) |
|
||||
|
|
||||
gen_tokens = None |
|
||||
scores = None |
|
||||
seq_lengths = torch.ones(beam_size, device=self.device) |
|
||||
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) |
|
||||
|
|
||||
for i in range(self.target_seq_length): |
|
||||
probs = self.get_next_probs(i, context_tokens) |
|
||||
logits = probs.log() |
|
||||
|
|
||||
if scores is None: |
|
||||
scores, next_tokens = logits.topk(beam_size, -1) |
|
||||
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) |
|
||||
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
|
||||
|
|
||||
if gen_tokens is None: |
|
||||
gen_tokens = next_tokens |
|
||||
else: |
|
||||
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) |
|
||||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) |
|
||||
else: |
|
||||
logits[is_stopped] = -float(np.inf) |
|
||||
logits[is_stopped, 0] = 0 |
|
||||
scores_sum = scores[:, None] + logits |
|
||||
seq_lengths[~is_stopped] += 1 |
|
||||
scores_sum_average = scores_sum / seq_lengths[:, None] |
|
||||
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( |
|
||||
beam_size, -1) |
|
||||
next_tokens_source = next_tokens // scores_sum.shape[1] |
|
||||
seq_lengths = seq_lengths[next_tokens_source] |
|
||||
next_tokens = next_tokens % scores_sum.shape[1] |
|
||||
next_tokens = next_tokens.unsqueeze(1) |
|
||||
gen_tokens = gen_tokens[next_tokens_source] |
|
||||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) |
|
||||
context_tokens = context_tokens[next_tokens_source] |
|
||||
scores = scores_sum_average * seq_lengths |
|
||||
is_stopped = is_stopped[next_tokens_source] |
|
||||
|
|
||||
context_tokens = torch.cat((context_tokens, next_tokens), dim=1) |
|
||||
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() |
|
||||
|
|
||||
#### |
|
||||
tmp_scores = scores / seq_lengths |
|
||||
tmp_output_list = gen_tokens.cpu().numpy() |
|
||||
tmp_output_texts = [ |
|
||||
self.lm_tokenizer.decode(tmp_output) |
|
||||
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) |
|
||||
] |
|
||||
tmp_order = tmp_scores.argsort(descending=True) |
|
||||
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] |
|
||||
log_info(tmp_output_texts, verbose=True) |
|
||||
#### |
|
||||
|
|
||||
if is_stopped.all(): |
|
||||
break |
|
||||
|
|
||||
scores = scores / seq_lengths |
|
||||
output_list = gen_tokens.cpu().numpy() |
|
||||
output_texts = [ |
|
||||
self.lm_tokenizer.decode(output[: int(length)]) |
|
||||
for output, length in zip(output_list, seq_lengths) |
|
||||
] |
|
||||
order = scores.argsort(descending=True) |
|
||||
output_texts = [output_texts[i] for i in order] |
|
||||
|
|
||||
return context_tokens, output_texts |
|
||||
|
|
||||
def get_next_probs(self, i, context_tokens): |
|
||||
last_token = context_tokens[:, -1:] |
|
||||
|
|
||||
if self.reset_context_delta and context_tokens.size(1) > 1: |
|
||||
context = self.lm_model(context_tokens[:, :-1])["past_key_values"] |
|
||||
|
|
||||
# Logits of LM with unshifted context |
|
||||
logits_before_shift = self.lm_model(context_tokens)["logits"] |
|
||||
logits_before_shift = logits_before_shift[:, -1, :] |
|
||||
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) |
|
||||
|
|
||||
if context: |
|
||||
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) |
|
||||
|
|
||||
lm_output = self.lm_model(last_token, past_key_values=context) |
|
||||
logits, past = ( |
|
||||
lm_output["logits"], |
|
||||
lm_output["past_key_values"], |
|
||||
) |
|
||||
logits = logits[:, -1, :] |
|
||||
|
|
||||
logits = self.update_special_tokens_logits(context_tokens, i, logits) |
|
||||
|
|
||||
probs = nn.functional.softmax(logits, dim=-1) |
|
||||
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) |
|
||||
probs = probs / probs.sum() |
|
||||
|
|
||||
return probs |
|
||||
|
|
||||
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): |
|
||||
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] |
|
||||
|
|
||||
for i in range(self.num_iterations): |
|
||||
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in |
|
||||
context_delta] |
|
||||
|
|
||||
for p0, p1 in curr_shift: |
|
||||
p0.retain_grad() |
|
||||
p1.retain_grad() |
|
||||
|
|
||||
shifted_context = list(map(add_context, context, curr_shift)) |
|
||||
|
|
||||
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) |
|
||||
logits = shifted_outputs["logits"][:, -1, :] |
|
||||
probs = nn.functional.softmax(logits, dim=-1) |
|
||||
|
|
||||
loss = 0.0 |
|
||||
|
|
||||
# CLIP LOSS |
|
||||
clip_loss, clip_losses = self.clip_loss(probs, context_tokens) |
|
||||
loss += self.clip_scale * clip_loss |
|
||||
|
|
||||
# CE/Fluency loss |
|
||||
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) |
|
||||
loss += ce_loss.sum() |
|
||||
|
|
||||
loss.backward() |
|
||||
|
|
||||
# --------- Specific Gen --------- |
|
||||
final_grads = self.norm_grad(context, context_tokens, curr_shift) |
|
||||
|
|
||||
# --------- update context --------- |
|
||||
context_delta = list(map(add_context, final_grads, context_delta)) |
|
||||
|
|
||||
for p0, p1 in curr_shift: |
|
||||
p0.grad.data.zero_() |
|
||||
p1.grad.data.zero_() |
|
||||
|
|
||||
new_context = [] |
|
||||
for p0, p1 in context: |
|
||||
new_context.append((p0.detach(), p1.detach())) |
|
||||
context = new_context |
|
||||
|
|
||||
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) |
|
||||
for p_ in context_delta] |
|
||||
context = list(map(add_context, context, context_delta)) |
|
||||
|
|
||||
new_context = [] |
|
||||
for p0, p1 in context: |
|
||||
new_context.append((p0.detach(), p1.detach())) |
|
||||
context = new_context |
|
||||
|
|
||||
return context |
|
||||
|
|
||||
def norm_grad(self, context, context_tokens, curr_shift, ): |
|
||||
factor = 1 |
|
||||
sep_grads = None |
|
||||
window_mask = torch.ones_like(context[0][0]).to(self.device) |
|
||||
|
|
||||
for b in range(context_tokens.shape[0]): |
|
||||
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] |
|
||||
for p_ in curr_shift] |
|
||||
|
|
||||
# normalize gradients |
|
||||
tmp_grad = [tuple([-self.stepsize * factor * ( |
|
||||
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ |
|
||||
j] ** self.grad_norm_factor).data.cpu().numpy() |
|
||||
for j, x in enumerate(p_)]) |
|
||||
for i, p_ in enumerate(curr_shift)] |
|
||||
if sep_grads is None: |
|
||||
sep_grads = tmp_grad |
|
||||
else: |
|
||||
for l_index in range(len(sep_grads)): |
|
||||
sep_grads[l_index] = list(sep_grads[l_index]) |
|
||||
for k_index in range(len(sep_grads[0])): |
|
||||
sep_grads[l_index][k_index] = np.concatenate( |
|
||||
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) |
|
||||
sep_grads[l_index] = tuple(sep_grads[l_index]) |
|
||||
final_grads = sep_grads |
|
||||
|
|
||||
return final_grads |
|
||||
|
|
||||
def update_special_tokens_logits(self, context_tokens, i, logits): |
|
||||
for beam_id in range(context_tokens.shape[0]): |
|
||||
for token_idx in set(context_tokens[beam_id][-4:].tolist()): |
|
||||
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) |
|
||||
logits[beam_id, token_idx] /= factor |
|
||||
|
|
||||
if i >= self.ef_idx: |
|
||||
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) |
|
||||
logits[beam_id, self.end_token] *= factor |
|
||||
if i == 0: |
|
||||
start_factor = 1.6 |
|
||||
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) |
|
||||
logits[beam_id, self.end_token] /= factor |
|
||||
|
|
||||
for token_idx in list(self.forbidden_tokens): |
|
||||
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) |
|
||||
logits[beam_id, token_idx] /= factor |
|
||||
|
|
||||
return logits |
|
||||
|
|
||||
def clip_loss(self, probs, context_tokens): |
|
||||
for p_ in self.clip.transformer.parameters(): |
|
||||
if p_.grad is not None: |
|
||||
p_.grad.data.zero_() |
|
||||
|
|
||||
top_size = 512 |
|
||||
top_probs, top_indices = probs.topk(top_size, -1) |
|
||||
|
|
||||
prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] |
|
||||
|
|
||||
clip_loss = 0 |
|
||||
losses = [] |
|
||||
|
|
||||
top_texts = [] |
|
||||
for idx_p in range(probs.shape[0]): |
|
||||
prefix_text = prefix_texts[idx_p] |
|
||||
for x in top_indices[idx_p]: |
|
||||
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|
||||
|
|
||||
text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) |
|
||||
similiraties = similiraties.reshape(probs.size(0), -1) |
|
||||
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|
||||
target_probs = target_probs.type(torch.float32) |
|
||||
|
|
||||
clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) |
|
||||
# for idx_p in range(probs.shape[0]): |
|
||||
# top_texts = [] |
|
||||
# prefix_text = prefix_texts[idx_p] |
|
||||
# for x in top_indices[idx_p]: |
|
||||
# top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|
||||
# text_features = self.get_txt_features(top_texts) |
|
||||
# |
|
||||
# with torch.no_grad(): |
|
||||
# similiraties = (self.image_features @ text_features.T) |
|
||||
# target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|
||||
# target_probs = target_probs.type(torch.float32) |
|
||||
# |
|
||||
# target = torch.zeros_like(probs[idx_p]) |
|
||||
# target[top_indices[idx_p]] = target_probs[0] |
|
||||
# target = target.unsqueeze(0) |
|
||||
# cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
|
||||
# |
|
||||
# clip_loss += cur_clip_loss |
|
||||
# losses.append(cur_clip_loss) |
|
||||
|
|
||||
return clip_loss, losses |
|
||||
|
|
||||
def clip_loss_old(self, probs, context_tokens): |
|
||||
for p_ in self.clip.transformer.parameters(): |
|
||||
if p_.grad is not None: |
|
||||
p_.grad.data.zero_() |
|
||||
|
|
||||
top_size = 512 |
|
||||
_, top_indices = probs.topk(top_size, -1) |
|
||||
|
|
||||
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] |
|
||||
|
|
||||
clip_loss = 0 |
|
||||
losses = [] |
|
||||
for idx_p in range(probs.shape[0]): |
|
||||
top_texts = [] |
|
||||
prefix_text = prefix_texts[idx_p] |
|
||||
for x in top_indices[idx_p]: |
|
||||
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|
||||
text_features = self.get_txt_features(top_texts) |
|
||||
|
|
||||
with torch.no_grad(): |
|
||||
similiraties = (self.image_features @ text_features.T) |
|
||||
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|
||||
target_probs = target_probs.type(torch.float32) |
|
||||
|
|
||||
target = torch.zeros_like(probs[idx_p]) |
|
||||
target[top_indices[idx_p]] = target_probs[0] |
|
||||
target = target.unsqueeze(0) |
|
||||
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
|
||||
|
|
||||
clip_loss += cur_clip_loss |
|
||||
losses.append(cur_clip_loss) |
|
||||
|
|
||||
return clip_loss, losses |
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,14 +0,0 @@ |
|||||
#!/bin/bash |
|
||||
|
|
||||
# lm_model: |
|
||||
# 1. cambridgeltl/magic_mscoco |
|
||||
# 2. cambridgeltl/magic_flickr30k |
|
||||
CUDA_VISIBLE_DEVICES=1 python run.py \ |
|
||||
--beam_size 1 \ |
|
||||
--target_seq_length 16 \ |
|
||||
--reset_context_delta \ |
|
||||
--lm_model cambridgeltl/magic_mscoco \ |
|
||||
--test_image_prefix_path ../data/mscoco/test_images \ |
|
||||
--test_path ../data/mscoco/mscoco_test.json \ |
|
||||
--save_path_prefix ../inference_result/mscoco/baselines/ \ |
|
||||
--save_name zerocap_result.json |
|
@ -1,117 +0,0 @@ |
|||||
import os |
|
||||
import tempfile |
|
||||
import sys |
|
||||
sys.path.append('CLIP') |
|
||||
from pathlib import Path |
|
||||
import cog |
|
||||
import argparse |
|
||||
import torch |
|
||||
import clip |
|
||||
from model.ZeroCLIP import CLIPTextGenerator |
|
||||
|
|
||||
def perplexity_score(text, lm_model, lm_tokenizer, device): |
|
||||
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') |
|
||||
input_ids = encodings.input_ids.to(device) |
|
||||
target_ids = input_ids.clone() |
|
||||
|
|
||||
outputs = lm_model(input_ids, labels=target_ids) |
|
||||
log_likelihood = outputs[0] |
|
||||
ll = log_likelihood.item() |
|
||||
|
|
||||
return ll |
|
||||
|
|
||||
class Predictor(cog.Predictor): |
|
||||
def setup(self): |
|
||||
self.args = get_args() |
|
||||
self.args.reset_context_delta = True |
|
||||
self.text_generator = CLIPTextGenerator(**vars(self.args)) |
|
||||
|
|
||||
@cog.input( |
|
||||
"image", |
|
||||
type=Path, |
|
||||
help="input image" |
|
||||
) |
|
||||
@cog.input( |
|
||||
"cond_text", |
|
||||
type=str, |
|
||||
default='Image of a', |
|
||||
help="conditional text", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"beam_size", |
|
||||
type=int, |
|
||||
default=5, min=1, max=10, |
|
||||
help="Number of beams to use", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"end_factor", |
|
||||
type=float, |
|
||||
default=1.01, min=1.0, max=1.10, |
|
||||
help="Higher value for shorter captions", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"max_seq_length", |
|
||||
type=int, |
|
||||
default=15, min=1, max=20, |
|
||||
help="Maximum number of tokens to generate", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"ce_loss_scale", |
|
||||
type=float, |
|
||||
default=0.2, min=0.0, max=0.6, |
|
||||
help="Scale of cross-entropy loss with un-shifted language model", |
|
||||
) |
|
||||
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): |
|
||||
self.args.cond_text = cond_text |
|
||||
self.text_generator.end_factor = end_factor |
|
||||
self.text_generator.target_seq_length = max_seq_length |
|
||||
self.text_generator.ce_scale = ce_loss_scale |
|
||||
|
|
||||
image_features = self.text_generator.get_img_feature([str(image)], None) |
|
||||
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) |
|
||||
|
|
||||
# CLIP SCORE |
|
||||
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) |
|
||||
for c in captions] |
|
||||
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
|
||||
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
|
||||
|
|
||||
# Perplexity SCORE |
|
||||
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] |
|
||||
best_ppl_index = torch.tensor(ppl_scores).argmin().item() |
|
||||
|
|
||||
best_clip_caption = self.args.cond_text + captions[best_clip_idx] |
|
||||
best_mixed = self.args.cond_text + captions[0] |
|
||||
best_PPL = self.args.cond_text + captions[best_ppl_index] |
|
||||
|
|
||||
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' |
|
||||
|
|
||||
return final |
|
||||
# return self.args.cond_text + captions[best_clip_idx] |
|
||||
|
|
||||
|
|
||||
def get_args(): |
|
||||
parser = argparse.ArgumentParser() |
|
||||
|
|
||||
parser.add_argument("--seed", type=int, default=0) |
|
||||
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
|
||||
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
|
||||
parser.add_argument("--target_seq_length", type=int, default=15) |
|
||||
parser.add_argument("--cond_text", type=str, default="Image of a") |
|
||||
parser.add_argument("--reset_context_delta", action="store_true", |
|
||||
help="Should we reset the context at each token gen") |
|
||||
parser.add_argument("--num_iterations", type=int, default=5) |
|
||||
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
|
||||
parser.add_argument("--clip_scale", type=float, default=1) |
|
||||
parser.add_argument("--ce_scale", type=float, default=0.2) |
|
||||
parser.add_argument("--stepsize", type=float, default=0.3) |
|
||||
parser.add_argument("--grad_norm_factor", type=float, default=0.9) |
|
||||
parser.add_argument("--fusion_factor", type=float, default=0.99) |
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1) |
|
||||
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
|
||||
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
|
||||
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
|
||||
parser.add_argument("--beam_size", type=int, default=5) |
|
||||
|
|
||||
args = parser.parse_args('') |
|
||||
return args |
|
@ -1,129 +0,0 @@ |
|||||
import os |
|
||||
import tempfile |
|
||||
import sys |
|
||||
sys.path.append('CLIP') |
|
||||
from pathlib import Path |
|
||||
import cog |
|
||||
import argparse |
|
||||
import torch |
|
||||
import clip |
|
||||
from model.ZeroCLIP import CLIPTextGenerator |
|
||||
|
|
||||
def perplexity_score(text, lm_model, lm_tokenizer, device): |
|
||||
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') |
|
||||
input_ids = encodings.input_ids.to(device) |
|
||||
target_ids = input_ids.clone() |
|
||||
|
|
||||
outputs = lm_model(input_ids, labels=target_ids) |
|
||||
log_likelihood = outputs[0] |
|
||||
ll = log_likelihood.item() |
|
||||
|
|
||||
return ll |
|
||||
|
|
||||
class Predictor(cog.Predictor): |
|
||||
def setup(self): |
|
||||
self.args = get_args() |
|
||||
self.args.reset_context_delta = True |
|
||||
self.text_generator = CLIPTextGenerator(**vars(self.args)) |
|
||||
|
|
||||
@cog.input( |
|
||||
"image1", |
|
||||
type=Path, |
|
||||
help="Final result will be: image1 + (image2 - image3)" |
|
||||
) |
|
||||
@cog.input( |
|
||||
"image2", |
|
||||
type=Path, |
|
||||
help="Final result will be: image1 + (image2 - image3)" |
|
||||
) |
|
||||
@cog.input( |
|
||||
"image3", |
|
||||
type=Path, |
|
||||
help="Final result will be: image1 + (image2 - image3)" |
|
||||
) |
|
||||
@cog.input( |
|
||||
"cond_text", |
|
||||
type=str, |
|
||||
default='Image of a', |
|
||||
help="conditional text", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"beam_size", |
|
||||
type=int, |
|
||||
default=3, min=1, max=10, |
|
||||
help="Number of beams to use", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"end_factors", |
|
||||
type=float, |
|
||||
default=1.06, min=1.0, max=1.10, |
|
||||
help="Higher value for shorter captions", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"max_seq_lengths", |
|
||||
type=int, |
|
||||
default=3, min=1, max=20, |
|
||||
help="Maximum number of tokens to generate", |
|
||||
) |
|
||||
@cog.input( |
|
||||
"ce_loss_scale", |
|
||||
type=float, |
|
||||
default=0.2, min=0.0, max=0.6, |
|
||||
help="Scale of cross-entropy loss with un-shifted language model", |
|
||||
) |
|
||||
def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): |
|
||||
self.args.cond_text = cond_text |
|
||||
self.text_generator.end_factor = end_factors |
|
||||
self.text_generator.target_seq_length = max_seq_lengths |
|
||||
self.text_generator.ce_scale = ce_loss_scale |
|
||||
self.text_generator.fusion_factor = 0.95 |
|
||||
self.text_generator.grad_norm_factor = 0.95 |
|
||||
|
|
||||
image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) |
|
||||
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) |
|
||||
|
|
||||
# CLIP SCORE |
|
||||
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) |
|
||||
for c in captions] |
|
||||
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
|
||||
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
|
||||
|
|
||||
# Perplexity SCORE |
|
||||
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] |
|
||||
best_ppl_index = torch.tensor(ppl_scores).argmin().item() |
|
||||
|
|
||||
best_clip_caption = self.args.cond_text + captions[best_clip_idx] |
|
||||
best_mixed = self.args.cond_text + captions[0] |
|
||||
best_PPL = self.args.cond_text + captions[best_ppl_index] |
|
||||
|
|
||||
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' |
|
||||
|
|
||||
return final |
|
||||
# return self.args.cond_text + captions[best_clip_idx] |
|
||||
|
|
||||
|
|
||||
def get_args(): |
|
||||
parser = argparse.ArgumentParser() |
|
||||
|
|
||||
parser.add_argument("--seed", type=int, default=0) |
|
||||
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
|
||||
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
|
||||
parser.add_argument("--target_seq_length", type=int, default=15) |
|
||||
parser.add_argument("--cond_text", type=str, default="Image of a") |
|
||||
parser.add_argument("--reset_context_delta", action="store_true", |
|
||||
help="Should we reset the context at each token gen") |
|
||||
parser.add_argument("--num_iterations", type=int, default=5) |
|
||||
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
|
||||
parser.add_argument("--clip_scale", type=float, default=1) |
|
||||
parser.add_argument("--ce_scale", type=float, default=0.2) |
|
||||
parser.add_argument("--stepsize", type=float, default=0.3) |
|
||||
parser.add_argument("--grad_norm_factor", type=float, default=0.95) |
|
||||
parser.add_argument("--fusion_factor", type=float, default=0.95) |
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1) |
|
||||
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
|
||||
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
|
||||
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
|
||||
parser.add_argument("--beam_size", type=int, default=5) |
|
||||
|
|
||||
args = parser.parse_args('') |
|
||||
return args |
|
@ -1,3 +0,0 @@ |
|||||
ftfy |
|
||||
regex |
|
||||
tqdm |
|
@ -1,131 +0,0 @@ |
|||||
import argparse |
|
||||
import ipdb |
|
||||
from tqdm import tqdm |
|
||||
import progressbar |
|
||||
import torch |
|
||||
import ipdb |
|
||||
import clip |
|
||||
from model.ZeroCLIP import CLIPTextGenerator |
|
||||
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu |
|
||||
|
|
||||
def get_args(): |
|
||||
parser = argparse.ArgumentParser() |
|
||||
|
|
||||
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") |
|
||||
parser.add_argument("--test_path", type=str) |
|
||||
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") |
|
||||
parser.add_argument("--save_name", type=str, help="the name of the saved file") |
|
||||
|
|
||||
parser.add_argument("--seed", type=int, default=0) |
|
||||
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
|
||||
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
|
||||
parser.add_argument("--target_seq_length", type=int, default=15) |
|
||||
parser.add_argument("--cond_text", type=str, default="Image of a") |
|
||||
parser.add_argument("--reset_context_delta", action="store_true", |
|
||||
help="Should we reset the context at each token gen") |
|
||||
parser.add_argument("--num_iterations", type=int, default=5) |
|
||||
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
|
||||
parser.add_argument("--clip_scale", type=float, default=1) |
|
||||
parser.add_argument("--ce_scale", type=float, default=0.2) |
|
||||
parser.add_argument("--stepsize", type=float, default=0.3) |
|
||||
parser.add_argument("--grad_norm_factor", type=float, default=0.9) |
|
||||
parser.add_argument("--fusion_factor", type=float, default=0.99) |
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1) |
|
||||
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
|
||||
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
|
||||
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
|
||||
parser.add_argument("--beam_size", type=int, default=1) |
|
||||
|
|
||||
parser.add_argument("--multi_gpu", action="store_true") |
|
||||
|
|
||||
parser.add_argument('--run_type', |
|
||||
default='caption', |
|
||||
nargs='?', |
|
||||
choices=['caption', 'arithmetics']) |
|
||||
|
|
||||
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', |
|
||||
help="Path to image for captioning") |
|
||||
|
|
||||
parser.add_argument("--arithmetics_imgs", nargs="+", |
|
||||
default=['example_images/arithmetics/woman2.jpg', |
|
||||
'example_images/arithmetics/king2.jpg', |
|
||||
'example_images/arithmetics/man2.jpg']) |
|
||||
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) |
|
||||
|
|
||||
args = parser.parse_args() |
|
||||
|
|
||||
return args |
|
||||
|
|
||||
def run(args, text_generator, img_path): |
|
||||
image_features = text_generator.get_img_feature([img_path], None) |
|
||||
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) |
|
||||
|
|
||||
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] |
|
||||
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
|
||||
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
|
||||
return captions |
|
||||
|
|
||||
|
|
||||
if __name__ == '__main__': |
|
||||
if torch.cuda.is_available(): |
|
||||
print ('Cuda is available.') |
|
||||
cuda_available = torch.cuda.is_available() |
|
||||
args = get_args() |
|
||||
device = torch.device('cuda') |
|
||||
|
|
||||
save_path_prefix = args.save_path_prefix |
|
||||
import os |
|
||||
if os.path.exists(save_path_prefix): |
|
||||
pass |
|
||||
else: # recursively construct directory |
|
||||
os.makedirs(save_path_prefix, exist_ok=True) |
|
||||
# parse save name |
|
||||
save_name = args.save_name |
|
||||
full_save_path = save_path_prefix + '/' + save_name |
|
||||
print ('full save path is {}'.format(full_save_path)) |
|
||||
|
|
||||
print ('Loading data...') |
|
||||
import json |
|
||||
with open(args.test_path) as f: |
|
||||
item_list = json.load(f) |
|
||||
print ('Data loaded.') |
|
||||
print ('Number of test instances is {}'.format(len(item_list))) |
|
||||
|
|
||||
# ZeroCap generator |
|
||||
text_generator = CLIPTextGenerator(**vars(args)) |
|
||||
|
|
||||
result_list = [] |
|
||||
invalid_num = 0 |
|
||||
print ('----------------------------------------------------------------') |
|
||||
test_num = len(item_list) |
|
||||
#test_num = 10 |
|
||||
print ('Number of inference instances is {}'.format(test_num)) |
|
||||
p = progressbar.ProgressBar(test_num) |
|
||||
p.start() |
|
||||
for p_idx in tqdm(range(test_num)): |
|
||||
p.update(p_idx) |
|
||||
one_test_dict = item_list[p_idx] |
|
||||
|
|
||||
one_res_dict = { |
|
||||
'split':one_test_dict['split'], |
|
||||
'image_name':one_test_dict['image_name'], |
|
||||
#'file_path':one_test_dict['file_path'], |
|
||||
'captions':one_test_dict['captions'] |
|
||||
} |
|
||||
|
|
||||
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] |
|
||||
try: |
|
||||
output_text = run(args, text_generator, img_path=image_full_path) |
|
||||
one_res_dict['prediction'] = output_text[0] |
|
||||
result_list.append(one_res_dict) |
|
||||
except Exception as error: |
|
||||
print(f'[!] ERROR:', error) |
|
||||
invalid_num += 1 |
|
||||
print ('invalid number is {}'.format(invalid_num)) |
|
||||
continue |
|
||||
p.finish() |
|
||||
print ('Inference completed!') |
|
||||
|
|
||||
import json |
|
||||
with open(full_save_path, 'w') as outfile: |
|
||||
json.dump(result_list, outfile, indent=4) |
|
@ -1,19 +0,0 @@ |
|||||
import os |
|
||||
|
|
||||
import pkg_resources |
|
||||
from setuptools import setup, find_packages |
|
||||
|
|
||||
setup( |
|
||||
name="zero-shot-image-to-text", |
|
||||
py_modules=["zero-shot-image-to-text"], |
|
||||
version="1.0", |
|
||||
description="", |
|
||||
packages=find_packages(), |
|
||||
install_requires=[ |
|
||||
str(r) |
|
||||
for r in pkg_resources.parse_requirements( |
|
||||
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) |
|
||||
) |
|
||||
], |
|
||||
include_package_data=True |
|
||||
) |
|
Loading…
Reference in new issue