magic
copied
wxywb
2 years ago
32 changed files with 1334 additions and 1376 deletions
Binary file not shown.
Binary file not shown.
@ -1,2 +1,81 @@ |
|||
# magic |
|||
# Image Captioning with MAGIC |
|||
|
|||
*author: David Wang* |
|||
|
|||
|
|||
<br /> |
|||
|
|||
|
|||
## Description |
|||
|
|||
This operator generates the caption with [MAGIC](https://arxiv.org/abs/2205.02655) which describes the content of the given image. MAGIC is a simple yet efficient plug-and-play framework, which directly combines an off-the-shelf LM (i.e., GPT-2) and an image-text matching model (i.e., CLIP) for image-grounded text generation. During decoding, MAGIC influences the generation of the LM by introducing a CLIP-induced score, called magic score, which regularizes the generated result to be semantically related to a given image while being coherent to the previously generated context. This is an adaptation from [yxuansu / MAGIC](https://github.com/yxuansu/MAGIC). |
|||
|
|||
|
|||
<br /> |
|||
|
|||
|
|||
## Code Example |
|||
|
|||
Load an image from path './image.jpg' to generate the caption. |
|||
|
|||
*Write the pipeline in simplified style*: |
|||
|
|||
```python |
|||
import towhee |
|||
|
|||
towhee.glob('./image.jpg') \ |
|||
.image_decode() \ |
|||
.image_captioning.magic(model_name='expansionnet_rf') \ |
|||
.show() |
|||
``` |
|||
<img src="./cap.png" alt="result1" style="height:20px;"/> |
|||
|
|||
*Write a same pipeline with explicit inputs/outputs name specifications:* |
|||
|
|||
```python |
|||
import towhee |
|||
|
|||
towhee.glob['path']('./image.jpg') \ |
|||
.image_decode['path', 'img']() \ |
|||
.image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \ |
|||
.select['img', 'text']() \ |
|||
.show() |
|||
``` |
|||
<img src="./tabular.png" alt="result2" style="height:60px;"/> |
|||
|
|||
|
|||
<br /> |
|||
|
|||
|
|||
## Factory Constructor |
|||
|
|||
Create the operator via the following factory method |
|||
|
|||
***expansionnet_v2(model_name)*** |
|||
|
|||
**Parameters:** |
|||
|
|||
***model_name:*** *str* |
|||
|
|||
The model name of MAGIC. Supported model names: |
|||
- magic_mscoco |
|||
|
|||
<br /> |
|||
|
|||
## Interface |
|||
|
|||
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption. |
|||
|
|||
|
|||
**Parameters:** |
|||
|
|||
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* |
|||
|
|||
The image to generate embedding. |
|||
|
|||
|
|||
|
|||
**Returns:** *str* |
|||
|
|||
The caption generated by model. |
|||
|
|||
|
@ -0,0 +1,167 @@ |
|||
## Unsupervised Domain Adaptation of Language Model |
|||
**** |
|||
### Catalogue: |
|||
* <a href='#mscoco'>1. MSCOCO Benchmark</a> |
|||
* <a href='#mscoco_data_preparation'>1.1. MSCOCO Data Preparation</a> |
|||
* <a href='#mscoco_training'>1.2. Unsupervised Domain Adaptation on MSCOCO</a> |
|||
* <a href='#flickr30k'>2. Flickr30k Benchmark</a> |
|||
* <a href='#flickr30k_data_preparation'>2.1. Flickr30k Data Preparation</a> |
|||
* <a href='#flickr30k_training'>2.2. Unsupervised Domain Adaptation on Flickr30k</a> |
|||
* <a href='#unsupervised_baselines'>3. Unsupervised Baselines</a> |
|||
* <a href='#contrastive_search'>3.1. Contrastive Search</a> |
|||
* <a href='#top_k_sampling'>3.2. Top-k Sampling</a> |
|||
* <a href='#nucleus_sampling'>3.3. Nucleus Sampling</a> |
|||
|
|||
**** |
|||
<span id='mscoco'/> |
|||
|
|||
#### 1. MSCOCO Benchmark: |
|||
|
|||
We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark. |
|||
|
|||
<span id='mscoco_data_preparation'/> |
|||
|
|||
##### 1.1. MSCOCO Data Preparation: |
|||
|
|||
To prepare the MSCOCO benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark). |
|||
|
|||
<span id='mscoco_training'/> |
|||
|
|||
##### 1.2.Unsupervised Domain Adaptation on MSCOCO: |
|||
After preparing the MSCOCO data, run the following command to train the language model. |
|||
```yaml |
|||
chmod +x ./train_mscoco.sh |
|||
./train_mscoco.sh |
|||
``` |
|||
The arguments are as follows: |
|||
* `--model_name`: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large). |
|||
* `--train_path`: The file path of training set. |
|||
* `--dev_path`: The file path of validation set. |
|||
* `--test_path`: The file path of test set. |
|||
* `--add_eos_token_to_data`: Whether adding an eos token at the end of text sequence. |
|||
* `--margin`: The contrastive margin $\rho$. |
|||
* `--max_len`: The maximum length of training samples. |
|||
* `--number_of_gpu`: The number of available GPUs. |
|||
* `--batch_size_per_gpu`: The batch size for each GPU. |
|||
* `--gradient_accumulation_steps`: How many forward computations between two gradient updates. |
|||
* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu. |
|||
* `--total_steps`: The number of total gradient update steps. |
|||
* `--print_every`: Have many steps to show the intermediate results. |
|||
* `--save_every`: How many steps to save one checkpoint. |
|||
* `--learning_rate`: The learning rate. |
|||
* `--save_path_prefix`: Where to save the checkpoints. |
|||
|
|||
**** |
|||
<span id='flickr30k'/> |
|||
|
|||
#### 2. Flickr30k Benchmark: |
|||
|
|||
We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark. |
|||
|
|||
<span id='flickr30k_data_preparation'/> |
|||
|
|||
##### 2.1. Flickr30k Data Preparation: |
|||
|
|||
To prepare the Flickr30k benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark). |
|||
|
|||
<span id='flickr30k_training'/> |
|||
|
|||
##### 2.2. Unsupervised Domain Adaptation on Flickr30k: |
|||
After preparing the Flickr30k data, run the following command to train the language model. |
|||
```yaml |
|||
chmod +x ./train_flickr30k.sh |
|||
./train_flickr30k.sh |
|||
``` |
|||
|
|||
**** |
|||
<span id='unsupervised_baselines'/> |
|||
|
|||
#### 3. Unsupervised Baselines: |
|||
|
|||
Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are **unsupervised** as the language model is a text-only model and does not take image as input. |
|||
|
|||
```python |
|||
# first, load the language model |
|||
import torch |
|||
from simctg import SimCTG |
|||
sos_token, pad_token = r'<-start_of_text->', r'<-pad->' |
|||
# we use the language model adapted on MSCOCO as an example. |
|||
language_model_name = r'cambridgeltl/magic_mscoco' |
|||
generation_model = SimCTG(language_model_name, sos_token, pad_token) |
|||
generation_model.eval() |
|||
|
|||
# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token. |
|||
tokens = generation_model.tokenizer.tokenize(sos_token) |
|||
input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens) |
|||
input_ids = torch.LongTensor(input_ids).view(1,-1) |
|||
``` |
|||
|
|||
<span id='contrastive_search'/> |
|||
|
|||
##### 3.1. Contrastive Search : |
|||
```python |
|||
''' |
|||
use contrastive search to generate the result. |
|||
note that, contrastive search is a deterministic decoding method, thus the generated text is always the same. |
|||
''' |
|||
beam_width, alpha, decoding_len = 45, 0.1, 16 |
|||
output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len) |
|||
print (output_text) |
|||
''' |
|||
A man is riding a skateboard down a street. |
|||
''' |
|||
``` |
|||
The arguments are as follows: |
|||
* `--input_ids`: The id of the start of sentence token. |
|||
* `--beam_width`: k in the contrastive search. |
|||
* `--alpha`: alpha in the contrastive search. |
|||
* `--decoding_len`: Number of tokens to generate. |
|||
|
|||
<span id='top_k_sampling'/> |
|||
|
|||
##### 3.2. Top-k Sampling : |
|||
```python |
|||
''' |
|||
use top-k sampling to generate the result. |
|||
note that, the this method is a stochastic method, thus the generated text is always different. |
|||
''' |
|||
top_k, decoding_len = 40, 16 |
|||
output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len) |
|||
print (output_text) |
|||
''' |
|||
some very different types of vases with flowers together |
|||
''' |
|||
``` |
|||
The arguments are as follows: |
|||
* `--input_ids`: The id of the start of sentence token. |
|||
* `--k`: The k in top-k sampling. |
|||
* `--decoding_len`: Number of tokens to generate. |
|||
|
|||
<span id='nucleus_sampling'/> |
|||
|
|||
##### 3.3. Nucleus Sampling : |
|||
```python |
|||
''' |
|||
use nucleus sampling to generate the result. |
|||
note that, the this method is a stochastic method, thus the generated text is always different. |
|||
''' |
|||
nucleus_p, decoding_len = 0.95, 16 |
|||
output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len) |
|||
print (output_text) |
|||
''' |
|||
Two young girls enjoying a hot dog hot dog bun. |
|||
''' |
|||
``` |
|||
The arguments are as follows: |
|||
* `--input_ids`: The id of the start of sentence token. |
|||
* `--nucleus_p`: The probability in nucleus sampling. |
|||
* `--decoding_len`: Number of tokens to generate. |
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
|||
|
@ -0,0 +1,157 @@ |
|||
import json |
|||
import random |
|||
import torch |
|||
import numpy as np |
|||
import progressbar |
|||
from torch.nn.utils import rnn |
|||
|
|||
class Data: |
|||
def __init__(self, model_name, train_path, dev_path, test_path, max_len, |
|||
sos_token, pad_token, add_eos_token_to_data): |
|||
''' |
|||
model_name: gpt2 |
|||
train_path: training data path |
|||
dev_path: validation data path |
|||
test_path: test data path |
|||
max_len: maximum length for training sequences |
|||
sos_token: initialized sos token <-start_of_text-> |
|||
pad_token: used to pad the sequences <-pad-> |
|||
add_eos_token_to_data: whether we want to the model learn to generate eos token; |
|||
if so, the model could automatically stop generation by generating eos token |
|||
''' |
|||
from transformers import GPT2TokenizerFast |
|||
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name) |
|||
self.sos_token, self.sos_token_id = self.add_special_token(sos_token) |
|||
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) |
|||
self.pad_token, self.pad_token_id = self.add_special_token(pad_token) |
|||
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) |
|||
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id |
|||
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) |
|||
self.add_eos_token_to_data = add_eos_token_to_data |
|||
|
|||
self.max_len = max_len |
|||
self.train_token_list, self.train_token_id_list = self.process_one_file(train_path) |
|||
self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path) |
|||
self.test_token_list, self.test_token_id_list = self.process_one_file(test_path) |
|||
self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \ |
|||
len(self.test_token_list) |
|||
print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num)) |
|||
|
|||
self.train_idx_list = [i for i in range(self.train_num)] |
|||
random.shuffle(self.train_idx_list) |
|||
self.dev_idx_list = [j for j in range(self.dev_num)] |
|||
self.test_idx_list = [j for j in range(self.test_num)] |
|||
self.dev_current_idx, self.test_current_idx = 0, 0 |
|||
|
|||
def add_special_token(self, special_token): |
|||
if special_token in self.tokenizer.vocab: |
|||
print (special_token + ' token exists.') |
|||
else: |
|||
print ('Add token to the tokenizer.') |
|||
print ('Original vocabulary size is {}'.format(len(self.tokenizer))) |
|||
self.tokenizer.add_tokens([special_token]) |
|||
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) |
|||
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 |
|||
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] |
|||
return special_token, special_token_id |
|||
|
|||
def process_one_file(self, path): |
|||
print ('Processing {}'.format(path)) |
|||
with open(path) as f: |
|||
item_list = json.load(f) |
|||
lines = [] |
|||
for item in item_list: |
|||
captions_list = item['captions'] |
|||
for one_caption in captions_list: |
|||
lines.append(one_caption.strip()) |
|||
|
|||
res_token_list, res_token_id_list = [], [] |
|||
n = len(lines) |
|||
p = progressbar.ProgressBar(n) |
|||
p.start() |
|||
for i in range(n): |
|||
p.update(i) |
|||
text = lines[i].strip('\n') |
|||
self.process_one_text(text, res_token_list, res_token_id_list) |
|||
p.finish() |
|||
print ('{} processed!'.format(path)) |
|||
return res_token_list, res_token_id_list |
|||
|
|||
def process_one_text(self, text, res_token_list, res_token_id_list): |
|||
tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True) |
|||
if len(tokens) <= 1: # filter out too short sequence |
|||
return |
|||
tokens = [self.sos_token] + tokens[:self.max_len] |
|||
if self.add_eos_token_to_data: |
|||
tokens = tokens + [self.eos_token] |
|||
token_ids = self.tokenizer.convert_tokens_to_ids(tokens) |
|||
res_token_list.append(tokens) |
|||
res_token_id_list.append(token_ids) |
|||
return |
|||
|
|||
def pad_batch(self, batch_id_list): |
|||
batch_id_list = [torch.LongTensor(item) for item in batch_id_list] |
|||
batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id) |
|||
batch_mask = torch.ones_like(batch_tensor) |
|||
batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor) |
|||
return batch_tensor, batch_mask |
|||
|
|||
def process_output(self, batch_tgt_id_list): |
|||
batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list] |
|||
batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence |
|||
batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone() |
|||
batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone() |
|||
return batch_tgt_input_tensor, batch_tgt_output_tensor |
|||
|
|||
def parse_batch(self, batch_id_list): |
|||
batch_input, batch_labels = self.process_output(batch_id_list) |
|||
batch_labels[batch_labels[:, :] == self.pad_token_id] = -100 |
|||
return batch_input, batch_labels |
|||
|
|||
def get_next_train_batch(self, batch_size): |
|||
batch_idx_list = random.sample(self.train_idx_list, batch_size) |
|||
batch_id_list, batch_token_list = [], [] |
|||
|
|||
for idx in batch_idx_list: |
|||
batch_id_list.append(self.train_token_id_list[idx]) |
|||
batch_token_list.append(self.train_token_list[idx]) |
|||
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) |
|||
return batch_input_tensor, batch_labels, batch_token_list |
|||
|
|||
def get_next_validation_batch(self, batch_size, mode): |
|||
batch_id_list, batch_token_list = [], [] |
|||
if mode == 'dev': |
|||
curr_select_idx, instance_num = self.dev_current_idx, self.dev_num |
|||
tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list |
|||
elif mode == 'test': |
|||
curr_select_idx, instance_num = self.test_current_idx, self.test_num |
|||
tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list |
|||
else: |
|||
raise Exception('Wrong Validation Mode!!!') |
|||
|
|||
if curr_select_idx + batch_size < instance_num: |
|||
for i in range(batch_size): |
|||
curr_idx = curr_select_idx + i |
|||
batch_id_list.append(tgt_token_id_list[curr_idx]) |
|||
batch_token_list.append(tgt_token_list[curr_idx]) |
|||
if mode == 'dev': |
|||
self.dev_current_idx += batch_size |
|||
else: |
|||
self.test_current_idx += batch_size |
|||
else: |
|||
for i in range(batch_size): |
|||
curr_idx = curr_select_idx + i |
|||
if curr_idx > instance_num - 1: |
|||
curr_idx = 0 |
|||
if mode == 'dev': |
|||
self.dev_current_idx = 0 |
|||
else: |
|||
self.test_current_idx = 0 |
|||
batch_id_list.append(tgt_token_id_list[curr_idx]) |
|||
batch_token_list.append(tgt_token_list[curr_idx]) |
|||
if mode == 'dev': |
|||
self.dev_current_idx = 0 |
|||
else: |
|||
self.test_current_idx = 0 |
|||
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list) |
|||
return batch_input_tensor, batch_labels, batch_token_list |
@ -0,0 +1,80 @@ |
|||
import torch |
|||
|
|||
def compute_valid_token_num(valid_len_list): |
|||
res = 0 |
|||
for one_len in valid_len_list: |
|||
res += one_len * (one_len - 1) |
|||
return res |
|||
|
|||
def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0): |
|||
''' |
|||
prefix_len: the length of prefix that we do not want to compute CL loss for. |
|||
|
|||
(1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4), |
|||
then the loss padding matrix looks like |
|||
[0., 1., 1., 1.], |
|||
[1., 0., 1., 1.], |
|||
[1., 1., 0., 1.], |
|||
[1., 1., 1., 0.] |
|||
|
|||
(2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3), |
|||
then the loss padding matrix looks like |
|||
[0., 1., 1., 0.], |
|||
[1., 0., 1., 0.], |
|||
[1., 1., 0., 0.], |
|||
[0., 0., 0., 0.] |
|||
''' |
|||
res_list = [] |
|||
base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen) |
|||
base_mask = base_mask.type(torch.FloatTensor) |
|||
bsz = len(valid_len_list) |
|||
for i in range(bsz): |
|||
one_base_mask = base_mask.clone() |
|||
one_valid_len = valid_len_list[i] |
|||
one_base_mask[:,one_valid_len:] = 0. |
|||
one_base_mask[one_valid_len:, :] = 0. |
|||
if prefix_len > 0: |
|||
one_base_mask[:prefix_len, :prefix_len] = 0. |
|||
res_list.append(one_base_mask) |
|||
res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list) |
|||
#print (res_mask) |
|||
assert res_mask.size() == torch.Size([bsz, seqlen, seqlen]) |
|||
return res_mask |
|||
|
|||
def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0): |
|||
''' |
|||
margin: predefined margin to push similarity score away |
|||
score_matrix: bsz x seqlen x seqlen |
|||
input_ids: bsz x seqlen |
|||
pad_token_id: indicating which tokens are padding token |
|||
''' |
|||
bsz, seqlen, _ = score_matrix.size() |
|||
gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen |
|||
gold_score = torch.unsqueeze(gold_score, -1) |
|||
assert gold_score.size() == torch.Size([bsz, seqlen, 1]) |
|||
difference_matrix = gold_score - score_matrix |
|||
assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen]) |
|||
loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen |
|||
loss_matrix = torch.nn.functional.relu(loss_matrix) |
|||
|
|||
### input mask |
|||
input_mask = torch.ones_like(input_ids).type(torch.FloatTensor) |
|||
if loss_matrix.is_cuda: |
|||
input_mask = input_mask.cuda(loss_matrix.get_device()) |
|||
input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0) |
|||
|
|||
if loss_matrix.is_cuda: |
|||
input_mask = input_mask.cuda(loss_matrix.get_device()) |
|||
|
|||
valid_len_list = torch.sum(input_mask, dim = -1).tolist() |
|||
loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len) |
|||
if score_matrix.is_cuda: |
|||
loss_mask = loss_mask.cuda(score_matrix.get_device()) |
|||
masked_loss_matrix = loss_matrix * loss_mask |
|||
|
|||
loss_matrix = torch.sum(masked_loss_matrix, dim = -1) |
|||
assert loss_matrix.size() == input_ids.size() |
|||
loss_matrix = loss_matrix * input_mask |
|||
cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask) |
|||
return cl_loss |
|||
|
@ -0,0 +1,233 @@ |
|||
import os |
|||
import sys |
|||
import operator |
|||
from tqdm import tqdm |
|||
from operator import itemgetter |
|||
import torch |
|||
from torch import nn |
|||
import random |
|||
import argparse |
|||
import numpy as np |
|||
import torch.nn.functional as F |
|||
from torch.nn import CrossEntropyLoss |
|||
from loss_func import contrastive_loss |
|||
from utlis import PlugAndPlayContrastiveDecodingOneStepFast |
|||
|
|||
import seaborn as sns |
|||
import matplotlib.pyplot as plt |
|||
import pandas as pd |
|||
import datetime |
|||
|
|||
train_fct = CrossEntropyLoss() |
|||
val_fct = CrossEntropyLoss(reduction='none') |
|||
class SimCTG(nn.Module): |
|||
def __init__(self, model_name, sos_token, pad_token): |
|||
super(SimCTG, self).__init__() |
|||
from transformers import AutoTokenizer, GPT2LMHeadModel |
|||
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|||
self.sos_token, self.sos_token_id = self.add_special_token(sos_token) |
|||
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) |
|||
self.pad_token, self.pad_token_id = self.add_special_token(pad_token) |
|||
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) |
|||
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id |
|||
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) |
|||
self.model = GPT2LMHeadModel.from_pretrained(model_name) |
|||
self.vocab_size = len(self.tokenizer) |
|||
print ('Resizing model embedding...') |
|||
self.model.resize_token_embeddings(len(self.tokenizer)) |
|||
print ('Model embedding resized!') |
|||
self.embed_dim = self.model.config.hidden_size |
|||
|
|||
def add_special_token(self, special_token): |
|||
if special_token in self.tokenizer.vocab: |
|||
print (special_token + ' token exists.') |
|||
else: |
|||
print ('Add token to the tokenizer.') |
|||
print ('Original vocabulary size is {}'.format(len(self.tokenizer))) |
|||
self.tokenizer.add_tokens([special_token]) |
|||
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) |
|||
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 |
|||
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] |
|||
return special_token, special_token_id |
|||
|
|||
def compute_logits_and_hidden_states(self, input_ids): |
|||
# used for advanced decoding |
|||
# input_ids: 1 x seqlen |
|||
outputs = self.model(input_ids=input_ids, output_hidden_states=True) |
|||
last_hidden_states = outputs.hidden_states[-1] |
|||
logits = outputs.logits |
|||
return last_hidden_states, logits |
|||
|
|||
def forward(self, input_ids, labels, margin): |
|||
bsz, seqlen = input_ids.size() |
|||
outputs = self.model(input_ids=input_ids, output_hidden_states=True) |
|||
logits = outputs.logits |
|||
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) |
|||
last_hidden_states = outputs.hidden_states[-1] |
|||
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) |
|||
mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1)) |
|||
|
|||
norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True) |
|||
cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2)) |
|||
assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen]) |
|||
cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0) |
|||
return mle_loss, cl_loss |
|||
|
|||
def eval_loss(self, input_ids, labels): |
|||
bsz, seqlen = input_ids.size() |
|||
outputs = self.model(input_ids=input_ids, output_hidden_states=True) |
|||
logits = outputs.logits |
|||
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) |
|||
last_hidden_states = outputs.hidden_states[-1] |
|||
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) |
|||
mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1)) |
|||
assert mle_loss.size() == torch.Size([bsz * seqlen]) |
|||
mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0) |
|||
mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0) |
|||
# sum |
|||
mle_loss_sum = torch.sum(mle_loss) |
|||
token_num_sum = torch.sum(mask) |
|||
return mle_loss_sum, token_num_sum |
|||
|
|||
def save_model(self, ckpt_save_path): |
|||
import os |
|||
if os.path.exists(ckpt_save_path): |
|||
pass |
|||
else: # recursively construct directory |
|||
os.makedirs(ckpt_save_path, exist_ok=True) |
|||
# save model |
|||
self.model.save_pretrained(ckpt_save_path) |
|||
# save tokenizer |
|||
self.tokenizer.save_pretrained(ckpt_save_path) |
|||
|
|||
def parse_sentences(self, text, num_of_sentences_to_keep): |
|||
item_list = text.split('.') |
|||
res_list = item_list[:num_of_sentences_to_keep] |
|||
if len(item_list) > num_of_sentences_to_keep: |
|||
res_text = '.'.join(res_list).strip('.') + '.' |
|||
else: |
|||
res_text = '.'.join(res_list).strip('.').strip() |
|||
return res_text |
|||
|
|||
def parse_generated_result(self, output, num_of_sentences_to_keep): |
|||
output_text = self.tokenizer.decode(output) |
|||
item_list = output_text.split(self.eos_token) |
|||
full_text = self.eos_token.join(item_list[:2]).strip() |
|||
full_text = self.parse_sentences(full_text, num_of_sentences_to_keep) |
|||
generated_text = item_list[1].strip() |
|||
generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep) |
|||
return full_text, generated_text |
|||
|
|||
# decoding functions |
|||
# ------------------------------------------------------- # |
|||
|
|||
def parse_output_token_list(self, output): |
|||
output = output.tolist() |
|||
res_list = [] |
|||
for token_id in output: |
|||
if token_id == self.sos_token_id: |
|||
continue |
|||
elif token_id == self.eos_token_id: |
|||
break |
|||
else: |
|||
res_list.append(token_id) |
|||
text = self.tokenizer.decode(res_list).strip() |
|||
return ' '.join(text.split()).strip() |
|||
|
|||
@torch.no_grad() |
|||
def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip, |
|||
clip_text_max_len):#, add_token_level_score=False): |
|||
prefix_len = input_ids.size()[1] |
|||
#from utlis import PlugAndPlayContrastiveDecodingOneStepFast |
|||
past_key_values, last_hidden_states, logits = None, None, None |
|||
generated = [item for item in input_ids.tolist()] |
|||
input_ids_for_class = input_ids.clone() |
|||
|
|||
image_embeds = clip.compute_image_representation_from_image_instance(image_instance) |
|||
|
|||
start_time = datetime.datetime.now() |
|||
|
|||
# the maximum supported length of generation for SimCTG is 256 |
|||
# to support longer generated length, you can re-train the SimCTG model with longer sequences |
|||
decoding_len = decoding_len - prefix_len |
|||
for step in range(decoding_len): |
|||
input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \ |
|||
PlugAndPlayContrastiveDecodingOneStepFast( |
|||
self.model, |
|||
input_ids, |
|||
prefix_len, |
|||
beam_width, |
|||
alpha, |
|||
beta, |
|||
self.tokenizer, |
|||
image_embeds, |
|||
clip, |
|||
clip_text_max_len, |
|||
past_key_values, |
|||
last_hidden_states, |
|||
logits, |
|||
first_step=step==0, |
|||
input_ids_for_class=input_ids_for_class, |
|||
) |
|||
end_time = datetime.datetime.now() |
|||
time_diff = (end_time - start_time) |
|||
execution_time = time_diff.total_seconds() * 1000 |
|||
return self.parse_output_token_list(input_ids_for_class[0]) |
|||
|
|||
def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len): |
|||
''' |
|||
input_ids: prefix input; 1 x prefix_len |
|||
decoding_len: how many tokens to generate |
|||
beam_width: size of candidate pool during decoding |
|||
alpha: regulates importance of model confidence and degeneration penalty |
|||
''' |
|||
self.model.eval() |
|||
#from utlis import ContrastiveDecodingOneStepFast |
|||
# sanity check |
|||
assert alpha >= 0. and alpha <= 1.0 |
|||
|
|||
# fast mode |
|||
prefix_len = input_ids.size()[1] |
|||
batch_size, seqlen = input_ids.size() |
|||
#generated = [[] for _ in range(batch_size)] |
|||
generated = [item for item in input_ids.tolist()] |
|||
past_key_values = None |
|||
last_hidden_states = None |
|||
logits = None |
|||
decoding_len = decoding_len - prefix_len |
|||
for step in range(decoding_len): |
|||
input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast( |
|||
self.model, |
|||
input_ids, |
|||
beam_width, |
|||
alpha, |
|||
past_key_values, |
|||
last_hidden_states, |
|||
self.tokenizer, |
|||
logits, |
|||
first_step=step == 0, |
|||
) |
|||
tokens = input_ids.squeeze(dim=-1).tolist() |
|||
for idx, t in enumerate(tokens): |
|||
generated[idx].append(t) |
|||
return self.parse_output_token_list(torch.LongTensor(generated[0])) |
|||
|
|||
def top_k_sampling(self, input_ids, k, decoding_len): |
|||
_, prefix_len = input_ids.size() |
|||
output = self.model.generate( |
|||
input_ids, |
|||
do_sample=True, |
|||
max_length=decoding_len, |
|||
top_p=1.0, |
|||
top_k=k) |
|||
return self.parse_output_token_list(output[0]) |
|||
|
|||
def nucleus_sampling(self, input_ids, nucleus_p, decoding_len): |
|||
_, prefix_len = input_ids.size() |
|||
output = self.model.generate( |
|||
input_ids, |
|||
do_sample=True, |
|||
max_length=decoding_len, |
|||
top_p=nucleus_p, |
|||
top_k=0) |
|||
return self.parse_output_token_list(output[0]) |
@ -0,0 +1,107 @@ |
|||
# coding=utf-8 |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
import torch.multiprocessing as mp |
|||
import argparse, os |
|||
import random |
|||
import numpy as np |
|||
import time |
|||
import logging |
|||
import progressbar |
|||
|
|||
import logging |
|||
logging.getLogger('transformers.generation_utils').disabled = True |
|||
|
|||
def parse_config(): |
|||
parser = argparse.ArgumentParser() |
|||
# data configuration |
|||
parser.add_argument("--model_name", type=str, default='gpt2') |
|||
parser.add_argument("--train_path", type=str) |
|||
parser.add_argument("--dev_path", type=str) |
|||
parser.add_argument("--test_path", type=str) |
|||
parser.add_argument("--max_len", type=int) |
|||
parser.add_argument("--add_eos_token_to_data", type=str) |
|||
# mini-batch training configuration |
|||
parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.") |
|||
parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.') |
|||
parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.") |
|||
parser.add_argument("--effective_batch_size", type=int, |
|||
help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps") |
|||
# pre-training configuration |
|||
parser.add_argument("--total_steps", type=int, |
|||
help="total effective training steps") |
|||
parser.add_argument("--print_every", type=int, |
|||
help="how many update steps to print one intermediate result") |
|||
parser.add_argument("--save_every", type=int, |
|||
help="how many update steps to save one model") |
|||
# learning configuration |
|||
parser.add_argument("--learning_rate", type=float, default=2e-5) |
|||
parser.add_argument("--margin", type=float) |
|||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") |
|||
parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.") |
|||
return parser.parse_args() |
|||
|
|||
def load_previous_best_model(path): |
|||
import os |
|||
filenames = os.listdir(path) |
|||
for file in filenames: |
|||
if file.startswith('training_step'): |
|||
return path + '/' + file |
|||
raise Exception('No best model found!') |
|||
|
|||
import argparse |
|||
if __name__ == '__main__': |
|||
if torch.cuda.is_available(): |
|||
print ('Cuda is available.') |
|||
cuda_available = torch.cuda.is_available() |
|||
multi_gpu_training = False |
|||
if cuda_available: |
|||
if torch.cuda.device_count() > 1: |
|||
multi_gpu_training = True |
|||
print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count())) |
|||
else: |
|||
print ('Using single GPU training.') |
|||
else: |
|||
pass |
|||
args = parse_config() |
|||
device = torch.device('cuda') |
|||
model_name = args.model_name |
|||
|
|||
sos_token, pad_token = r'<-start_of_text->', r'<-pad->' |
|||
add_eos_token_to_data = args.add_eos_token_to_data |
|||
if add_eos_token_to_data == 'True': |
|||
add_eos_token_to_data = True |
|||
print ('Add eos token to data!') |
|||
elif add_eos_token_to_data == 'False': |
|||
add_eos_token_to_data = False |
|||
print ('Do not add eos token to data!') |
|||
else: |
|||
raise Exception('Wrong eos configuration for data!!!') |
|||
print ('Loading data...') |
|||
from dataclass import Data |
|||
data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len, |
|||
sos_token, pad_token, add_eos_token_to_data) |
|||
print ('Data loaded.') |
|||
|
|||
from trainer import model_training |
|||
print ('############################################################') |
|||
print ('Start Training...') |
|||
from simctg import SimCTG |
|||
print ('Initializaing SimCTG model...') |
|||
model = SimCTG(model_name, sos_token, pad_token) |
|||
if cuda_available: |
|||
if multi_gpu_training: |
|||
model = nn.DataParallel(model) # multi-gpu training |
|||
else: |
|||
pass |
|||
model = model.to(device) |
|||
else: |
|||
pass |
|||
print ('Model loaded') |
|||
total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every |
|||
ckpt_save_path = args.save_path_prefix |
|||
model = model_training(args, data, model, total_steps, print_every, save_every, |
|||
ckpt_save_path, cuda_available, device) |
|||
print ('Training stage completed!') |
|||
print ('############################################################') |
@ -0,0 +1,17 @@ |
|||
CUDA_VISIBLE_DEVICES=0 python train.py\ |
|||
--model_name gpt2\ |
|||
--train_path ../data/flickr30k/flickr30k_train.json\ |
|||
--dev_path ../data/flickr30k/flickr30k_val.json\ |
|||
--test_path ../data/flickr30k/flickr30k_test.json\ |
|||
--add_eos_token_to_data True\ |
|||
--margin 0.5\ |
|||
--max_len 64\ |
|||
--number_of_gpu 1\ |
|||
--batch_size_per_gpu 32\ |
|||
--gradient_accumulation_steps 4\ |
|||
--effective_batch_size 128\ |
|||
--total_steps 10000\ |
|||
--print_every 50\ |
|||
--save_every 250\ |
|||
--learning_rate 2e-5\ |
|||
--save_path_prefix ./magic_flickr30k/ |
@ -0,0 +1,17 @@ |
|||
CUDA_VISIBLE_DEVICES=0 python train.py\ |
|||
--model_name gpt2\ |
|||
--train_path ../data/mscoco/mscoco_train.json\ |
|||
--dev_path ../data/mscoco/mscoco_val.json\ |
|||
--test_path ../data/mscoco/mscoco_test.json\ |
|||
--add_eos_token_to_data True\ |
|||
--margin 0.5\ |
|||
--max_len 64\ |
|||
--number_of_gpu 1\ |
|||
--batch_size_per_gpu 32\ |
|||
--gradient_accumulation_steps 4\ |
|||
--effective_batch_size 128\ |
|||
--total_steps 20000\ |
|||
--print_every 100\ |
|||
--save_every 500\ |
|||
--learning_rate 2e-5\ |
|||
--save_path_prefix ./magic_mscoco/ |
@ -0,0 +1,165 @@ |
|||
# coding=utf-8 |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
import torch.multiprocessing as mp |
|||
import argparse, os |
|||
import random |
|||
import numpy as np |
|||
import time |
|||
import logging |
|||
import progressbar |
|||
|
|||
import logging |
|||
logging.getLogger('transformers.generation_utils').disabled = True |
|||
|
|||
def eval_model(args, model, data, cuda_available, device): |
|||
dataset_batch_size = args.batch_size_per_gpu * args.number_of_gpu |
|||
eval_step = int(data.test_num / dataset_batch_size) + 1 |
|||
val_loss, token_sum = 0., 0. |
|||
model.eval() |
|||
with torch.no_grad(): |
|||
p = progressbar.ProgressBar(eval_step) |
|||
p.start() |
|||
for idx in range(eval_step): |
|||
p.update(idx) |
|||
batch_input_tensor, batch_labels, _ = \ |
|||
data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test') |
|||
if cuda_available: |
|||
batch_input_tensor = batch_input_tensor.cuda(device) |
|||
batch_labels = batch_labels.cuda(device) |
|||
one_val_loss, one_val_token_sum = model.eval_loss(batch_input_tensor, batch_labels) |
|||
one_val_loss = torch.sum(one_val_loss) |
|||
one_val_token_sum = torch.sum(one_val_token_sum) |
|||
val_loss += one_val_loss.item() |
|||
token_sum += one_val_token_sum.item() |
|||
p.finish() |
|||
model.train() |
|||
val_loss = val_loss / token_sum |
|||
return val_loss |
|||
|
|||
def model_training(args, data, model, total_steps, print_every, save_every, ckpt_save_path, cuda_available, device): |
|||
import os |
|||
if os.path.exists(ckpt_save_path): |
|||
pass |
|||
else: # recursively construct directory |
|||
os.makedirs(ckpt_save_path, exist_ok=True) |
|||
|
|||
max_save_num = 1 |
|||
|
|||
batch_size_per_gpu, gradient_accumulation_steps, number_of_gpu, effective_batch_size = \ |
|||
args.batch_size_per_gpu, args.gradient_accumulation_steps, args.number_of_gpu, args.effective_batch_size |
|||
assert effective_batch_size == batch_size_per_gpu * gradient_accumulation_steps * number_of_gpu |
|||
|
|||
warmup_steps = int(0.1 * total_steps) # 10% of training steps are used for warmup |
|||
print ('total training steps is {}, warmup steps is {}'.format(total_steps, warmup_steps)) |
|||
from transformers.optimization import AdamW, get_linear_schedule_with_warmup |
|||
optimizer = AdamW(model.parameters(), lr=args.learning_rate) |
|||
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) |
|||
optimizer.zero_grad() |
|||
|
|||
effective_batch_acm = 0 |
|||
all_batch_step = 1 |
|||
print_valid, save_valid = False, False |
|||
train_loss, train_cl_loss, min_val_loss = 0., 0., 1e10 |
|||
train_ave_bleu = 0. |
|||
|
|||
print ('--------------------------------------------------------------------------') |
|||
print ('Start Training:') |
|||
model.train() |
|||
number_of_saves = 0 |
|||
|
|||
while effective_batch_acm < total_steps: |
|||
all_batch_step += 1 |
|||
train_batch_input_tensor, train_batch_labels, _ = data.get_next_train_batch(batch_size_per_gpu * number_of_gpu) |
|||
if cuda_available: |
|||
train_batch_input_tensor = train_batch_input_tensor.cuda(device) |
|||
train_batch_labels = train_batch_labels.cuda(device) |
|||
mle_loss, cl_loss = model(train_batch_input_tensor, train_batch_labels, args.margin) |
|||
|
|||
loss = mle_loss + cl_loss |
|||
loss = loss.mean() |
|||
loss.backward() |
|||
train_loss += mle_loss.item() |
|||
train_cl_loss += cl_loss.item() |
|||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
|||
|
|||
# parameter update |
|||
if all_batch_step % gradient_accumulation_steps == 0: |
|||
optimizer.step() |
|||
scheduler.step() |
|||
optimizer.zero_grad() |
|||
effective_batch_acm += 1 |
|||
print_valid, save_valid = True, True |
|||
|
|||
# print intermediate result |
|||
if effective_batch_acm % print_every == 0 and print_valid: |
|||
denominator = (effective_batch_acm - (number_of_saves * save_every)) * gradient_accumulation_steps |
|||
one_train_loss = train_loss / denominator |
|||
one_train_cl_loss = train_cl_loss / denominator |
|||
print ('At training steps {}, training MLE loss is {}, train CL loss is {}'.format(effective_batch_acm, |
|||
one_train_loss, one_train_cl_loss)) |
|||
print_valid = False |
|||
|
|||
# saving result |
|||
if effective_batch_acm % save_every == 0 and save_valid: |
|||
number_of_saves += 1 |
|||
|
|||
save_valid = False |
|||
one_train_loss = train_loss / (save_every * gradient_accumulation_steps) |
|||
one_train_cl_loss = train_cl_loss / (save_every * gradient_accumulation_steps) |
|||
|
|||
model.eval() |
|||
one_val_loss = eval_model(args, model, data, cuda_available, device) |
|||
model.train() |
|||
|
|||
print ('At training steps {}, training MLE loss is {}, train CL loss is {}, validation loss is {}'.format(effective_batch_acm, |
|||
one_train_loss, one_train_cl_loss, one_val_loss)) |
|||
|
|||
train_loss, train_cl_loss = 0., 0. |
|||
|
|||
if one_val_loss < min_val_loss: |
|||
# in finetuning stage, we always save the model |
|||
min_val_loss = min(one_val_loss, min_val_loss) |
|||
print ('Saving model...') |
|||
one_val_ppl = np.exp(one_val_loss) |
|||
one_val_ppl = round(one_val_ppl, 3) |
|||
save_name = 'training_step_{}_train_mle_loss_{}_train_cl_loss_{}_dev_loss_{}_dev_ppl_{}'.format(effective_batch_acm, |
|||
round(one_train_loss,5), round(one_train_cl_loss,5), round(one_val_loss,5), one_val_ppl) |
|||
|
|||
model_save_path = ckpt_save_path + '/' + save_name |
|||
import os |
|||
if os.path.exists(model_save_path): |
|||
pass |
|||
else: # recursively construct directory |
|||
os.makedirs(model_save_path, exist_ok=True) |
|||
if cuda_available and torch.cuda.device_count() > 1: |
|||
model.module.save_model(model_save_path) |
|||
else: |
|||
model.save_model(model_save_path) |
|||
print ('Model Saved!') |
|||
|
|||
# --------------------------------------------------------------------------------------------- # |
|||
# removing extra checkpoints... |
|||
import os |
|||
from operator import itemgetter |
|||
fileData = {} |
|||
test_output_dir = ckpt_save_path |
|||
for fname in os.listdir(test_output_dir): |
|||
if fname.startswith('training_step'): |
|||
fileData[fname] = os.stat(test_output_dir + '/' + fname).st_mtime |
|||
else: |
|||
pass |
|||
sortedFiles = sorted(fileData.items(), key=itemgetter(1)) |
|||
|
|||
if len(sortedFiles) < max_save_num: |
|||
pass |
|||
else: |
|||
delete = len(sortedFiles) - max_save_num |
|||
for x in range(0, delete): |
|||
one_folder_name = test_output_dir + '/' + sortedFiles[x][0] |
|||
os.system('rm -r ' + one_folder_name) |
|||
print ('-----------------------------------') |
|||
# --------------------------------------------------------------------------------------------- # |
|||
return model |
|||
|
@ -0,0 +1,291 @@ |
|||
import sys |
|||
import os |
|||
import operator |
|||
from operator import itemgetter |
|||
import torch |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
import random |
|||
import numpy as np |
|||
import argparse |
|||
import random |
|||
|
|||
def parse_prompt(text): |
|||
''' |
|||
process the prompt text; |
|||
''' |
|||
eos_token = '<|endoftext|>' |
|||
text = text.strip(eos_token).strip() |
|||
left_bracket_idx, right_bracket_idx = -1, -1 |
|||
for idx in range(len(text)): |
|||
char = text[idx] |
|||
if char == '[' and left_bracket_idx == -1: # first [ is met |
|||
left_bracket_idx = idx |
|||
elif char == ']' and right_bracket_idx == -1: # first ] is met |
|||
right_bracket_idx = idx |
|||
else: |
|||
pass |
|||
res_text = '' |
|||
remove = False |
|||
if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx: |
|||
if right_bracket_idx - left_bracket_idx <= 6: |
|||
remove = True |
|||
else: |
|||
pass |
|||
|
|||
for idx in range(len(text)): |
|||
if remove: |
|||
if idx >= left_bracket_idx and idx <= right_bracket_idx: |
|||
continue |
|||
else: |
|||
res_text += text[idx] |
|||
else: |
|||
res_text += text[idx] |
|||
res_text = res_text.strip() |
|||
res_text = ' '.join(res_text.split()).strip() |
|||
return res_text |
|||
|
|||
def typical_filtering(scores, mass, min_tokens_to_keep, filter_value): |
|||
# calculate entropy |
|||
normalized = torch.nn.functional.log_softmax(scores, dim=-1) |
|||
p = torch.exp(normalized) |
|||
ent = -(normalized * p).nansum(-1, keepdim=True) |
|||
|
|||
# shift and sort |
|||
shifted_scores = torch.abs((-normalized) - ent) |
|||
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False) |
|||
sorted_logits = scores.gather(-1, sorted_indices) |
|||
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|||
|
|||
# Remove tokens with cumulative mass above the threshold |
|||
last_ind = (cumulative_probs < mass).sum(dim=1) |
|||
last_ind[last_ind < 0] = 0 |
|||
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) |
|||
if min_tokens_to_keep > 1: |
|||
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) |
|||
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0 |
|||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|||
|
|||
scores = scores.masked_fill(indices_to_remove, filter_value) |
|||
return scores |
|||
|
|||
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf): |
|||
assert logits.dim() == 1 |
|||
top_k = min(top_k, logits.size(-1)) |
|||
if top_k > 0: |
|||
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|||
logits[indices_to_remove] = filter_value |
|||
if top_p > 0.0: |
|||
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|||
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|||
sorted_indices_to_remove = cumulative_probs > top_p |
|||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|||
sorted_indices_to_remove[..., 0] = 0 |
|||
|
|||
indices_to_remove = sorted_indices[sorted_indices_to_remove] |
|||
logits[indices_to_remove] = filter_value |
|||
|
|||
indices_to_remove = logits < threshold |
|||
logits[indices_to_remove] = filter_value |
|||
return logits |
|||
|
|||
# ========== batch version ========= # |
|||
def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width): |
|||
''' |
|||
context_hidden: bsz*beam x seqlen x embed_dim |
|||
next_hidden: bsz*beam x 1 x embed_dim |
|||
next_top_k_probs: bsz x beam |
|||
''' |
|||
_, context_len, embed_dim = context_hidden.size() |
|||
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) |
|||
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) |
|||
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S] |
|||
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K] |
|||
next_top_k_probs = next_top_k_probs.view(-1) # [B*K] |
|||
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores |
|||
scores = torch.stack(torch.split(scores, beam_width)) # [B, K] |
|||
selected_idx = scores.max(dim=-1)[1] # [B] |
|||
return selected_idx |
|||
|
|||
def ContrastiveDecodingOneStepFast( |
|||
model, |
|||
ids, |
|||
beam_width, |
|||
alpha, |
|||
past_key_values, |
|||
last_hidden_states, |
|||
vocab, |
|||
logit_for_next_step, |
|||
first_step=False, |
|||
): |
|||
# input_ids: [B, S] |
|||
if first_step: |
|||
output = model( |
|||
input_ids=ids, |
|||
past_key_values=past_key_values, |
|||
use_cache=True, |
|||
output_hidden_states=True |
|||
) |
|||
past_key_values = output.past_key_values |
|||
last_hidden_states = output.hidden_states[-1] # [B, S, E] |
|||
logit_for_next_step = output.logits[:, -1, :] # [B, V] |
|||
bsz, seqlen, embed_dim = last_hidden_states.size() |
|||
p = random.uniform(0, 1) |
|||
|
|||
next_probs = F.softmax(logit_for_next_step, dim=-1) |
|||
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K] |
|||
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K] |
|||
# compute new hidden |
|||
past_key_values = enlarge_past_key_values(past_key_values, beam_width) |
|||
output = model( |
|||
input_ids=top_k_ids.view(-1, 1), |
|||
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), |
|||
past_key_values=past_key_values, |
|||
output_hidden_states=True, |
|||
use_cache=True, |
|||
) |
|||
past_key_values = output.past_key_values |
|||
logits = output.logits[:, -1, :] # [B*K, V] |
|||
next_hidden = output.hidden_states[-1] # [B*K, 1, E] |
|||
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E] |
|||
|
|||
selected_idx = ranking_fast( |
|||
context_hidden, |
|||
next_hidden, |
|||
top_k_probs, # [B, K] |
|||
alpha, |
|||
beam_width, |
|||
) # [B] |
|||
# prepare for the next step |
|||
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1] |
|||
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E] |
|||
next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E] |
|||
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E] |
|||
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) |
|||
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V] |
|||
# next_id: [B, 1] |
|||
return next_id, past_key_values, last_hidden_states, logits |
|||
|
|||
def enlarge_past_key_values(past_key_values, beam_width): |
|||
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz] |
|||
new_key_values = [] |
|||
for layer in past_key_values: |
|||
items = [] |
|||
for item in layer: |
|||
# item is the key and value matrix |
|||
bsz, num_head, seq_len, esz = item.size() |
|||
item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz] |
|||
items.append(item) |
|||
new_key_values.append(items) |
|||
return new_key_values |
|||
|
|||
def select_past_key_values(past_key_values, beam_width, selected_idx): |
|||
'''select_idx: [B]''' |
|||
new_key_values = [] |
|||
for layer in past_key_values: |
|||
items = [] |
|||
for item in layer: |
|||
bsz_and_beam, num_head, seq_len, esz = item.size() |
|||
bsz = int(bsz_and_beam//beam_width) |
|||
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz] |
|||
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz] |
|||
items.append(item) |
|||
new_key_values.append(items) |
|||
return new_key_values |
|||
|
|||
# ========== fast plug and play version ========= # |
|||
def plug_and_play_fast_ranking( |
|||
context_hidden, |
|||
next_hidden, |
|||
next_top_k_ids, |
|||
next_top_k_probs, |
|||
alpha, |
|||
beta, |
|||
batch_class_score, |
|||
beam_width, |
|||
): |
|||
''' |
|||
context_hidden: beam_width x context_len x embed_dim |
|||
next_hidden: beam_width x 1 x embed_dim |
|||
next_top_k_ids: beam_width x 1 |
|||
batch_class_score: beam_width x 1 |
|||
''' |
|||
_, context_len, embed_dim = context_hidden.size() |
|||
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True) |
|||
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True) |
|||
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) |
|||
scores, _ = torch.max(cosine_matrix, dim = -1) |
|||
next_top_k_probs = next_top_k_probs.view(-1) |
|||
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width]) |
|||
scores = torch.stack(torch.split(scores, beam_width)) |
|||
selected_idx = scores.max(dim=-1)[1] |
|||
return selected_idx |
|||
|
|||
def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta, |
|||
simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states, |
|||
logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False): |
|||
''' |
|||
model: the generation model, e.g., gpt2 |
|||
input_ids: 1 x seqlen |
|||
''' |
|||
|
|||
if first_step: |
|||
output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True) |
|||
past_key_values = output.past_key_values |
|||
last_hidden_states = output.hidden_states[-1] # [B, S, E] |
|||
logit_for_next_step = output.logits[:, -1, :] # [B, V] |
|||
bsz, seqlen, embed_dim = last_hidden_states.size() |
|||
next_probs = F.softmax(logit_for_next_step, dim = -1) |
|||
_, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width) |
|||
top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids) |
|||
|
|||
# compute the new hidden |
|||
past_key_values = enlarge_past_key_values(past_key_values, beam_width) |
|||
output = model( |
|||
input_ids=top_k_ids.view(-1, 1) , |
|||
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)), |
|||
past_key_values=past_key_values, |
|||
output_hidden_states=True, |
|||
use_cache=True, |
|||
) |
|||
past_key_values = output.past_key_values |
|||
logits = output.logits[:, -1, :] |
|||
next_hidden = output.hidden_states[-1] |
|||
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) |
|||
|
|||
# prepare for the classification model |
|||
input_ids_for_class_ = torch.cat([ |
|||
input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen), |
|||
top_k_ids.view(-1, 1) |
|||
], dim=-1 |
|||
) |
|||
|
|||
batch_text_list = [] |
|||
for one_input_id in input_ids_for_class_: |
|||
one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:]) |
|||
# we only consider the class score of the generated text continuation |
|||
batch_text_list.append(one_text) |
|||
batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list) |
|||
|
|||
selected_idx = plug_and_play_fast_ranking( |
|||
context_hidden, |
|||
next_hidden, |
|||
top_k_ids, |
|||
top_k_probs, |
|||
alpha, |
|||
beta, |
|||
batch_score, |
|||
beam_width, |
|||
) |
|||
|
|||
# prepare for the next step |
|||
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) |
|||
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) |
|||
next_hidden = next_hidden[range(bsz), selected_idx, :] |
|||
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) |
|||
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx) |
|||
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] |
|||
input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1) |
|||
return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class |
|||
|
|||
|
@ -1,89 +0,0 @@ |
|||
### Our Implementation of the ZeroCap Baseline Model |
|||
|
|||
**** |
|||
### Catalogue: |
|||
* <a href='#environment'>1. Environment Preparation</a> |
|||
* <a href='#mscoco'>2. Image Captioning on MSCOCO</a> |
|||
* <a href='#flickr30k'>3. Image Captioning on Flickr30k</a> |
|||
* <a href='#flickr30k_to_mscoco'>4. Cross Domain Image Captioning on MSCOCO</a> |
|||
* <a href='#mscoco_to_flickr30k'>5. Cross Domain Image Captioning on Flickr30k</a> |
|||
* <a href='#citation'>6. Citation</a> |
|||
* <a href='#acknowledgements'>7. Acknowledgements</a> |
|||
|
|||
**** |
|||
|
|||
<span id='environment'/> |
|||
|
|||
#### 1. Environment Preparation: |
|||
To install the correct environment, please run the following command: |
|||
```yaml |
|||
pip install -r requirements.txt |
|||
``` |
|||
|
|||
**** |
|||
|
|||
<span id='mscoco'/> |
|||
|
|||
#### 2. Image Captioning on MSCOCO: |
|||
To perform image captioning on MSCOCO, please run the following command: |
|||
```yaml |
|||
chmod +x ./mscoco_zerocap.sh |
|||
./mscoco_zerocap.sh |
|||
``` |
|||
|
|||
**** |
|||
|
|||
<span id='flickr30k'/> |
|||
|
|||
#### 3. Image Captioning on Flickr30k: |
|||
To perform image captioning on Flickr30k, please run the following command: |
|||
```yaml |
|||
chmod +x ./flickr30k_zerocap.sh |
|||
./flickr30k_zerocap.sh |
|||
``` |
|||
|
|||
**** |
|||
|
|||
<span id='flickr30k_to_mscoco'/> |
|||
|
|||
#### 4. Cross Domain Image Captioning on MSCOCO: |
|||
To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command: |
|||
```yaml |
|||
chmod +x ./flickr30k_to_mscoco_zerocap.sh |
|||
./flickr30k_to_mscoco_zerocap.sh |
|||
``` |
|||
|
|||
**** |
|||
|
|||
<span id='mscoco_to_flickr30k'/> |
|||
|
|||
#### 5. Cross Domain Image Captioning on Flickr30k: |
|||
To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command: |
|||
```yaml |
|||
chmod +x ./mscoco_to_flickr30k_zerocap.sh |
|||
./mscoco_to_flickr30k_zerocap.sh |
|||
``` |
|||
|
|||
**** |
|||
|
|||
<span id='citation'/> |
|||
|
|||
#### 6. Citation: |
|||
If you find our code helpful, please cite the original paper as |
|||
|
|||
```bibtex |
|||
@article{tewel2021zero, |
|||
title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic}, |
|||
author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior}, |
|||
journal={arXiv preprint arXiv:2111.14447}, |
|||
year={2021} |
|||
} |
|||
``` |
|||
|
|||
**** |
|||
|
|||
<span id='acknowledgements'/> |
|||
|
|||
#### 7. Acknowledgements: |
|||
We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text). |
|||
|
@ -1,12 +0,0 @@ |
|||
build: |
|||
gpu: true |
|||
python_version: "3.8" |
|||
system_packages: |
|||
- "libgl1-mesa-glx" |
|||
- "libglib2.0-0" |
|||
python_packages: |
|||
- "git+https://github.com/openai/CLIP.git" |
|||
- "git+https://github.com/YoadTew/zero-shot-image-to-text.git" |
|||
|
|||
predict: "predict.py:Predictor" |
|||
#predict: "predict_arithmetic.py:Predictor" |
@ -1,14 +0,0 @@ |
|||
#!/bin/bash |
|||
|
|||
# lm_model: |
|||
# 1. cambridgeltl/magic_mscoco |
|||
# 2. cambridgeltl/magic_flickr30k |
|||
CUDA_VISIBLE_DEVICES=1 python run.py \ |
|||
--beam_size 1 \ |
|||
--target_seq_length 16 \ |
|||
--reset_context_delta \ |
|||
--lm_model cambridgeltl/magic_flickr30k \ |
|||
--test_image_prefix_path ../data/flickr30k/test_images \ |
|||
--test_path ../data/flickr30k/flickr30k_test.json \ |
|||
--save_path_prefix ../inference_result/flickr30k/baselines/ \ |
|||
--save_name zerocap_result.json |
Binary file not shown.
@ -1,389 +0,0 @@ |
|||
import numpy as np |
|||
from torch import nn |
|||
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer |
|||
from transformers.models.gpt_neo import GPTNeoForCausalLM |
|||
import torch |
|||
import clip |
|||
from PIL import Image |
|||
from datetime import datetime |
|||
import sys |
|||
|
|||
|
|||
def log_info(text, verbose=True): |
|||
if verbose: |
|||
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
|||
print(f'{dt_string} | {text}') |
|||
sys.stdout.flush() |
|||
|
|||
|
|||
def add_context(x, y): |
|||
return (x[0] + y[0], x[1] + y[1]) |
|||
|
|||
|
|||
def convert_models_to_fp32(model): |
|||
for p in model.parameters(): |
|||
p.data = p.data.float() |
|||
|
|||
|
|||
class CLIPTextGenerator: |
|||
def __init__(self, |
|||
seed=0, |
|||
lm_model='gpt-2', |
|||
forbidden_tokens_file_path='./forbidden_tokens.npy', |
|||
clip_checkpoints='./clip_checkpoints', |
|||
target_seq_length=15, |
|||
reset_context_delta=True, |
|||
num_iterations=5, |
|||
clip_loss_temperature=0.01, |
|||
clip_scale=1., |
|||
ce_scale=0.2, |
|||
stepsize=0.3, |
|||
grad_norm_factor=0.9, |
|||
fusion_factor=0.99, |
|||
repetition_penalty=1., |
|||
end_token='.', |
|||
end_factor=1.01, |
|||
forbidden_factor=20, |
|||
**kwargs): |
|||
|
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
|
|||
# set Random seed |
|||
torch.manual_seed(seed) |
|||
np.random.seed(seed) |
|||
|
|||
# Initialize Language model |
|||
self.context_prefix = '' |
|||
|
|||
self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model) |
|||
self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True) |
|||
self.context_prefix = self.lm_tokenizer.bos_token |
|||
|
|||
self.lm_model.to(self.device) |
|||
self.lm_model.eval() |
|||
|
|||
self.forbidden_tokens = np.load(forbidden_tokens_file_path) |
|||
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if |
|||
(x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] |
|||
|
|||
# Freeze LM weights |
|||
for param in self.lm_model.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
# Initialize CLIP |
|||
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, |
|||
download_root=clip_checkpoints, jit=False) |
|||
# convert_models_to_fp32(self.clip) |
|||
|
|||
# Init arguments |
|||
self.target_seq_length = target_seq_length |
|||
self.reset_context_delta = reset_context_delta |
|||
self.num_iterations = num_iterations |
|||
self.clip_loss_temperature = clip_loss_temperature |
|||
self.clip_scale = clip_scale |
|||
self.ce_scale = ce_scale |
|||
self.stepsize = stepsize |
|||
self.grad_norm_factor = grad_norm_factor |
|||
self.fusion_factor = fusion_factor |
|||
self.repetition_penalty = repetition_penalty |
|||
self.end_token = self.lm_tokenizer.encode(end_token)[0] |
|||
self.end_factor = end_factor |
|||
self.ef_idx = 1 |
|||
self.forbidden_factor = forbidden_factor |
|||
|
|||
def get_img_feature(self, img_path, weights): |
|||
imgs = [Image.open(x) for x in img_path] |
|||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|||
|
|||
with torch.no_grad(): |
|||
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
|||
|
|||
if weights is not None: |
|||
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) |
|||
else: |
|||
image_features = sum(image_fts) |
|||
|
|||
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|||
return image_features.detach() |
|||
|
|||
def get_txt_features(self, text): |
|||
clip_texts = clip.tokenize(text).to(self.device) |
|||
|
|||
with torch.no_grad(): |
|||
text_features = self.clip.encode_text(clip_texts) |
|||
|
|||
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|||
return text_features.detach() |
|||
|
|||
def get_combined_feature(self, img_path, texts, weights_i, weights_t): |
|||
imgs = [Image.open(x) for x in img_path] |
|||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|||
clip_texts = [clip.tokenize(x).to(self.device) for x in texts] |
|||
|
|||
with torch.no_grad(): |
|||
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
|||
text_fts = [self.clip.encode_text(x) for x in clip_texts] |
|||
|
|||
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) |
|||
if weights_t is not None: |
|||
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) |
|||
|
|||
features = features / features.norm(dim=-1, keepdim=True) |
|||
return features.detach() |
|||
|
|||
def run(self, image_features, cond_text, beam_size): |
|||
self.image_features = image_features |
|||
|
|||
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) |
|||
|
|||
output_tokens, output_text = self.generate_text(context_tokens, beam_size) |
|||
|
|||
return output_text |
|||
|
|||
def generate_text(self, context_tokens, beam_size): |
|||
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) |
|||
|
|||
gen_tokens = None |
|||
scores = None |
|||
seq_lengths = torch.ones(beam_size, device=self.device) |
|||
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) |
|||
|
|||
for i in range(self.target_seq_length): |
|||
probs = self.get_next_probs(i, context_tokens) |
|||
logits = probs.log() |
|||
|
|||
if scores is None: |
|||
scores, next_tokens = logits.topk(beam_size, -1) |
|||
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) |
|||
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
|||
|
|||
if gen_tokens is None: |
|||
gen_tokens = next_tokens |
|||
else: |
|||
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) |
|||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) |
|||
else: |
|||
logits[is_stopped] = -float(np.inf) |
|||
logits[is_stopped, 0] = 0 |
|||
scores_sum = scores[:, None] + logits |
|||
seq_lengths[~is_stopped] += 1 |
|||
scores_sum_average = scores_sum / seq_lengths[:, None] |
|||
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( |
|||
beam_size, -1) |
|||
next_tokens_source = next_tokens // scores_sum.shape[1] |
|||
seq_lengths = seq_lengths[next_tokens_source] |
|||
next_tokens = next_tokens % scores_sum.shape[1] |
|||
next_tokens = next_tokens.unsqueeze(1) |
|||
gen_tokens = gen_tokens[next_tokens_source] |
|||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) |
|||
context_tokens = context_tokens[next_tokens_source] |
|||
scores = scores_sum_average * seq_lengths |
|||
is_stopped = is_stopped[next_tokens_source] |
|||
|
|||
context_tokens = torch.cat((context_tokens, next_tokens), dim=1) |
|||
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() |
|||
|
|||
#### |
|||
tmp_scores = scores / seq_lengths |
|||
tmp_output_list = gen_tokens.cpu().numpy() |
|||
tmp_output_texts = [ |
|||
self.lm_tokenizer.decode(tmp_output) |
|||
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) |
|||
] |
|||
tmp_order = tmp_scores.argsort(descending=True) |
|||
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] |
|||
log_info(tmp_output_texts, verbose=True) |
|||
#### |
|||
|
|||
if is_stopped.all(): |
|||
break |
|||
|
|||
scores = scores / seq_lengths |
|||
output_list = gen_tokens.cpu().numpy() |
|||
output_texts = [ |
|||
self.lm_tokenizer.decode(output[: int(length)]) |
|||
for output, length in zip(output_list, seq_lengths) |
|||
] |
|||
order = scores.argsort(descending=True) |
|||
output_texts = [output_texts[i] for i in order] |
|||
|
|||
return context_tokens, output_texts |
|||
|
|||
def get_next_probs(self, i, context_tokens): |
|||
last_token = context_tokens[:, -1:] |
|||
|
|||
if self.reset_context_delta and context_tokens.size(1) > 1: |
|||
context = self.lm_model(context_tokens[:, :-1])["past_key_values"] |
|||
|
|||
# Logits of LM with unshifted context |
|||
logits_before_shift = self.lm_model(context_tokens)["logits"] |
|||
logits_before_shift = logits_before_shift[:, -1, :] |
|||
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) |
|||
|
|||
if context: |
|||
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) |
|||
|
|||
lm_output = self.lm_model(last_token, past_key_values=context) |
|||
logits, past = ( |
|||
lm_output["logits"], |
|||
lm_output["past_key_values"], |
|||
) |
|||
logits = logits[:, -1, :] |
|||
|
|||
logits = self.update_special_tokens_logits(context_tokens, i, logits) |
|||
|
|||
probs = nn.functional.softmax(logits, dim=-1) |
|||
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) |
|||
probs = probs / probs.sum() |
|||
|
|||
return probs |
|||
|
|||
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): |
|||
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] |
|||
|
|||
window_mask = torch.ones_like(context[0][0]).to(self.device) |
|||
|
|||
for i in range(self.num_iterations): |
|||
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in |
|||
context_delta] |
|||
|
|||
for p0, p1 in curr_shift: |
|||
p0.retain_grad() |
|||
p1.retain_grad() |
|||
|
|||
shifted_context = list(map(add_context, context, curr_shift)) |
|||
|
|||
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) |
|||
logits = shifted_outputs["logits"][:, -1, :] |
|||
probs = nn.functional.softmax(logits, dim=-1) |
|||
|
|||
loss = 0.0 |
|||
|
|||
# CLIP LOSS |
|||
clip_loss, clip_losses = self.clip_loss(probs, context_tokens) |
|||
loss += self.clip_scale * clip_loss |
|||
|
|||
# CE/Fluency loss |
|||
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) |
|||
loss += ce_loss.sum() |
|||
|
|||
loss.backward() |
|||
|
|||
# ---------- Weights ---------- |
|||
combined_scores_k = -(ce_loss) |
|||
combined_scores_c = -(self.clip_scale * torch.stack(clip_losses)) |
|||
|
|||
# minmax |
|||
if combined_scores_k.shape[0] == 1: |
|||
tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device) |
|||
else: |
|||
tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / ( |
|||
combined_scores_k.max() - combined_scores_k.min()) |
|||
tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / ( |
|||
combined_scores_c.max() - combined_scores_c.min()) |
|||
|
|||
tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c |
|||
tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1) |
|||
|
|||
factor = 1 |
|||
|
|||
# --------- Specific Gen --------- |
|||
sep_grads = None |
|||
|
|||
for b in range(context_tokens.shape[0]): |
|||
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] |
|||
for p_ in curr_shift] |
|||
|
|||
# normalize gradients |
|||
tmp_grad = [tuple([-self.stepsize * factor * ( |
|||
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ |
|||
j] ** self.grad_norm_factor).data.cpu().numpy() |
|||
for j, x in enumerate(p_)]) |
|||
for i, p_ in enumerate(curr_shift)] |
|||
if sep_grads is None: |
|||
sep_grads = tmp_grad |
|||
else: |
|||
for l_index in range(len(sep_grads)): |
|||
sep_grads[l_index] = list(sep_grads[l_index]) |
|||
for k_index in range(len(sep_grads[0])): |
|||
sep_grads[l_index][k_index] = np.concatenate( |
|||
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) |
|||
sep_grads[l_index] = tuple(sep_grads[l_index]) |
|||
final_grads = sep_grads |
|||
|
|||
# --------- update context --------- |
|||
context_delta = list(map(add_context, final_grads, context_delta)) |
|||
|
|||
for p0, p1 in curr_shift: |
|||
p0.grad.data.zero_() |
|||
p1.grad.data.zero_() |
|||
|
|||
new_context = [] |
|||
for p0, p1 in context: |
|||
new_context.append((p0.detach(), p1.detach())) |
|||
context = new_context |
|||
|
|||
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) |
|||
for p_ in context_delta] |
|||
context = list(map(add_context, context, context_delta)) |
|||
|
|||
new_context = [] |
|||
for p0, p1 in context: |
|||
new_context.append((p0.detach(), p1.detach())) |
|||
context = new_context |
|||
|
|||
return context |
|||
|
|||
def update_special_tokens_logits(self, context_tokens, i, logits): |
|||
for beam_id in range(context_tokens.shape[0]): |
|||
for token_idx in set(context_tokens[beam_id][-4:].tolist()): |
|||
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) |
|||
logits[beam_id, token_idx] /= factor |
|||
|
|||
if i >= self.ef_idx: |
|||
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) |
|||
logits[beam_id, self.end_token] *= factor |
|||
if i == 0: |
|||
start_factor = 1.6 |
|||
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) |
|||
logits[beam_id, self.end_token] /= factor |
|||
|
|||
for token_idx in list(self.forbidden_tokens): |
|||
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) |
|||
logits[beam_id, token_idx] /= factor |
|||
|
|||
return logits |
|||
|
|||
def clip_loss(self, probs, context_tokens): |
|||
for p_ in self.clip.transformer.parameters(): |
|||
if p_.grad is not None: |
|||
p_.grad.data.zero_() |
|||
|
|||
top_size = 512 |
|||
_, top_indices = probs.topk(top_size, -1) |
|||
|
|||
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] |
|||
|
|||
clip_loss = 0 |
|||
losses = [] |
|||
for idx_p in range(probs.shape[0]): |
|||
top_texts = [] |
|||
prefix_text = prefix_texts[idx_p] |
|||
for x in top_indices[idx_p]: |
|||
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|||
text_features = self.get_txt_features(top_texts) |
|||
|
|||
with torch.no_grad(): |
|||
similiraties = (self.image_features @ text_features.T) |
|||
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|||
target_probs = target_probs.type(torch.float32) |
|||
|
|||
target = torch.zeros_like(probs[idx_p]) |
|||
target[top_indices[idx_p]] = target_probs[0] |
|||
target = target.unsqueeze(0) |
|||
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
|||
|
|||
clip_loss += cur_clip_loss |
|||
losses.append(cur_clip_loss) |
|||
|
|||
return clip_loss, losses |
@ -1,449 +0,0 @@ |
|||
import numpy as np |
|||
from torch import nn |
|||
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer |
|||
from transformers.models.gpt_neo import GPTNeoForCausalLM |
|||
import torch |
|||
import clip |
|||
from PIL import Image |
|||
from datetime import datetime |
|||
import sys |
|||
|
|||
class TextCLIP(nn.Module): |
|||
def __init__(self, model): |
|||
super(TextCLIP, self).__init__() |
|||
self.model = model |
|||
|
|||
def forward(self, text): |
|||
return self.model.encode_text(text) |
|||
|
|||
|
|||
class ImageCLIP(nn.Module): |
|||
def __init__(self, model): |
|||
super(ImageCLIP, self).__init__() |
|||
self.model = model |
|||
|
|||
def forward(self, image): |
|||
return self.model.encode_image(image) |
|||
|
|||
def log_info(text, verbose=True): |
|||
if verbose: |
|||
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S") |
|||
print(f'{dt_string} | {text}') |
|||
sys.stdout.flush() |
|||
|
|||
|
|||
def add_context(x, y): |
|||
return (x[0] + y[0], x[1] + y[1]) |
|||
|
|||
|
|||
def convert_models_to_fp32(model): |
|||
for p in model.parameters(): |
|||
p.data = p.data.float() |
|||
|
|||
|
|||
class CLIPTextGenerator: |
|||
def __init__(self, |
|||
seed=0, |
|||
lm_model='gpt-2', |
|||
forbidden_tokens_file_path='./forbidden_tokens.npy', |
|||
clip_checkpoints='./clip_checkpoints', |
|||
target_seq_length=15, |
|||
reset_context_delta=True, |
|||
num_iterations=5, |
|||
clip_loss_temperature=0.01, |
|||
clip_scale=1., |
|||
ce_scale=0.2, |
|||
stepsize=0.3, |
|||
grad_norm_factor=0.9, |
|||
fusion_factor=0.99, |
|||
repetition_penalty=1., |
|||
end_token='.', |
|||
end_factor=1.01, |
|||
forbidden_factor=20, |
|||
**kwargs): |
|||
|
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
|
|||
# set Random seed |
|||
torch.manual_seed(seed) |
|||
np.random.seed(seed) |
|||
|
|||
# Initialize Language model |
|||
self.context_prefix = '' |
|||
|
|||
if lm_model == 'gpt-neo': |
|||
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M') |
|||
self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True) |
|||
elif lm_model == 'gpt-2': |
|||
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium') |
|||
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True) |
|||
self.context_prefix = self.lm_tokenizer.bos_token |
|||
|
|||
self.lm_model.to(self.device) |
|||
self.lm_model.eval() |
|||
|
|||
self.forbidden_tokens = np.load(forbidden_tokens_file_path) |
|||
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if |
|||
(x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())] |
|||
|
|||
# Freeze LM weights |
|||
for param in self.lm_model.parameters(): |
|||
param.requires_grad = False |
|||
|
|||
# Initialize CLIP |
|||
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device, |
|||
download_root=clip_checkpoints, jit=False) |
|||
self.clip_image = ImageCLIP(self.clip) |
|||
self.clip_image = torch.nn.DataParallel(self.clip_image) |
|||
self.clip_text = TextCLIP(self.clip) |
|||
self.clip_text = torch.nn.DataParallel(self.clip_text) |
|||
|
|||
# Init arguments |
|||
self.target_seq_length = target_seq_length |
|||
self.reset_context_delta = reset_context_delta |
|||
self.num_iterations = num_iterations |
|||
self.clip_loss_temperature = clip_loss_temperature |
|||
self.clip_scale = clip_scale |
|||
self.ce_scale = ce_scale |
|||
self.stepsize = stepsize |
|||
self.grad_norm_factor = grad_norm_factor |
|||
self.fusion_factor = fusion_factor |
|||
self.repetition_penalty = repetition_penalty |
|||
self.end_token = self.lm_tokenizer.encode(end_token)[0] |
|||
self.end_factor = end_factor |
|||
self.ef_idx = 1 |
|||
self.forbidden_factor = forbidden_factor |
|||
|
|||
def get_img_feature(self, img_path, weights): |
|||
imgs = [Image.open(x) for x in img_path] |
|||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|||
|
|||
with torch.no_grad(): |
|||
image_fts = [self.clip_image(x) for x in clip_imgs] |
|||
|
|||
if weights is not None: |
|||
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)]) |
|||
else: |
|||
image_features = sum(image_fts) |
|||
|
|||
image_features = torch.nn.functional.normalize(image_features, dim=-1) |
|||
return image_features.detach() |
|||
|
|||
def get_txt_features(self, text): |
|||
clip_texts = clip.tokenize(text).to(self.device) |
|||
|
|||
with torch.no_grad(): |
|||
text_features = self.clip_text(clip_texts) |
|||
|
|||
text_features = torch.nn.functional.normalize(text_features, dim=-1) |
|||
return text_features.detach() |
|||
|
|||
def get_combined_feature(self, img_path, texts, weights_i, weights_t): |
|||
imgs = [Image.open(x) for x in img_path] |
|||
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs] |
|||
clip_texts = [clip.tokenize(x).to(self.device) for x in texts] |
|||
|
|||
with torch.no_grad(): |
|||
image_fts = [self.clip.encode_image(x) for x in clip_imgs] |
|||
text_fts = [self.clip.encode_text(x) for x in clip_texts] |
|||
|
|||
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)]) |
|||
if weights_t is not None: |
|||
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)]) |
|||
|
|||
features = features / features.norm(dim=-1, keepdim=True) |
|||
return features.detach() |
|||
|
|||
def run(self, image_features, cond_text, beam_size): |
|||
self.image_features = image_features |
|||
|
|||
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text) |
|||
|
|||
output_tokens, output_text = self.generate_text(context_tokens, beam_size) |
|||
|
|||
return output_text |
|||
|
|||
def generate_text(self, context_tokens, beam_size): |
|||
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0) |
|||
|
|||
gen_tokens = None |
|||
scores = None |
|||
seq_lengths = torch.ones(beam_size, device=self.device) |
|||
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool) |
|||
|
|||
for i in range(self.target_seq_length): |
|||
probs = self.get_next_probs(i, context_tokens) |
|||
logits = probs.log() |
|||
|
|||
if scores is None: |
|||
scores, next_tokens = logits.topk(beam_size, -1) |
|||
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:]) |
|||
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
|||
|
|||
if gen_tokens is None: |
|||
gen_tokens = next_tokens |
|||
else: |
|||
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:]) |
|||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1) |
|||
else: |
|||
logits[is_stopped] = -float(np.inf) |
|||
logits[is_stopped, 0] = 0 |
|||
scores_sum = scores[:, None] + logits |
|||
seq_lengths[~is_stopped] += 1 |
|||
scores_sum_average = scores_sum / seq_lengths[:, None] |
|||
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( |
|||
beam_size, -1) |
|||
next_tokens_source = next_tokens // scores_sum.shape[1] |
|||
seq_lengths = seq_lengths[next_tokens_source] |
|||
next_tokens = next_tokens % scores_sum.shape[1] |
|||
next_tokens = next_tokens.unsqueeze(1) |
|||
gen_tokens = gen_tokens[next_tokens_source] |
|||
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1) |
|||
context_tokens = context_tokens[next_tokens_source] |
|||
scores = scores_sum_average * seq_lengths |
|||
is_stopped = is_stopped[next_tokens_source] |
|||
|
|||
context_tokens = torch.cat((context_tokens, next_tokens), dim=1) |
|||
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze() |
|||
|
|||
#### |
|||
tmp_scores = scores / seq_lengths |
|||
tmp_output_list = gen_tokens.cpu().numpy() |
|||
tmp_output_texts = [ |
|||
self.lm_tokenizer.decode(tmp_output) |
|||
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths) |
|||
] |
|||
tmp_order = tmp_scores.argsort(descending=True) |
|||
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order] |
|||
log_info(tmp_output_texts, verbose=True) |
|||
#### |
|||
|
|||
if is_stopped.all(): |
|||
break |
|||
|
|||
scores = scores / seq_lengths |
|||
output_list = gen_tokens.cpu().numpy() |
|||
output_texts = [ |
|||
self.lm_tokenizer.decode(output[: int(length)]) |
|||
for output, length in zip(output_list, seq_lengths) |
|||
] |
|||
order = scores.argsort(descending=True) |
|||
output_texts = [output_texts[i] for i in order] |
|||
|
|||
return context_tokens, output_texts |
|||
|
|||
def get_next_probs(self, i, context_tokens): |
|||
last_token = context_tokens[:, -1:] |
|||
|
|||
if self.reset_context_delta and context_tokens.size(1) > 1: |
|||
context = self.lm_model(context_tokens[:, :-1])["past_key_values"] |
|||
|
|||
# Logits of LM with unshifted context |
|||
logits_before_shift = self.lm_model(context_tokens)["logits"] |
|||
logits_before_shift = logits_before_shift[:, -1, :] |
|||
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1) |
|||
|
|||
if context: |
|||
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift) |
|||
|
|||
lm_output = self.lm_model(last_token, past_key_values=context) |
|||
logits, past = ( |
|||
lm_output["logits"], |
|||
lm_output["past_key_values"], |
|||
) |
|||
logits = logits[:, -1, :] |
|||
|
|||
logits = self.update_special_tokens_logits(context_tokens, i, logits) |
|||
|
|||
probs = nn.functional.softmax(logits, dim=-1) |
|||
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor)) |
|||
probs = probs / probs.sum() |
|||
|
|||
return probs |
|||
|
|||
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift): |
|||
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context] |
|||
|
|||
for i in range(self.num_iterations): |
|||
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in |
|||
context_delta] |
|||
|
|||
for p0, p1 in curr_shift: |
|||
p0.retain_grad() |
|||
p1.retain_grad() |
|||
|
|||
shifted_context = list(map(add_context, context, curr_shift)) |
|||
|
|||
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context) |
|||
logits = shifted_outputs["logits"][:, -1, :] |
|||
probs = nn.functional.softmax(logits, dim=-1) |
|||
|
|||
loss = 0.0 |
|||
|
|||
# CLIP LOSS |
|||
clip_loss, clip_losses = self.clip_loss(probs, context_tokens) |
|||
loss += self.clip_scale * clip_loss |
|||
|
|||
# CE/Fluency loss |
|||
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1) |
|||
loss += ce_loss.sum() |
|||
|
|||
loss.backward() |
|||
|
|||
# --------- Specific Gen --------- |
|||
final_grads = self.norm_grad(context, context_tokens, curr_shift) |
|||
|
|||
# --------- update context --------- |
|||
context_delta = list(map(add_context, final_grads, context_delta)) |
|||
|
|||
for p0, p1 in curr_shift: |
|||
p0.grad.data.zero_() |
|||
p1.grad.data.zero_() |
|||
|
|||
new_context = [] |
|||
for p0, p1 in context: |
|||
new_context.append((p0.detach(), p1.detach())) |
|||
context = new_context |
|||
|
|||
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) |
|||
for p_ in context_delta] |
|||
context = list(map(add_context, context, context_delta)) |
|||
|
|||
new_context = [] |
|||
for p0, p1 in context: |
|||
new_context.append((p0.detach(), p1.detach())) |
|||
context = new_context |
|||
|
|||
return context |
|||
|
|||
def norm_grad(self, context, context_tokens, curr_shift, ): |
|||
factor = 1 |
|||
sep_grads = None |
|||
window_mask = torch.ones_like(context[0][0]).to(self.device) |
|||
|
|||
for b in range(context_tokens.shape[0]): |
|||
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_] |
|||
for p_ in curr_shift] |
|||
|
|||
# normalize gradients |
|||
tmp_grad = [tuple([-self.stepsize * factor * ( |
|||
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][ |
|||
j] ** self.grad_norm_factor).data.cpu().numpy() |
|||
for j, x in enumerate(p_)]) |
|||
for i, p_ in enumerate(curr_shift)] |
|||
if sep_grads is None: |
|||
sep_grads = tmp_grad |
|||
else: |
|||
for l_index in range(len(sep_grads)): |
|||
sep_grads[l_index] = list(sep_grads[l_index]) |
|||
for k_index in range(len(sep_grads[0])): |
|||
sep_grads[l_index][k_index] = np.concatenate( |
|||
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0) |
|||
sep_grads[l_index] = tuple(sep_grads[l_index]) |
|||
final_grads = sep_grads |
|||
|
|||
return final_grads |
|||
|
|||
def update_special_tokens_logits(self, context_tokens, i, logits): |
|||
for beam_id in range(context_tokens.shape[0]): |
|||
for token_idx in set(context_tokens[beam_id][-4:].tolist()): |
|||
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty) |
|||
logits[beam_id, token_idx] /= factor |
|||
|
|||
if i >= self.ef_idx: |
|||
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor) |
|||
logits[beam_id, self.end_token] *= factor |
|||
if i == 0: |
|||
start_factor = 1.6 |
|||
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor) |
|||
logits[beam_id, self.end_token] /= factor |
|||
|
|||
for token_idx in list(self.forbidden_tokens): |
|||
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor) |
|||
logits[beam_id, token_idx] /= factor |
|||
|
|||
return logits |
|||
|
|||
def clip_loss(self, probs, context_tokens): |
|||
for p_ in self.clip.transformer.parameters(): |
|||
if p_.grad is not None: |
|||
p_.grad.data.zero_() |
|||
|
|||
top_size = 512 |
|||
top_probs, top_indices = probs.topk(top_size, -1) |
|||
|
|||
prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens] |
|||
|
|||
clip_loss = 0 |
|||
losses = [] |
|||
|
|||
top_texts = [] |
|||
for idx_p in range(probs.shape[0]): |
|||
prefix_text = prefix_texts[idx_p] |
|||
for x in top_indices[idx_p]: |
|||
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|||
|
|||
text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1) |
|||
|
|||
with torch.no_grad(): |
|||
similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1) |
|||
similiraties = similiraties.reshape(probs.size(0), -1) |
|||
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|||
target_probs = target_probs.type(torch.float32) |
|||
|
|||
clip_loss += torch.sum(-(target_probs * torch.log(top_probs))) |
|||
# for idx_p in range(probs.shape[0]): |
|||
# top_texts = [] |
|||
# prefix_text = prefix_texts[idx_p] |
|||
# for x in top_indices[idx_p]: |
|||
# top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|||
# text_features = self.get_txt_features(top_texts) |
|||
# |
|||
# with torch.no_grad(): |
|||
# similiraties = (self.image_features @ text_features.T) |
|||
# target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|||
# target_probs = target_probs.type(torch.float32) |
|||
# |
|||
# target = torch.zeros_like(probs[idx_p]) |
|||
# target[top_indices[idx_p]] = target_probs[0] |
|||
# target = target.unsqueeze(0) |
|||
# cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
|||
# |
|||
# clip_loss += cur_clip_loss |
|||
# losses.append(cur_clip_loss) |
|||
|
|||
return clip_loss, losses |
|||
|
|||
def clip_loss_old(self, probs, context_tokens): |
|||
for p_ in self.clip.transformer.parameters(): |
|||
if p_.grad is not None: |
|||
p_.grad.data.zero_() |
|||
|
|||
top_size = 512 |
|||
_, top_indices = probs.topk(top_size, -1) |
|||
|
|||
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens] |
|||
|
|||
clip_loss = 0 |
|||
losses = [] |
|||
for idx_p in range(probs.shape[0]): |
|||
top_texts = [] |
|||
prefix_text = prefix_texts[idx_p] |
|||
for x in top_indices[idx_p]: |
|||
top_texts.append(prefix_text + self.lm_tokenizer.decode(x)) |
|||
text_features = self.get_txt_features(top_texts) |
|||
|
|||
with torch.no_grad(): |
|||
similiraties = (self.image_features @ text_features.T) |
|||
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach() |
|||
target_probs = target_probs.type(torch.float32) |
|||
|
|||
target = torch.zeros_like(probs[idx_p]) |
|||
target[top_indices[idx_p]] = target_probs[0] |
|||
target = target.unsqueeze(0) |
|||
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)]))) |
|||
|
|||
clip_loss += cur_clip_loss |
|||
losses.append(cur_clip_loss) |
|||
|
|||
return clip_loss, losses |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -1,14 +0,0 @@ |
|||
#!/bin/bash |
|||
|
|||
# lm_model: |
|||
# 1. cambridgeltl/magic_mscoco |
|||
# 2. cambridgeltl/magic_flickr30k |
|||
CUDA_VISIBLE_DEVICES=1 python run.py \ |
|||
--beam_size 1 \ |
|||
--target_seq_length 16 \ |
|||
--reset_context_delta \ |
|||
--lm_model cambridgeltl/magic_mscoco \ |
|||
--test_image_prefix_path ../data/mscoco/test_images \ |
|||
--test_path ../data/mscoco/mscoco_test.json \ |
|||
--save_path_prefix ../inference_result/mscoco/baselines/ \ |
|||
--save_name zerocap_result.json |
@ -1,117 +0,0 @@ |
|||
import os |
|||
import tempfile |
|||
import sys |
|||
sys.path.append('CLIP') |
|||
from pathlib import Path |
|||
import cog |
|||
import argparse |
|||
import torch |
|||
import clip |
|||
from model.ZeroCLIP import CLIPTextGenerator |
|||
|
|||
def perplexity_score(text, lm_model, lm_tokenizer, device): |
|||
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') |
|||
input_ids = encodings.input_ids.to(device) |
|||
target_ids = input_ids.clone() |
|||
|
|||
outputs = lm_model(input_ids, labels=target_ids) |
|||
log_likelihood = outputs[0] |
|||
ll = log_likelihood.item() |
|||
|
|||
return ll |
|||
|
|||
class Predictor(cog.Predictor): |
|||
def setup(self): |
|||
self.args = get_args() |
|||
self.args.reset_context_delta = True |
|||
self.text_generator = CLIPTextGenerator(**vars(self.args)) |
|||
|
|||
@cog.input( |
|||
"image", |
|||
type=Path, |
|||
help="input image" |
|||
) |
|||
@cog.input( |
|||
"cond_text", |
|||
type=str, |
|||
default='Image of a', |
|||
help="conditional text", |
|||
) |
|||
@cog.input( |
|||
"beam_size", |
|||
type=int, |
|||
default=5, min=1, max=10, |
|||
help="Number of beams to use", |
|||
) |
|||
@cog.input( |
|||
"end_factor", |
|||
type=float, |
|||
default=1.01, min=1.0, max=1.10, |
|||
help="Higher value for shorter captions", |
|||
) |
|||
@cog.input( |
|||
"max_seq_length", |
|||
type=int, |
|||
default=15, min=1, max=20, |
|||
help="Maximum number of tokens to generate", |
|||
) |
|||
@cog.input( |
|||
"ce_loss_scale", |
|||
type=float, |
|||
default=0.2, min=0.0, max=0.6, |
|||
help="Scale of cross-entropy loss with un-shifted language model", |
|||
) |
|||
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale): |
|||
self.args.cond_text = cond_text |
|||
self.text_generator.end_factor = end_factor |
|||
self.text_generator.target_seq_length = max_seq_length |
|||
self.text_generator.ce_scale = ce_loss_scale |
|||
|
|||
image_features = self.text_generator.get_img_feature([str(image)], None) |
|||
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) |
|||
|
|||
# CLIP SCORE |
|||
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) |
|||
for c in captions] |
|||
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
|||
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
|||
|
|||
# Perplexity SCORE |
|||
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] |
|||
best_ppl_index = torch.tensor(ppl_scores).argmin().item() |
|||
|
|||
best_clip_caption = self.args.cond_text + captions[best_clip_idx] |
|||
best_mixed = self.args.cond_text + captions[0] |
|||
best_PPL = self.args.cond_text + captions[best_ppl_index] |
|||
|
|||
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' |
|||
|
|||
return final |
|||
# return self.args.cond_text + captions[best_clip_idx] |
|||
|
|||
|
|||
def get_args(): |
|||
parser = argparse.ArgumentParser() |
|||
|
|||
parser.add_argument("--seed", type=int, default=0) |
|||
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
|||
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
|||
parser.add_argument("--target_seq_length", type=int, default=15) |
|||
parser.add_argument("--cond_text", type=str, default="Image of a") |
|||
parser.add_argument("--reset_context_delta", action="store_true", |
|||
help="Should we reset the context at each token gen") |
|||
parser.add_argument("--num_iterations", type=int, default=5) |
|||
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
|||
parser.add_argument("--clip_scale", type=float, default=1) |
|||
parser.add_argument("--ce_scale", type=float, default=0.2) |
|||
parser.add_argument("--stepsize", type=float, default=0.3) |
|||
parser.add_argument("--grad_norm_factor", type=float, default=0.9) |
|||
parser.add_argument("--fusion_factor", type=float, default=0.99) |
|||
parser.add_argument("--repetition_penalty", type=float, default=1) |
|||
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
|||
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
|||
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
|||
parser.add_argument("--beam_size", type=int, default=5) |
|||
|
|||
args = parser.parse_args('') |
|||
return args |
@ -1,129 +0,0 @@ |
|||
import os |
|||
import tempfile |
|||
import sys |
|||
sys.path.append('CLIP') |
|||
from pathlib import Path |
|||
import cog |
|||
import argparse |
|||
import torch |
|||
import clip |
|||
from model.ZeroCLIP import CLIPTextGenerator |
|||
|
|||
def perplexity_score(text, lm_model, lm_tokenizer, device): |
|||
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt') |
|||
input_ids = encodings.input_ids.to(device) |
|||
target_ids = input_ids.clone() |
|||
|
|||
outputs = lm_model(input_ids, labels=target_ids) |
|||
log_likelihood = outputs[0] |
|||
ll = log_likelihood.item() |
|||
|
|||
return ll |
|||
|
|||
class Predictor(cog.Predictor): |
|||
def setup(self): |
|||
self.args = get_args() |
|||
self.args.reset_context_delta = True |
|||
self.text_generator = CLIPTextGenerator(**vars(self.args)) |
|||
|
|||
@cog.input( |
|||
"image1", |
|||
type=Path, |
|||
help="Final result will be: image1 + (image2 - image3)" |
|||
) |
|||
@cog.input( |
|||
"image2", |
|||
type=Path, |
|||
help="Final result will be: image1 + (image2 - image3)" |
|||
) |
|||
@cog.input( |
|||
"image3", |
|||
type=Path, |
|||
help="Final result will be: image1 + (image2 - image3)" |
|||
) |
|||
@cog.input( |
|||
"cond_text", |
|||
type=str, |
|||
default='Image of a', |
|||
help="conditional text", |
|||
) |
|||
@cog.input( |
|||
"beam_size", |
|||
type=int, |
|||
default=3, min=1, max=10, |
|||
help="Number of beams to use", |
|||
) |
|||
@cog.input( |
|||
"end_factors", |
|||
type=float, |
|||
default=1.06, min=1.0, max=1.10, |
|||
help="Higher value for shorter captions", |
|||
) |
|||
@cog.input( |
|||
"max_seq_lengths", |
|||
type=int, |
|||
default=3, min=1, max=20, |
|||
help="Maximum number of tokens to generate", |
|||
) |
|||
@cog.input( |
|||
"ce_loss_scale", |
|||
type=float, |
|||
default=0.2, min=0.0, max=0.6, |
|||
help="Scale of cross-entropy loss with un-shifted language model", |
|||
) |
|||
def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale): |
|||
self.args.cond_text = cond_text |
|||
self.text_generator.end_factor = end_factors |
|||
self.text_generator.target_seq_length = max_seq_lengths |
|||
self.text_generator.ce_scale = ce_loss_scale |
|||
self.text_generator.fusion_factor = 0.95 |
|||
self.text_generator.grad_norm_factor = 0.95 |
|||
|
|||
image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None) |
|||
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size) |
|||
|
|||
# CLIP SCORE |
|||
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device)) |
|||
for c in captions] |
|||
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
|||
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
|||
|
|||
# Perplexity SCORE |
|||
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions] |
|||
best_ppl_index = torch.tensor(ppl_scores).argmin().item() |
|||
|
|||
best_clip_caption = self.args.cond_text + captions[best_clip_idx] |
|||
best_mixed = self.args.cond_text + captions[0] |
|||
best_PPL = self.args.cond_text + captions[best_ppl_index] |
|||
|
|||
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}' |
|||
|
|||
return final |
|||
# return self.args.cond_text + captions[best_clip_idx] |
|||
|
|||
|
|||
def get_args(): |
|||
parser = argparse.ArgumentParser() |
|||
|
|||
parser.add_argument("--seed", type=int, default=0) |
|||
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
|||
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
|||
parser.add_argument("--target_seq_length", type=int, default=15) |
|||
parser.add_argument("--cond_text", type=str, default="Image of a") |
|||
parser.add_argument("--reset_context_delta", action="store_true", |
|||
help="Should we reset the context at each token gen") |
|||
parser.add_argument("--num_iterations", type=int, default=5) |
|||
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
|||
parser.add_argument("--clip_scale", type=float, default=1) |
|||
parser.add_argument("--ce_scale", type=float, default=0.2) |
|||
parser.add_argument("--stepsize", type=float, default=0.3) |
|||
parser.add_argument("--grad_norm_factor", type=float, default=0.95) |
|||
parser.add_argument("--fusion_factor", type=float, default=0.95) |
|||
parser.add_argument("--repetition_penalty", type=float, default=1) |
|||
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
|||
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
|||
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
|||
parser.add_argument("--beam_size", type=int, default=5) |
|||
|
|||
args = parser.parse_args('') |
|||
return args |
@ -1,3 +0,0 @@ |
|||
ftfy |
|||
regex |
|||
tqdm |
@ -1,131 +0,0 @@ |
|||
import argparse |
|||
import ipdb |
|||
from tqdm import tqdm |
|||
import progressbar |
|||
import torch |
|||
import ipdb |
|||
import clip |
|||
from model.ZeroCLIP import CLIPTextGenerator |
|||
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu |
|||
|
|||
def get_args(): |
|||
parser = argparse.ArgumentParser() |
|||
|
|||
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") |
|||
parser.add_argument("--test_path", type=str) |
|||
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") |
|||
parser.add_argument("--save_name", type=str, help="the name of the saved file") |
|||
|
|||
parser.add_argument("--seed", type=int, default=0) |
|||
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") |
|||
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") |
|||
parser.add_argument("--target_seq_length", type=int, default=15) |
|||
parser.add_argument("--cond_text", type=str, default="Image of a") |
|||
parser.add_argument("--reset_context_delta", action="store_true", |
|||
help="Should we reset the context at each token gen") |
|||
parser.add_argument("--num_iterations", type=int, default=5) |
|||
parser.add_argument("--clip_loss_temperature", type=float, default=0.01) |
|||
parser.add_argument("--clip_scale", type=float, default=1) |
|||
parser.add_argument("--ce_scale", type=float, default=0.2) |
|||
parser.add_argument("--stepsize", type=float, default=0.3) |
|||
parser.add_argument("--grad_norm_factor", type=float, default=0.9) |
|||
parser.add_argument("--fusion_factor", type=float, default=0.99) |
|||
parser.add_argument("--repetition_penalty", type=float, default=1) |
|||
parser.add_argument("--end_token", type=str, default=".", help="Token to end text") |
|||
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") |
|||
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") |
|||
parser.add_argument("--beam_size", type=int, default=1) |
|||
|
|||
parser.add_argument("--multi_gpu", action="store_true") |
|||
|
|||
parser.add_argument('--run_type', |
|||
default='caption', |
|||
nargs='?', |
|||
choices=['caption', 'arithmetics']) |
|||
|
|||
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', |
|||
help="Path to image for captioning") |
|||
|
|||
parser.add_argument("--arithmetics_imgs", nargs="+", |
|||
default=['example_images/arithmetics/woman2.jpg', |
|||
'example_images/arithmetics/king2.jpg', |
|||
'example_images/arithmetics/man2.jpg']) |
|||
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) |
|||
|
|||
args = parser.parse_args() |
|||
|
|||
return args |
|||
|
|||
def run(args, text_generator, img_path): |
|||
image_features = text_generator.get_img_feature([img_path], None) |
|||
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) |
|||
|
|||
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] |
|||
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] |
|||
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() |
|||
return captions |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
if torch.cuda.is_available(): |
|||
print ('Cuda is available.') |
|||
cuda_available = torch.cuda.is_available() |
|||
args = get_args() |
|||
device = torch.device('cuda') |
|||
|
|||
save_path_prefix = args.save_path_prefix |
|||
import os |
|||
if os.path.exists(save_path_prefix): |
|||
pass |
|||
else: # recursively construct directory |
|||
os.makedirs(save_path_prefix, exist_ok=True) |
|||
# parse save name |
|||
save_name = args.save_name |
|||
full_save_path = save_path_prefix + '/' + save_name |
|||
print ('full save path is {}'.format(full_save_path)) |
|||
|
|||
print ('Loading data...') |
|||
import json |
|||
with open(args.test_path) as f: |
|||
item_list = json.load(f) |
|||
print ('Data loaded.') |
|||
print ('Number of test instances is {}'.format(len(item_list))) |
|||
|
|||
# ZeroCap generator |
|||
text_generator = CLIPTextGenerator(**vars(args)) |
|||
|
|||
result_list = [] |
|||
invalid_num = 0 |
|||
print ('----------------------------------------------------------------') |
|||
test_num = len(item_list) |
|||
#test_num = 10 |
|||
print ('Number of inference instances is {}'.format(test_num)) |
|||
p = progressbar.ProgressBar(test_num) |
|||
p.start() |
|||
for p_idx in tqdm(range(test_num)): |
|||
p.update(p_idx) |
|||
one_test_dict = item_list[p_idx] |
|||
|
|||
one_res_dict = { |
|||
'split':one_test_dict['split'], |
|||
'image_name':one_test_dict['image_name'], |
|||
#'file_path':one_test_dict['file_path'], |
|||
'captions':one_test_dict['captions'] |
|||
} |
|||
|
|||
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] |
|||
try: |
|||
output_text = run(args, text_generator, img_path=image_full_path) |
|||
one_res_dict['prediction'] = output_text[0] |
|||
result_list.append(one_res_dict) |
|||
except Exception as error: |
|||
print(f'[!] ERROR:', error) |
|||
invalid_num += 1 |
|||
print ('invalid number is {}'.format(invalid_num)) |
|||
continue |
|||
p.finish() |
|||
print ('Inference completed!') |
|||
|
|||
import json |
|||
with open(full_save_path, 'w') as outfile: |
|||
json.dump(result_list, outfile, indent=4) |
@ -1,19 +0,0 @@ |
|||
import os |
|||
|
|||
import pkg_resources |
|||
from setuptools import setup, find_packages |
|||
|
|||
setup( |
|||
name="zero-shot-image-to-text", |
|||
py_modules=["zero-shot-image-to-text"], |
|||
version="1.0", |
|||
description="", |
|||
packages=find_packages(), |
|||
install_requires=[ |
|||
str(r) |
|||
for r in pkg_resources.parse_requirements( |
|||
open(os.path.join(os.path.dirname(__file__), "requirements.txt")) |
|||
) |
|||
], |
|||
include_package_data=True |
|||
) |
Loading…
Reference in new issue