logo
Browse Source

update the magic.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
49e3970b48
  1. BIN
      .DS_Store
  2. BIN
      .README.md.swp
  3. 81
      README.md
  4. 167
      language_model/README.md
  5. 157
      language_model/dataclass.py
  6. 80
      language_model/loss_func.py
  7. 233
      language_model/simctg.py
  8. 107
      language_model/train.py
  9. 17
      language_model/train_flickr30k.sh
  10. 17
      language_model/train_mscoco.sh
  11. 165
      language_model/trainer.py
  12. 291
      language_model/utlis.py
  13. 29
      magic.py
  14. 89
      zerocap/README.md
  15. 12
      zerocap/cog.yaml
  16. 14
      zerocap/flickr30k_zerocap.sh
  17. BIN
      zerocap/forbidden_tokens.npy
  18. 389
      zerocap/model/ZeroCLIP.py
  19. 449
      zerocap/model/ZeroCLIP_batched.py
  20. 0
      zerocap/model/__init__.py
  21. BIN
      zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc
  22. BIN
      zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc
  23. BIN
      zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc
  24. BIN
      zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc
  25. BIN
      zerocap/model/__pycache__/__init__.cpython-36.pyc
  26. BIN
      zerocap/model/__pycache__/__init__.cpython-37.pyc
  27. 14
      zerocap/mscoco_zerocap.sh
  28. 117
      zerocap/predict.py
  29. 129
      zerocap/predict_arithmetic.py
  30. 3
      zerocap/requirements.txt
  31. 131
      zerocap/run.py
  32. 19
      zerocap/setup.py

BIN
.DS_Store

Binary file not shown.

BIN
.README.md.swp

Binary file not shown.

81
README.md

@ -1,2 +1,81 @@
# magic
# Image Captioning with MAGIC
*author: David Wang*
<br />
## Description
This operator generates the caption with [MAGIC](https://arxiv.org/abs/2205.02655) which describes the content of the given image. MAGIC is a simple yet efficient plug-and-play framework, which directly combines an off-the-shelf LM (i.e., GPT-2) and an image-text matching model (i.e., CLIP) for image-grounded text generation. During decoding, MAGIC influences the generation of the LM by introducing a CLIP-induced score, called magic score, which regularizes the generated result to be semantically related to a given image while being coherent to the previously generated context. This is an adaptation from [yxuansu / MAGIC](https://github.com/yxuansu/MAGIC).
<br />
## Code Example
Load an image from path './image.jpg' to generate the caption.
*Write the pipeline in simplified style*:
```python
import towhee
towhee.glob('./image.jpg') \
.image_decode() \
.image_captioning.magic(model_name='expansionnet_rf') \
.show()
```
<img src="./cap.png" alt="result1" style="height:20px;"/>
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
towhee.glob['path']('./image.jpg') \
.image_decode['path', 'img']() \
.image_captioning.magic['img', 'text'](model_name='expansionnet_rf') \
.select['img', 'text']() \
.show()
```
<img src="./tabular.png" alt="result2" style="height:60px;"/>
<br />
## Factory Constructor
Create the operator via the following factory method
***expansionnet_v2(model_name)***
**Parameters:**
***model_name:*** *str*
​ The model name of MAGIC. Supported model names:
- magic_mscoco
<br />
## Interface
An image-text embedding operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate embedding.
**Returns:** *str*
​ The caption generated by model.

167
language_model/README.md

@ -0,0 +1,167 @@
## Unsupervised Domain Adaptation of Language Model
****
### Catalogue:
* <a href='#mscoco'>1. MSCOCO Benchmark</a>
* <a href='#mscoco_data_preparation'>1.1. MSCOCO Data Preparation</a>
* <a href='#mscoco_training'>1.2. Unsupervised Domain Adaptation on MSCOCO</a>
* <a href='#flickr30k'>2. Flickr30k Benchmark</a>
* <a href='#flickr30k_data_preparation'>2.1. Flickr30k Data Preparation</a>
* <a href='#flickr30k_training'>2.2. Unsupervised Domain Adaptation on Flickr30k</a>
* <a href='#unsupervised_baselines'>3. Unsupervised Baselines</a>
* <a href='#contrastive_search'>3.1. Contrastive Search</a>
* <a href='#top_k_sampling'>3.2. Top-k Sampling</a>
* <a href='#nucleus_sampling'>3.3. Nucleus Sampling</a>
****
<span id='mscoco'/>
#### 1. MSCOCO Benchmark:
We first describe how to perform unsupervised domain adaptation of language model on the text corpus of MSCOCO benchmark.
<span id='mscoco_data_preparation'/>
##### 1.1. MSCOCO Data Preparation:
To prepare the MSCOCO benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#1-mscoco-benchmark).
<span id='mscoco_training'/>
##### 1.2.Unsupervised Domain Adaptation on MSCOCO:
After preparing the MSCOCO data, run the following command to train the language model.
```yaml
chmod +x ./train_mscoco.sh
./train_mscoco.sh
```
The arguments are as follows:
* `--model_name`: The name of huggingface pre-trained gpt model (e.g. gpt2, gpt-large).
* `--train_path`: The file path of training set.
* `--dev_path`: The file path of validation set.
* `--test_path`: The file path of test set.
* `--add_eos_token_to_data`: Whether adding an eos token at the end of text sequence.
* `--margin`: The contrastive margin $\rho$.
* `--max_len`: The maximum length of training samples.
* `--number_of_gpu`: The number of available GPUs.
* `--batch_size_per_gpu`: The batch size for each GPU.
* `--gradient_accumulation_steps`: How many forward computations between two gradient updates.
* `--effective_batch_size`: The overall batch size. It equals to batch_size_per_gpu x gradient_accumulation_steps x number_of_gpu.
* `--total_steps`: The number of total gradient update steps.
* `--print_every`: Have many steps to show the intermediate results.
* `--save_every`: How many steps to save one checkpoint.
* `--learning_rate`: The learning rate.
* `--save_path_prefix`: Where to save the checkpoints.
****
<span id='flickr30k'/>
#### 2. Flickr30k Benchmark:
We then describe how to perform unsupervised domain adaptation of language model on the text corpus of Flickr30k benchmark.
<span id='flickr30k_data_preparation'/>
##### 2.1. Flickr30k Data Preparation:
To prepare the Flickr30k benchmark, please follow the instructions [[here]](https://github.com/yxuansu/MAGIC/tree/main/image_captioning/data#2-flickr30k-benchmark).
<span id='flickr30k_training'/>
##### 2.2. Unsupervised Domain Adaptation on Flickr30k:
After preparing the Flickr30k data, run the following command to train the language model.
```yaml
chmod +x ./train_flickr30k.sh
./train_flickr30k.sh
```
****
<span id='unsupervised_baselines'/>
#### 3. Unsupervised Baselines:
Here, we illustrate how to use the language model to perform unsupervised baselines as described in our paper. Note that, all these methods are **unsupervised** as the language model is a text-only model and does not take image as input.
```python
# first, load the language model
import torch
from simctg import SimCTG
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
# we use the language model adapted on MSCOCO as an example.
language_model_name = r'cambridgeltl/magic_mscoco'
generation_model = SimCTG(language_model_name, sos_token, pad_token)
generation_model.eval()
# then, prepare the input ids. Note that, the text is always generated from the same start of sentence token.
tokens = generation_model.tokenizer.tokenize(sos_token)
input_ids = generation_model.tokenizer.convert_tokens_to_ids(tokens)
input_ids = torch.LongTensor(input_ids).view(1,-1)
```
<span id='contrastive_search'/>
##### 3.1. Contrastive Search :
```python
'''
use contrastive search to generate the result.
note that, contrastive search is a deterministic decoding method, thus the generated text is always the same.
'''
beam_width, alpha, decoding_len = 45, 0.1, 16
output_text = generation_model.fast_contrastive_search(input_ids, beam_width, alpha, decoding_len)
print (output_text)
'''
A man is riding a skateboard down a street.
'''
```
The arguments are as follows:
* `--input_ids`: The id of the start of sentence token.
* `--beam_width`: k in the contrastive search.
* `--alpha`: alpha in the contrastive search.
* `--decoding_len`: Number of tokens to generate.
<span id='top_k_sampling'/>
##### 3.2. Top-k Sampling :
```python
'''
use top-k sampling to generate the result.
note that, the this method is a stochastic method, thus the generated text is always different.
'''
top_k, decoding_len = 40, 16
output_text = generation_model.top_k_sampling(input_ids, top_k, decoding_len)
print (output_text)
'''
some very different types of vases with flowers together
'''
```
The arguments are as follows:
* `--input_ids`: The id of the start of sentence token.
* `--k`: The k in top-k sampling.
* `--decoding_len`: Number of tokens to generate.
<span id='nucleus_sampling'/>
##### 3.3. Nucleus Sampling :
```python
'''
use nucleus sampling to generate the result.
note that, the this method is a stochastic method, thus the generated text is always different.
'''
nucleus_p, decoding_len = 0.95, 16
output_text = generation_model.nucleus_sampling(input_ids, nucleus_p, decoding_len)
print (output_text)
'''
Two young girls enjoying a hot dog hot dog bun.
'''
```
The arguments are as follows:
* `--input_ids`: The id of the start of sentence token.
* `--nucleus_p`: The probability in nucleus sampling.
* `--decoding_len`: Number of tokens to generate.

157
language_model/dataclass.py

@ -0,0 +1,157 @@
import json
import random
import torch
import numpy as np
import progressbar
from torch.nn.utils import rnn
class Data:
def __init__(self, model_name, train_path, dev_path, test_path, max_len,
sos_token, pad_token, add_eos_token_to_data):
'''
model_name: gpt2
train_path: training data path
dev_path: validation data path
test_path: test data path
max_len: maximum length for training sequences
sos_token: initialized sos token <-start_of_text->
pad_token: used to pad the sequences <-pad->
add_eos_token_to_data: whether we want to the model learn to generate eos token;
if so, the model could automatically stop generation by generating eos token
'''
from transformers import GPT2TokenizerFast
self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name)
self.sos_token, self.sos_token_id = self.add_special_token(sos_token)
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id))
self.pad_token, self.pad_token_id = self.add_special_token(pad_token)
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id))
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id))
self.add_eos_token_to_data = add_eos_token_to_data
self.max_len = max_len
self.train_token_list, self.train_token_id_list = self.process_one_file(train_path)
self.dev_token_list, self.dev_token_id_list = self.process_one_file(dev_path)
self.test_token_list, self.test_token_id_list = self.process_one_file(test_path)
self.train_num, self.dev_num, self.test_num = len(self.train_token_list), len(self.dev_token_list), \
len(self.test_token_list)
print ('train number:{}, dev number:{}, test number:{}'.format(self.train_num, self.dev_num, self.test_num))
self.train_idx_list = [i for i in range(self.train_num)]
random.shuffle(self.train_idx_list)
self.dev_idx_list = [j for j in range(self.dev_num)]
self.test_idx_list = [j for j in range(self.test_num)]
self.dev_current_idx, self.test_current_idx = 0, 0
def add_special_token(self, special_token):
if special_token in self.tokenizer.vocab:
print (special_token + ' token exists.')
else:
print ('Add token to the tokenizer.')
print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
self.tokenizer.add_tokens([special_token])
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0]
return special_token, special_token_id
def process_one_file(self, path):
print ('Processing {}'.format(path))
with open(path) as f:
item_list = json.load(f)
lines = []
for item in item_list:
captions_list = item['captions']
for one_caption in captions_list:
lines.append(one_caption.strip())
res_token_list, res_token_id_list = [], []
n = len(lines)
p = progressbar.ProgressBar(n)
p.start()
for i in range(n):
p.update(i)
text = lines[i].strip('\n')
self.process_one_text(text, res_token_list, res_token_id_list)
p.finish()
print ('{} processed!'.format(path))
return res_token_list, res_token_id_list
def process_one_text(self, text, res_token_list, res_token_id_list):
tokens = self.tokenizer.tokenize(text, max_length=self.max_len, truncation=True)
if len(tokens) <= 1: # filter out too short sequence
return
tokens = [self.sos_token] + tokens[:self.max_len]
if self.add_eos_token_to_data:
tokens = tokens + [self.eos_token]
token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
res_token_list.append(tokens)
res_token_id_list.append(token_ids)
return
def pad_batch(self, batch_id_list):
batch_id_list = [torch.LongTensor(item) for item in batch_id_list]
batch_tensor = rnn.pad_sequence(batch_id_list, batch_first=True, padding_value=self.pad_token_id)
batch_mask = torch.ones_like(batch_tensor)
batch_mask = batch_mask.masked_fill(batch_tensor.eq(self.pad_token_id), 0.0).type(torch.FloatTensor)
return batch_tensor, batch_mask
def process_output(self, batch_tgt_id_list):
batch_tgt_id_list = [torch.LongTensor(item) for item in batch_tgt_id_list]
batch_tgt_tensor, _ = self.pad_batch(batch_tgt_id_list) # padded target sequence
batch_tgt_input_tensor = batch_tgt_tensor[:, :-1].clone()
batch_tgt_output_tensor = batch_tgt_tensor[:, 1:].clone()
return batch_tgt_input_tensor, batch_tgt_output_tensor
def parse_batch(self, batch_id_list):
batch_input, batch_labels = self.process_output(batch_id_list)
batch_labels[batch_labels[:, :] == self.pad_token_id] = -100
return batch_input, batch_labels
def get_next_train_batch(self, batch_size):
batch_idx_list = random.sample(self.train_idx_list, batch_size)
batch_id_list, batch_token_list = [], []
for idx in batch_idx_list:
batch_id_list.append(self.train_token_id_list[idx])
batch_token_list.append(self.train_token_list[idx])
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list)
return batch_input_tensor, batch_labels, batch_token_list
def get_next_validation_batch(self, batch_size, mode):
batch_id_list, batch_token_list = [], []
if mode == 'dev':
curr_select_idx, instance_num = self.dev_current_idx, self.dev_num
tgt_token_id_list, tgt_token_list = self.dev_token_id_list, self.dev_token_list
elif mode == 'test':
curr_select_idx, instance_num = self.test_current_idx, self.test_num
tgt_token_id_list, tgt_token_list = self.test_token_id_list, self.test_token_list
else:
raise Exception('Wrong Validation Mode!!!')
if curr_select_idx + batch_size < instance_num:
for i in range(batch_size):
curr_idx = curr_select_idx + i
batch_id_list.append(tgt_token_id_list[curr_idx])
batch_token_list.append(tgt_token_list[curr_idx])
if mode == 'dev':
self.dev_current_idx += batch_size
else:
self.test_current_idx += batch_size
else:
for i in range(batch_size):
curr_idx = curr_select_idx + i
if curr_idx > instance_num - 1:
curr_idx = 0
if mode == 'dev':
self.dev_current_idx = 0
else:
self.test_current_idx = 0
batch_id_list.append(tgt_token_id_list[curr_idx])
batch_token_list.append(tgt_token_list[curr_idx])
if mode == 'dev':
self.dev_current_idx = 0
else:
self.test_current_idx = 0
batch_input_tensor, batch_labels = self.parse_batch(batch_id_list)
return batch_input_tensor, batch_labels, batch_token_list

80
language_model/loss_func.py

@ -0,0 +1,80 @@
import torch
def compute_valid_token_num(valid_len_list):
res = 0
for one_len in valid_len_list:
res += one_len * (one_len - 1)
return res
def build_mask_matrix(seqlen, valid_len_list, prefix_len = 0):
'''
prefix_len: the length of prefix that we do not want to compute CL loss for.
(1) if a sequence of length 4 contains zero padding token (i.e., the valid length is 4),
then the loss padding matrix looks like
[0., 1., 1., 1.],
[1., 0., 1., 1.],
[1., 1., 0., 1.],
[1., 1., 1., 0.]
(2) if a sequence of length 4 contains 1 padding token (i.e., the valid length is 3),
then the loss padding matrix looks like
[0., 1., 1., 0.],
[1., 0., 1., 0.],
[1., 1., 0., 0.],
[0., 0., 0., 0.]
'''
res_list = []
base_mask = torch.ones(seqlen, seqlen) - torch.eye(seqlen, seqlen)
base_mask = base_mask.type(torch.FloatTensor)
bsz = len(valid_len_list)
for i in range(bsz):
one_base_mask = base_mask.clone()
one_valid_len = valid_len_list[i]
one_base_mask[:,one_valid_len:] = 0.
one_base_mask[one_valid_len:, :] = 0.
if prefix_len > 0:
one_base_mask[:prefix_len, :prefix_len] = 0.
res_list.append(one_base_mask)
res_mask = torch.stack(res_list, dim = 0)#torch.FloatTensor(res_list)
#print (res_mask)
assert res_mask.size() == torch.Size([bsz, seqlen, seqlen])
return res_mask
def contrastive_loss(margin, score_matrix, input_ids, pad_token_id, prefix_len=0):
'''
margin: predefined margin to push similarity score away
score_matrix: bsz x seqlen x seqlen
input_ids: bsz x seqlen
pad_token_id: indicating which tokens are padding token
'''
bsz, seqlen, _ = score_matrix.size()
gold_score = torch.diagonal(score_matrix, offset=0, dim1=1, dim2=2) # bsz x seqlen
gold_score = torch.unsqueeze(gold_score, -1)
assert gold_score.size() == torch.Size([bsz, seqlen, 1])
difference_matrix = gold_score - score_matrix
assert difference_matrix.size() == torch.Size([bsz, seqlen, seqlen])
loss_matrix = margin - difference_matrix # bsz x seqlen x seqlen
loss_matrix = torch.nn.functional.relu(loss_matrix)
### input mask
input_mask = torch.ones_like(input_ids).type(torch.FloatTensor)
if loss_matrix.is_cuda:
input_mask = input_mask.cuda(loss_matrix.get_device())
input_mask = input_mask.masked_fill(input_ids.eq(pad_token_id), 0.0)
if loss_matrix.is_cuda:
input_mask = input_mask.cuda(loss_matrix.get_device())
valid_len_list = torch.sum(input_mask, dim = -1).tolist()
loss_mask = build_mask_matrix(seqlen, [int(item) for item in valid_len_list], prefix_len)
if score_matrix.is_cuda:
loss_mask = loss_mask.cuda(score_matrix.get_device())
masked_loss_matrix = loss_matrix * loss_mask
loss_matrix = torch.sum(masked_loss_matrix, dim = -1)
assert loss_matrix.size() == input_ids.size()
loss_matrix = loss_matrix * input_mask
cl_loss = torch.sum(loss_matrix) / torch.sum(loss_mask)
return cl_loss

233
language_model/simctg.py

@ -0,0 +1,233 @@
import os
import sys
import operator
from tqdm import tqdm
from operator import itemgetter
import torch
from torch import nn
import random
import argparse
import numpy as np
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from loss_func import contrastive_loss
from utlis import PlugAndPlayContrastiveDecodingOneStepFast
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import datetime
train_fct = CrossEntropyLoss()
val_fct = CrossEntropyLoss(reduction='none')
class SimCTG(nn.Module):
def __init__(self, model_name, sos_token, pad_token):
super(SimCTG, self).__init__()
from transformers import AutoTokenizer, GPT2LMHeadModel
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.sos_token, self.sos_token_id = self.add_special_token(sos_token)
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id))
self.pad_token, self.pad_token_id = self.add_special_token(pad_token)
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id))
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id))
self.model = GPT2LMHeadModel.from_pretrained(model_name)
self.vocab_size = len(self.tokenizer)
print ('Resizing model embedding...')
self.model.resize_token_embeddings(len(self.tokenizer))
print ('Model embedding resized!')
self.embed_dim = self.model.config.hidden_size
def add_special_token(self, special_token):
if special_token in self.tokenizer.vocab:
print (special_token + ' token exists.')
else:
print ('Add token to the tokenizer.')
print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
self.tokenizer.add_tokens([special_token])
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0]
return special_token, special_token_id
def compute_logits_and_hidden_states(self, input_ids):
# used for advanced decoding
# input_ids: 1 x seqlen
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
last_hidden_states = outputs.hidden_states[-1]
logits = outputs.logits
return last_hidden_states, logits
def forward(self, input_ids, labels, margin):
bsz, seqlen = input_ids.size()
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
logits = outputs.logits
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size])
last_hidden_states = outputs.hidden_states[-1]
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim])
mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1))
norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True)
cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2))
assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen])
cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0)
return mle_loss, cl_loss
def eval_loss(self, input_ids, labels):
bsz, seqlen = input_ids.size()
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
logits = outputs.logits
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size])
last_hidden_states = outputs.hidden_states[-1]
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim])
mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1))
assert mle_loss.size() == torch.Size([bsz * seqlen])
mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0)
mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0)
# sum
mle_loss_sum = torch.sum(mle_loss)
token_num_sum = torch.sum(mask)
return mle_loss_sum, token_num_sum
def save_model(self, ckpt_save_path):
import os
if os.path.exists(ckpt_save_path):
pass
else: # recursively construct directory
os.makedirs(ckpt_save_path, exist_ok=True)
# save model
self.model.save_pretrained(ckpt_save_path)
# save tokenizer
self.tokenizer.save_pretrained(ckpt_save_path)
def parse_sentences(self, text, num_of_sentences_to_keep):
item_list = text.split('.')
res_list = item_list[:num_of_sentences_to_keep]
if len(item_list) > num_of_sentences_to_keep:
res_text = '.'.join(res_list).strip('.') + '.'
else:
res_text = '.'.join(res_list).strip('.').strip()
return res_text
def parse_generated_result(self, output, num_of_sentences_to_keep):
output_text = self.tokenizer.decode(output)
item_list = output_text.split(self.eos_token)
full_text = self.eos_token.join(item_list[:2]).strip()
full_text = self.parse_sentences(full_text, num_of_sentences_to_keep)
generated_text = item_list[1].strip()
generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep)
return full_text, generated_text
# decoding functions
# ------------------------------------------------------- #
def parse_output_token_list(self, output):
output = output.tolist()
res_list = []
for token_id in output:
if token_id == self.sos_token_id:
continue
elif token_id == self.eos_token_id:
break
else:
res_list.append(token_id)
text = self.tokenizer.decode(res_list).strip()
return ' '.join(text.split()).strip()
@torch.no_grad()
def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip,
clip_text_max_len):#, add_token_level_score=False):
prefix_len = input_ids.size()[1]
#from utlis import PlugAndPlayContrastiveDecodingOneStepFast
past_key_values, last_hidden_states, logits = None, None, None
generated = [item for item in input_ids.tolist()]
input_ids_for_class = input_ids.clone()
image_embeds = clip.compute_image_representation_from_image_instance(image_instance)
start_time = datetime.datetime.now()
# the maximum supported length of generation for SimCTG is 256
# to support longer generated length, you can re-train the SimCTG model with longer sequences
decoding_len = decoding_len - prefix_len
for step in range(decoding_len):
input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \
PlugAndPlayContrastiveDecodingOneStepFast(
self.model,
input_ids,
prefix_len,
beam_width,
alpha,
beta,
self.tokenizer,
image_embeds,
clip,
clip_text_max_len,
past_key_values,
last_hidden_states,
logits,
first_step=step==0,
input_ids_for_class=input_ids_for_class,
)
end_time = datetime.datetime.now()
time_diff = (end_time - start_time)
execution_time = time_diff.total_seconds() * 1000
return self.parse_output_token_list(input_ids_for_class[0])
def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len):
'''
input_ids: prefix input; 1 x prefix_len
decoding_len: how many tokens to generate
beam_width: size of candidate pool during decoding
alpha: regulates importance of model confidence and degeneration penalty
'''
self.model.eval()
#from utlis import ContrastiveDecodingOneStepFast
# sanity check
assert alpha >= 0. and alpha <= 1.0
# fast mode
prefix_len = input_ids.size()[1]
batch_size, seqlen = input_ids.size()
#generated = [[] for _ in range(batch_size)]
generated = [item for item in input_ids.tolist()]
past_key_values = None
last_hidden_states = None
logits = None
decoding_len = decoding_len - prefix_len
for step in range(decoding_len):
input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
self.model,
input_ids,
beam_width,
alpha,
past_key_values,
last_hidden_states,
self.tokenizer,
logits,
first_step=step == 0,
)
tokens = input_ids.squeeze(dim=-1).tolist()
for idx, t in enumerate(tokens):
generated[idx].append(t)
return self.parse_output_token_list(torch.LongTensor(generated[0]))
def top_k_sampling(self, input_ids, k, decoding_len):
_, prefix_len = input_ids.size()
output = self.model.generate(
input_ids,
do_sample=True,
max_length=decoding_len,
top_p=1.0,
top_k=k)
return self.parse_output_token_list(output[0])
def nucleus_sampling(self, input_ids, nucleus_p, decoding_len):
_, prefix_len = input_ids.size()
output = self.model.generate(
input_ids,
do_sample=True,
max_length=decoding_len,
top_p=nucleus_p,
top_k=0)
return self.parse_output_token_list(output[0])

107
language_model/train.py

@ -0,0 +1,107 @@
# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import argparse, os
import random
import numpy as np
import time
import logging
import progressbar
import logging
logging.getLogger('transformers.generation_utils').disabled = True
def parse_config():
parser = argparse.ArgumentParser()
# data configuration
parser.add_argument("--model_name", type=str, default='gpt2')
parser.add_argument("--train_path", type=str)
parser.add_argument("--dev_path", type=str)
parser.add_argument("--test_path", type=str)
parser.add_argument("--max_len", type=int)
parser.add_argument("--add_eos_token_to_data", type=str)
# mini-batch training configuration
parser.add_argument("--number_of_gpu", type=int, help="Number of available GPUs.")
parser.add_argument("--batch_size_per_gpu", type=int, help='batch size for each gpu.')
parser.add_argument("--gradient_accumulation_steps", type=int, help="gradient accumulation step.")
parser.add_argument("--effective_batch_size", type=int,
help="effective_bsz = batch_size_per_gpu x number_of_gpu x gradient_accumulation_steps")
# pre-training configuration
parser.add_argument("--total_steps", type=int,
help="total effective training steps")
parser.add_argument("--print_every", type=int,
help="how many update steps to print one intermediate result")
parser.add_argument("--save_every", type=int,
help="how many update steps to save one model")
# learning configuration
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--margin", type=float)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--save_path_prefix", type=str, help="directory to save the model parameters.")
return parser.parse_args()
def load_previous_best_model(path):
import os
filenames = os.listdir(path)
for file in filenames:
if file.startswith('training_step'):
return path + '/' + file
raise Exception('No best model found!')
import argparse
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
multi_gpu_training = False
if cuda_available:
if torch.cuda.device_count() > 1:
multi_gpu_training = True
print ('Using Multi-GPU training, number of GPU is {}'.format(torch.cuda.device_count()))
else:
print ('Using single GPU training.')
else:
pass
args = parse_config()
device = torch.device('cuda')
model_name = args.model_name
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
add_eos_token_to_data = args.add_eos_token_to_data
if add_eos_token_to_data == 'True':
add_eos_token_to_data = True
print ('Add eos token to data!')
elif add_eos_token_to_data == 'False':
add_eos_token_to_data = False
print ('Do not add eos token to data!')
else:
raise Exception('Wrong eos configuration for data!!!')
print ('Loading data...')
from dataclass import Data
data = Data(model_name, args.train_path, args.dev_path, args.test_path, args.max_len,
sos_token, pad_token, add_eos_token_to_data)
print ('Data loaded.')
from trainer import model_training
print ('############################################################')
print ('Start Training...')
from simctg import SimCTG
print ('Initializaing SimCTG model...')
model = SimCTG(model_name, sos_token, pad_token)
if cuda_available:
if multi_gpu_training:
model = nn.DataParallel(model) # multi-gpu training
else:
pass
model = model.to(device)
else:
pass
print ('Model loaded')
total_steps, print_every, save_every = args.total_steps, args.print_every, args.save_every
ckpt_save_path = args.save_path_prefix
model = model_training(args, data, model, total_steps, print_every, save_every,
ckpt_save_path, cuda_available, device)
print ('Training stage completed!')
print ('############################################################')

17
language_model/train_flickr30k.sh

@ -0,0 +1,17 @@
CUDA_VISIBLE_DEVICES=0 python train.py\
--model_name gpt2\
--train_path ../data/flickr30k/flickr30k_train.json\
--dev_path ../data/flickr30k/flickr30k_val.json\
--test_path ../data/flickr30k/flickr30k_test.json\
--add_eos_token_to_data True\
--margin 0.5\
--max_len 64\
--number_of_gpu 1\
--batch_size_per_gpu 32\
--gradient_accumulation_steps 4\
--effective_batch_size 128\
--total_steps 10000\
--print_every 50\
--save_every 250\
--learning_rate 2e-5\
--save_path_prefix ./magic_flickr30k/

17
language_model/train_mscoco.sh

@ -0,0 +1,17 @@
CUDA_VISIBLE_DEVICES=0 python train.py\
--model_name gpt2\
--train_path ../data/mscoco/mscoco_train.json\
--dev_path ../data/mscoco/mscoco_val.json\
--test_path ../data/mscoco/mscoco_test.json\
--add_eos_token_to_data True\
--margin 0.5\
--max_len 64\
--number_of_gpu 1\
--batch_size_per_gpu 32\
--gradient_accumulation_steps 4\
--effective_batch_size 128\
--total_steps 20000\
--print_every 100\
--save_every 500\
--learning_rate 2e-5\
--save_path_prefix ./magic_mscoco/

165
language_model/trainer.py

@ -0,0 +1,165 @@
# coding=utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.multiprocessing as mp
import argparse, os
import random
import numpy as np
import time
import logging
import progressbar
import logging
logging.getLogger('transformers.generation_utils').disabled = True
def eval_model(args, model, data, cuda_available, device):
dataset_batch_size = args.batch_size_per_gpu * args.number_of_gpu
eval_step = int(data.test_num / dataset_batch_size) + 1
val_loss, token_sum = 0., 0.
model.eval()
with torch.no_grad():
p = progressbar.ProgressBar(eval_step)
p.start()
for idx in range(eval_step):
p.update(idx)
batch_input_tensor, batch_labels, _ = \
data.get_next_validation_batch(batch_size=dataset_batch_size, mode='test')
if cuda_available:
batch_input_tensor = batch_input_tensor.cuda(device)
batch_labels = batch_labels.cuda(device)
one_val_loss, one_val_token_sum = model.eval_loss(batch_input_tensor, batch_labels)
one_val_loss = torch.sum(one_val_loss)
one_val_token_sum = torch.sum(one_val_token_sum)
val_loss += one_val_loss.item()
token_sum += one_val_token_sum.item()
p.finish()
model.train()
val_loss = val_loss / token_sum
return val_loss
def model_training(args, data, model, total_steps, print_every, save_every, ckpt_save_path, cuda_available, device):
import os
if os.path.exists(ckpt_save_path):
pass
else: # recursively construct directory
os.makedirs(ckpt_save_path, exist_ok=True)
max_save_num = 1
batch_size_per_gpu, gradient_accumulation_steps, number_of_gpu, effective_batch_size = \
args.batch_size_per_gpu, args.gradient_accumulation_steps, args.number_of_gpu, args.effective_batch_size
assert effective_batch_size == batch_size_per_gpu * gradient_accumulation_steps * number_of_gpu
warmup_steps = int(0.1 * total_steps) # 10% of training steps are used for warmup
print ('total training steps is {}, warmup steps is {}'.format(total_steps, warmup_steps))
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
optimizer = AdamW(model.parameters(), lr=args.learning_rate)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
optimizer.zero_grad()
effective_batch_acm = 0
all_batch_step = 1
print_valid, save_valid = False, False
train_loss, train_cl_loss, min_val_loss = 0., 0., 1e10
train_ave_bleu = 0.
print ('--------------------------------------------------------------------------')
print ('Start Training:')
model.train()
number_of_saves = 0
while effective_batch_acm < total_steps:
all_batch_step += 1
train_batch_input_tensor, train_batch_labels, _ = data.get_next_train_batch(batch_size_per_gpu * number_of_gpu)
if cuda_available:
train_batch_input_tensor = train_batch_input_tensor.cuda(device)
train_batch_labels = train_batch_labels.cuda(device)
mle_loss, cl_loss = model(train_batch_input_tensor, train_batch_labels, args.margin)
loss = mle_loss + cl_loss
loss = loss.mean()
loss.backward()
train_loss += mle_loss.item()
train_cl_loss += cl_loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
# parameter update
if all_batch_step % gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
optimizer.zero_grad()
effective_batch_acm += 1
print_valid, save_valid = True, True
# print intermediate result
if effective_batch_acm % print_every == 0 and print_valid:
denominator = (effective_batch_acm - (number_of_saves * save_every)) * gradient_accumulation_steps
one_train_loss = train_loss / denominator
one_train_cl_loss = train_cl_loss / denominator
print ('At training steps {}, training MLE loss is {}, train CL loss is {}'.format(effective_batch_acm,
one_train_loss, one_train_cl_loss))
print_valid = False
# saving result
if effective_batch_acm % save_every == 0 and save_valid:
number_of_saves += 1
save_valid = False
one_train_loss = train_loss / (save_every * gradient_accumulation_steps)
one_train_cl_loss = train_cl_loss / (save_every * gradient_accumulation_steps)
model.eval()
one_val_loss = eval_model(args, model, data, cuda_available, device)
model.train()
print ('At training steps {}, training MLE loss is {}, train CL loss is {}, validation loss is {}'.format(effective_batch_acm,
one_train_loss, one_train_cl_loss, one_val_loss))
train_loss, train_cl_loss = 0., 0.
if one_val_loss < min_val_loss:
# in finetuning stage, we always save the model
min_val_loss = min(one_val_loss, min_val_loss)
print ('Saving model...')
one_val_ppl = np.exp(one_val_loss)
one_val_ppl = round(one_val_ppl, 3)
save_name = 'training_step_{}_train_mle_loss_{}_train_cl_loss_{}_dev_loss_{}_dev_ppl_{}'.format(effective_batch_acm,
round(one_train_loss,5), round(one_train_cl_loss,5), round(one_val_loss,5), one_val_ppl)
model_save_path = ckpt_save_path + '/' + save_name
import os
if os.path.exists(model_save_path):
pass
else: # recursively construct directory
os.makedirs(model_save_path, exist_ok=True)
if cuda_available and torch.cuda.device_count() > 1:
model.module.save_model(model_save_path)
else:
model.save_model(model_save_path)
print ('Model Saved!')
# --------------------------------------------------------------------------------------------- #
# removing extra checkpoints...
import os
from operator import itemgetter
fileData = {}
test_output_dir = ckpt_save_path
for fname in os.listdir(test_output_dir):
if fname.startswith('training_step'):
fileData[fname] = os.stat(test_output_dir + '/' + fname).st_mtime
else:
pass
sortedFiles = sorted(fileData.items(), key=itemgetter(1))
if len(sortedFiles) < max_save_num:
pass
else:
delete = len(sortedFiles) - max_save_num
for x in range(0, delete):
one_folder_name = test_output_dir + '/' + sortedFiles[x][0]
os.system('rm -r ' + one_folder_name)
print ('-----------------------------------')
# --------------------------------------------------------------------------------------------- #
return model

291
language_model/utlis.py

@ -0,0 +1,291 @@
import sys
import os
import operator
from operator import itemgetter
import torch
from torch import nn
import torch.nn.functional as F
import random
import numpy as np
import argparse
import random
def parse_prompt(text):
'''
process the prompt text;
'''
eos_token = '<|endoftext|>'
text = text.strip(eos_token).strip()
left_bracket_idx, right_bracket_idx = -1, -1
for idx in range(len(text)):
char = text[idx]
if char == '[' and left_bracket_idx == -1: # first [ is met
left_bracket_idx = idx
elif char == ']' and right_bracket_idx == -1: # first ] is met
right_bracket_idx = idx
else:
pass
res_text = ''
remove = False
if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx:
if right_bracket_idx - left_bracket_idx <= 6:
remove = True
else:
pass
for idx in range(len(text)):
if remove:
if idx >= left_bracket_idx and idx <= right_bracket_idx:
continue
else:
res_text += text[idx]
else:
res_text += text[idx]
res_text = res_text.strip()
res_text = ' '.join(res_text.split()).strip()
return res_text
def typical_filtering(scores, mass, min_tokens_to_keep, filter_value):
# calculate entropy
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
p = torch.exp(normalized)
ent = -(normalized * p).nansum(-1, keepdim=True)
# shift and sort
shifted_scores = torch.abs((-normalized) - ent)
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
sorted_logits = scores.gather(-1, sorted_indices)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
# Remove tokens with cumulative mass above the threshold
last_ind = (cumulative_probs < mass).sum(dim=1)
last_ind[last_ind < 0] = 0
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
if min_tokens_to_keep > 1:
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
scores = scores.masked_fill(indices_to_remove, filter_value)
return scores
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf):
assert logits.dim() == 1
top_k = min(top_k, logits.size(-1))
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = filter_value
indices_to_remove = logits < threshold
logits[indices_to_remove] = filter_value
return logits
# ========== batch version ========= #
def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
'''
context_hidden: bsz*beam x seqlen x embed_dim
next_hidden: bsz*beam x 1 x embed_dim
next_top_k_probs: bsz x beam
'''
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S]
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
selected_idx = scores.max(dim=-1)[1] # [B]
return selected_idx
def ContrastiveDecodingOneStepFast(
model,
ids,
beam_width,
alpha,
past_key_values,
last_hidden_states,
vocab,
logit_for_next_step,
first_step=False,
):
# input_ids: [B, S]
if first_step:
output = model(
input_ids=ids,
past_key_values=past_key_values,
use_cache=True,
output_hidden_states=True
)
past_key_values = output.past_key_values
last_hidden_states = output.hidden_states[-1] # [B, S, E]
logit_for_next_step = output.logits[:, -1, :] # [B, V]
bsz, seqlen, embed_dim = last_hidden_states.size()
p = random.uniform(0, 1)
next_probs = F.softmax(logit_for_next_step, dim=-1)
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K]
# compute new hidden
past_key_values = enlarge_past_key_values(past_key_values, beam_width)
output = model(
input_ids=top_k_ids.view(-1, 1),
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
past_key_values=past_key_values,
output_hidden_states=True,
use_cache=True,
)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :] # [B*K, V]
next_hidden = output.hidden_states[-1] # [B*K, 1, E]
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E]
selected_idx = ranking_fast(
context_hidden,
next_hidden,
top_k_probs, # [B, K]
alpha,
beam_width,
) # [B]
# prepare for the next step
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1]
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E]
next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E]
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V]
# next_id: [B, 1]
return next_id, past_key_values, last_hidden_states, logits
def enlarge_past_key_values(past_key_values, beam_width):
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
new_key_values = []
for layer in past_key_values:
items = []
for item in layer:
# item is the key and value matrix
bsz, num_head, seq_len, esz = item.size()
item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values
def select_past_key_values(past_key_values, beam_width, selected_idx):
'''select_idx: [B]'''
new_key_values = []
for layer in past_key_values:
items = []
for item in layer:
bsz_and_beam, num_head, seq_len, esz = item.size()
bsz = int(bsz_and_beam//beam_width)
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
items.append(item)
new_key_values.append(items)
return new_key_values
# ========== fast plug and play version ========= #
def plug_and_play_fast_ranking(
context_hidden,
next_hidden,
next_top_k_ids,
next_top_k_probs,
alpha,
beta,
batch_class_score,
beam_width,
):
'''
context_hidden: beam_width x context_len x embed_dim
next_hidden: beam_width x 1 x embed_dim
next_top_k_ids: beam_width x 1
batch_class_score: beam_width x 1
'''
_, context_len, embed_dim = context_hidden.size()
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
scores, _ = torch.max(cosine_matrix, dim = -1)
next_top_k_probs = next_top_k_probs.view(-1)
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width])
scores = torch.stack(torch.split(scores, beam_width))
selected_idx = scores.max(dim=-1)[1]
return selected_idx
def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta,
simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states,
logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False):
'''
model: the generation model, e.g., gpt2
input_ids: 1 x seqlen
'''
if first_step:
output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True)
past_key_values = output.past_key_values
last_hidden_states = output.hidden_states[-1] # [B, S, E]
logit_for_next_step = output.logits[:, -1, :] # [B, V]
bsz, seqlen, embed_dim = last_hidden_states.size()
next_probs = F.softmax(logit_for_next_step, dim = -1)
_, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
# compute the new hidden
past_key_values = enlarge_past_key_values(past_key_values, beam_width)
output = model(
input_ids=top_k_ids.view(-1, 1) ,
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
past_key_values=past_key_values,
output_hidden_states=True,
use_cache=True,
)
past_key_values = output.past_key_values
logits = output.logits[:, -1, :]
next_hidden = output.hidden_states[-1]
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim)
# prepare for the classification model
input_ids_for_class_ = torch.cat([
input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen),
top_k_ids.view(-1, 1)
], dim=-1
)
batch_text_list = []
for one_input_id in input_ids_for_class_:
one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:])
# we only consider the class score of the generated text continuation
batch_text_list.append(one_text)
batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list)
selected_idx = plug_and_play_fast_ranking(
context_hidden,
next_hidden,
top_k_ids,
top_k_probs,
alpha,
beta,
batch_score,
beam_width,
)
# prepare for the next step
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))
next_hidden = next_hidden[range(bsz), selected_idx, :]
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]
input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1)
return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class

29
magic.py

@ -29,7 +29,6 @@ from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
class Magic(NNOperator):
"""
@ -38,22 +37,32 @@ class Magic(NNOperator):
def __init__(self, model_name: str):
super().__init__()
path = str(pathlib.Path(__file__).parent)
sys.path.append(path)
sys.path.append(path + '/clip')
sys.path.append(path + '/language_model')
print(sys.path)
from clip import CLIP
from simctg import SimCTG
sys.path.pop()
sys.path.pop()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Language Model
language_model_name = r'cambridgeltl/magic_mscoco' # or r'/path/to/downloaded/cambridgeltl/magic_mscoco'
cfg = self._configs()[model_name]
language_model_name = cfg['language_model'] # or r'/path/to/downloaded/cambridgeltl/magic_mscoco'
sos_token, pad_token = r'<-start_of_text->', r'<-pad->'
self.generation_model = SimCTG(language_model_name, sos_token, pad_token).to(self.device)
self.generation_model.eval()
model_name = r"openai/clip-vit-base-patch32" # or r"/path/to/downloaded/openai/clip-vit-base-patch32"
model_name = cfg['clip_model'] # or r"/path/to/downloaded/openai/clip-vit-base-patch32"
self.clip = CLIP(model_name).to(self.device)
self.clip.to(self.device)
self.clip.eval()
sos_token = r'<-start_of_text->'
start_token = self.generation_model.tokenizer.tokenize(sos_token)
start_token_id = self.generation_model.tokenizer.convert_tokens_to_ids(start_token)
self.input_ids = torch.LongTensor(start_token_id).view(1,-1).to(self.device)
def _preprocess(self, img):
img = to_pil(img)
@ -87,13 +96,15 @@ class Magic(NNOperator):
k, alpha, beta, decoding_len = 45, 0.1, 2.0, 16
eos_token = '<|endoftext|>'
with torch.no_grad():
output = generation_model.magic_search(input_ids, k,
alpha, decoding_len, beta, image_instance, clip, 60)
print(type(img))
output = self.generation_model.magic_search(self.input_ids, k,
alpha, decoding_len, beta, img, self.clip, 60)
return out
return output
def _configs(self):
config = {}
config['expansionnet_rf'] = {}
config['expansionnet_rf']['weights'] = 'rf_model.pth'
config['magic_mscoco'] = {}
config['magic_mscoco']['language_model'] = 'cambridgeltl/magic_mscoco'
config['magic_mscoco']['clip_model'] = 'openai/clip-vit-base-patch32'
return config

89
zerocap/README.md

@ -1,89 +0,0 @@
### Our Implementation of the ZeroCap Baseline Model
****
### Catalogue:
* <a href='#environment'>1. Environment Preparation</a>
* <a href='#mscoco'>2. Image Captioning on MSCOCO</a>
* <a href='#flickr30k'>3. Image Captioning on Flickr30k</a>
* <a href='#flickr30k_to_mscoco'>4. Cross Domain Image Captioning on MSCOCO</a>
* <a href='#mscoco_to_flickr30k'>5. Cross Domain Image Captioning on Flickr30k</a>
* <a href='#citation'>6. Citation</a>
* <a href='#acknowledgements'>7. Acknowledgements</a>
****
<span id='environment'/>
#### 1. Environment Preparation:
To install the correct environment, please run the following command:
```yaml
pip install -r requirements.txt
```
****
<span id='mscoco'/>
#### 2. Image Captioning on MSCOCO:
To perform image captioning on MSCOCO, please run the following command:
```yaml
chmod +x ./mscoco_zerocap.sh
./mscoco_zerocap.sh
```
****
<span id='flickr30k'/>
#### 3. Image Captioning on Flickr30k:
To perform image captioning on Flickr30k, please run the following command:
```yaml
chmod +x ./flickr30k_zerocap.sh
./flickr30k_zerocap.sh
```
****
<span id='flickr30k_to_mscoco'/>
#### 4. Cross Domain Image Captioning on MSCOCO:
To perform image captioning on MSCOCO with the language model from Flickr30k domain, please run the following command:
```yaml
chmod +x ./flickr30k_to_mscoco_zerocap.sh
./flickr30k_to_mscoco_zerocap.sh
```
****
<span id='mscoco_to_flickr30k'/>
#### 5. Cross Domain Image Captioning on Flickr30k:
To perform image captioning on Flickr30k with the language model from MSCOCO domain, please run the following command:
```yaml
chmod +x ./mscoco_to_flickr30k_zerocap.sh
./mscoco_to_flickr30k_zerocap.sh
```
****
<span id='citation'/>
#### 6. Citation:
If you find our code helpful, please cite the original paper as
```bibtex
@article{tewel2021zero,
title={Zero-Shot Image-to-Text Generation for Visual-Semantic Arithmetic},
author={Tewel, Yoad and Shalev, Yoav and Schwartz, Idan and Wolf, Lior},
journal={arXiv preprint arXiv:2111.14447},
year={2021}
}
```
****
<span id='acknowledgements'/>
#### 7. Acknowledgements:
We thank the authors for releasing their code. Our reimplementation of the baseline is based on their original codebase [[here]](https://github.com/yoadtew/zero-shot-image-to-text).

12
zerocap/cog.yaml

@ -1,12 +0,0 @@
build:
gpu: true
python_version: "3.8"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "git+https://github.com/openai/CLIP.git"
- "git+https://github.com/YoadTew/zero-shot-image-to-text.git"
predict: "predict.py:Predictor"
#predict: "predict_arithmetic.py:Predictor"

14
zerocap/flickr30k_zerocap.sh

@ -1,14 +0,0 @@
#!/bin/bash
# lm_model:
# 1. cambridgeltl/magic_mscoco
# 2. cambridgeltl/magic_flickr30k
CUDA_VISIBLE_DEVICES=1 python run.py \
--beam_size 1 \
--target_seq_length 16 \
--reset_context_delta \
--lm_model cambridgeltl/magic_flickr30k \
--test_image_prefix_path ../data/flickr30k/test_images \
--test_path ../data/flickr30k/flickr30k_test.json \
--save_path_prefix ../inference_result/flickr30k/baselines/ \
--save_name zerocap_result.json

BIN
zerocap/forbidden_tokens.npy

Binary file not shown.

389
zerocap/model/ZeroCLIP.py

@ -1,389 +0,0 @@
import numpy as np
from torch import nn
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
from transformers.models.gpt_neo import GPTNeoForCausalLM
import torch
import clip
from PIL import Image
from datetime import datetime
import sys
def log_info(text, verbose=True):
if verbose:
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f'{dt_string} | {text}')
sys.stdout.flush()
def add_context(x, y):
return (x[0] + y[0], x[1] + y[1])
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
class CLIPTextGenerator:
def __init__(self,
seed=0,
lm_model='gpt-2',
forbidden_tokens_file_path='./forbidden_tokens.npy',
clip_checkpoints='./clip_checkpoints',
target_seq_length=15,
reset_context_delta=True,
num_iterations=5,
clip_loss_temperature=0.01,
clip_scale=1.,
ce_scale=0.2,
stepsize=0.3,
grad_norm_factor=0.9,
fusion_factor=0.99,
repetition_penalty=1.,
end_token='.',
end_factor=1.01,
forbidden_factor=20,
**kwargs):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# set Random seed
torch.manual_seed(seed)
np.random.seed(seed)
# Initialize Language model
self.context_prefix = ''
self.lm_tokenizer = GPT2Tokenizer.from_pretrained(lm_model)
self.lm_model = GPT2LMHeadModel.from_pretrained(lm_model, output_hidden_states=True)
self.context_prefix = self.lm_tokenizer.bos_token
self.lm_model.to(self.device)
self.lm_model.eval()
self.forbidden_tokens = np.load(forbidden_tokens_file_path)
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
(x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())]
# Freeze LM weights
for param in self.lm_model.parameters():
param.requires_grad = False
# Initialize CLIP
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
download_root=clip_checkpoints, jit=False)
# convert_models_to_fp32(self.clip)
# Init arguments
self.target_seq_length = target_seq_length
self.reset_context_delta = reset_context_delta
self.num_iterations = num_iterations
self.clip_loss_temperature = clip_loss_temperature
self.clip_scale = clip_scale
self.ce_scale = ce_scale
self.stepsize = stepsize
self.grad_norm_factor = grad_norm_factor
self.fusion_factor = fusion_factor
self.repetition_penalty = repetition_penalty
self.end_token = self.lm_tokenizer.encode(end_token)[0]
self.end_factor = end_factor
self.ef_idx = 1
self.forbidden_factor = forbidden_factor
def get_img_feature(self, img_path, weights):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
with torch.no_grad():
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
if weights is not None:
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
else:
image_features = sum(image_fts)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
return image_features.detach()
def get_txt_features(self, text):
clip_texts = clip.tokenize(text).to(self.device)
with torch.no_grad():
text_features = self.clip.encode_text(clip_texts)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
return text_features.detach()
def get_combined_feature(self, img_path, texts, weights_i, weights_t):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
with torch.no_grad():
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
text_fts = [self.clip.encode_text(x) for x in clip_texts]
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
if weights_t is not None:
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
features = features / features.norm(dim=-1, keepdim=True)
return features.detach()
def run(self, image_features, cond_text, beam_size):
self.image_features = image_features
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
output_tokens, output_text = self.generate_text(context_tokens, beam_size)
return output_text
def generate_text(self, context_tokens, beam_size):
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
gen_tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=self.device)
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
for i in range(self.target_seq_length):
probs = self.get_next_probs(i, context_tokens)
logits = probs.log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if gen_tokens is None:
gen_tokens = next_tokens
else:
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
gen_tokens = gen_tokens[next_tokens_source]
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
context_tokens = context_tokens[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
####
tmp_scores = scores / seq_lengths
tmp_output_list = gen_tokens.cpu().numpy()
tmp_output_texts = [
self.lm_tokenizer.decode(tmp_output)
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
]
tmp_order = tmp_scores.argsort(descending=True)
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
log_info(tmp_output_texts, verbose=True)
####
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = gen_tokens.cpu().numpy()
output_texts = [
self.lm_tokenizer.decode(output[: int(length)])
for output, length in zip(output_list, seq_lengths)
]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return context_tokens, output_texts
def get_next_probs(self, i, context_tokens):
last_token = context_tokens[:, -1:]
if self.reset_context_delta and context_tokens.size(1) > 1:
context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
# Logits of LM with unshifted context
logits_before_shift = self.lm_model(context_tokens)["logits"]
logits_before_shift = logits_before_shift[:, -1, :]
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
if context:
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
lm_output = self.lm_model(last_token, past_key_values=context)
logits, past = (
lm_output["logits"],
lm_output["past_key_values"],
)
logits = logits[:, -1, :]
logits = self.update_special_tokens_logits(context_tokens, i, logits)
probs = nn.functional.softmax(logits, dim=-1)
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
probs = probs / probs.sum()
return probs
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
window_mask = torch.ones_like(context[0][0]).to(self.device)
for i in range(self.num_iterations):
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
context_delta]
for p0, p1 in curr_shift:
p0.retain_grad()
p1.retain_grad()
shifted_context = list(map(add_context, context, curr_shift))
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
logits = shifted_outputs["logits"][:, -1, :]
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
# CLIP LOSS
clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
loss += self.clip_scale * clip_loss
# CE/Fluency loss
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
loss += ce_loss.sum()
loss.backward()
# ---------- Weights ----------
combined_scores_k = -(ce_loss)
combined_scores_c = -(self.clip_scale * torch.stack(clip_losses))
# minmax
if combined_scores_k.shape[0] == 1:
tmp_weights_c = tmp_weights_k = torch.ones(*combined_scores_k.shape).to(self.device)
else:
tmp_weights_k = ((combined_scores_k - combined_scores_k.min())) / (
combined_scores_k.max() - combined_scores_k.min())
tmp_weights_c = ((combined_scores_c - combined_scores_c.min())) / (
combined_scores_c.max() - combined_scores_c.min())
tmp_weights = 0.5 * tmp_weights_k + 0.5 * tmp_weights_c
tmp_weights = tmp_weights.view(tmp_weights.shape[0], 1, 1, 1)
factor = 1
# --------- Specific Gen ---------
sep_grads = None
for b in range(context_tokens.shape[0]):
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
for p_ in curr_shift]
# normalize gradients
tmp_grad = [tuple([-self.stepsize * factor * (
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
j] ** self.grad_norm_factor).data.cpu().numpy()
for j, x in enumerate(p_)])
for i, p_ in enumerate(curr_shift)]
if sep_grads is None:
sep_grads = tmp_grad
else:
for l_index in range(len(sep_grads)):
sep_grads[l_index] = list(sep_grads[l_index])
for k_index in range(len(sep_grads[0])):
sep_grads[l_index][k_index] = np.concatenate(
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
sep_grads[l_index] = tuple(sep_grads[l_index])
final_grads = sep_grads
# --------- update context ---------
context_delta = list(map(add_context, final_grads, context_delta))
for p0, p1 in curr_shift:
p0.grad.data.zero_()
p1.grad.data.zero_()
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
for p_ in context_delta]
context = list(map(add_context, context, context_delta))
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
return context
def update_special_tokens_logits(self, context_tokens, i, logits):
for beam_id in range(context_tokens.shape[0]):
for token_idx in set(context_tokens[beam_id][-4:].tolist()):
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
logits[beam_id, token_idx] /= factor
if i >= self.ef_idx:
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
logits[beam_id, self.end_token] *= factor
if i == 0:
start_factor = 1.6
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
logits[beam_id, self.end_token] /= factor
for token_idx in list(self.forbidden_tokens):
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
logits[beam_id, token_idx] /= factor
return logits
def clip_loss(self, probs, context_tokens):
for p_ in self.clip.transformer.parameters():
if p_.grad is not None:
p_.grad.data.zero_()
top_size = 512
_, top_indices = probs.topk(top_size, -1)
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
clip_loss = 0
losses = []
for idx_p in range(probs.shape[0]):
top_texts = []
prefix_text = prefix_texts[idx_p]
for x in top_indices[idx_p]:
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
text_features = self.get_txt_features(top_texts)
with torch.no_grad():
similiraties = (self.image_features @ text_features.T)
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
target_probs = target_probs.type(torch.float32)
target = torch.zeros_like(probs[idx_p])
target[top_indices[idx_p]] = target_probs[0]
target = target.unsqueeze(0)
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
clip_loss += cur_clip_loss
losses.append(cur_clip_loss)
return clip_loss, losses

449
zerocap/model/ZeroCLIP_batched.py

@ -1,449 +0,0 @@
import numpy as np
from torch import nn
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
from transformers.models.gpt_neo import GPTNeoForCausalLM
import torch
import clip
from PIL import Image
from datetime import datetime
import sys
class TextCLIP(nn.Module):
def __init__(self, model):
super(TextCLIP, self).__init__()
self.model = model
def forward(self, text):
return self.model.encode_text(text)
class ImageCLIP(nn.Module):
def __init__(self, model):
super(ImageCLIP, self).__init__()
self.model = model
def forward(self, image):
return self.model.encode_image(image)
def log_info(text, verbose=True):
if verbose:
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
print(f'{dt_string} | {text}')
sys.stdout.flush()
def add_context(x, y):
return (x[0] + y[0], x[1] + y[1])
def convert_models_to_fp32(model):
for p in model.parameters():
p.data = p.data.float()
class CLIPTextGenerator:
def __init__(self,
seed=0,
lm_model='gpt-2',
forbidden_tokens_file_path='./forbidden_tokens.npy',
clip_checkpoints='./clip_checkpoints',
target_seq_length=15,
reset_context_delta=True,
num_iterations=5,
clip_loss_temperature=0.01,
clip_scale=1.,
ce_scale=0.2,
stepsize=0.3,
grad_norm_factor=0.9,
fusion_factor=0.99,
repetition_penalty=1.,
end_token='.',
end_factor=1.01,
forbidden_factor=20,
**kwargs):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# set Random seed
torch.manual_seed(seed)
np.random.seed(seed)
# Initialize Language model
self.context_prefix = ''
if lm_model == 'gpt-neo':
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True)
elif lm_model == 'gpt-2':
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True)
self.context_prefix = self.lm_tokenizer.bos_token
self.lm_model.to(self.device)
self.lm_model.eval()
self.forbidden_tokens = np.load(forbidden_tokens_file_path)
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
(x[0] == 'Ġ' and len(x) > 1 and x[1].isupper())]
# Freeze LM weights
for param in self.lm_model.parameters():
param.requires_grad = False
# Initialize CLIP
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
download_root=clip_checkpoints, jit=False)
self.clip_image = ImageCLIP(self.clip)
self.clip_image = torch.nn.DataParallel(self.clip_image)
self.clip_text = TextCLIP(self.clip)
self.clip_text = torch.nn.DataParallel(self.clip_text)
# Init arguments
self.target_seq_length = target_seq_length
self.reset_context_delta = reset_context_delta
self.num_iterations = num_iterations
self.clip_loss_temperature = clip_loss_temperature
self.clip_scale = clip_scale
self.ce_scale = ce_scale
self.stepsize = stepsize
self.grad_norm_factor = grad_norm_factor
self.fusion_factor = fusion_factor
self.repetition_penalty = repetition_penalty
self.end_token = self.lm_tokenizer.encode(end_token)[0]
self.end_factor = end_factor
self.ef_idx = 1
self.forbidden_factor = forbidden_factor
def get_img_feature(self, img_path, weights):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
with torch.no_grad():
image_fts = [self.clip_image(x) for x in clip_imgs]
if weights is not None:
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
else:
image_features = sum(image_fts)
image_features = torch.nn.functional.normalize(image_features, dim=-1)
return image_features.detach()
def get_txt_features(self, text):
clip_texts = clip.tokenize(text).to(self.device)
with torch.no_grad():
text_features = self.clip_text(clip_texts)
text_features = torch.nn.functional.normalize(text_features, dim=-1)
return text_features.detach()
def get_combined_feature(self, img_path, texts, weights_i, weights_t):
imgs = [Image.open(x) for x in img_path]
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
with torch.no_grad():
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
text_fts = [self.clip.encode_text(x) for x in clip_texts]
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
if weights_t is not None:
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
features = features / features.norm(dim=-1, keepdim=True)
return features.detach()
def run(self, image_features, cond_text, beam_size):
self.image_features = image_features
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
output_tokens, output_text = self.generate_text(context_tokens, beam_size)
return output_text
def generate_text(self, context_tokens, beam_size):
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
gen_tokens = None
scores = None
seq_lengths = torch.ones(beam_size, device=self.device)
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
for i in range(self.target_seq_length):
probs = self.get_next_probs(i, context_tokens)
logits = probs.log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if gen_tokens is None:
gen_tokens = next_tokens
else:
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
gen_tokens = gen_tokens[next_tokens_source]
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
context_tokens = context_tokens[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
####
tmp_scores = scores / seq_lengths
tmp_output_list = gen_tokens.cpu().numpy()
tmp_output_texts = [
self.lm_tokenizer.decode(tmp_output)
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
]
tmp_order = tmp_scores.argsort(descending=True)
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
log_info(tmp_output_texts, verbose=True)
####
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = gen_tokens.cpu().numpy()
output_texts = [
self.lm_tokenizer.decode(output[: int(length)])
for output, length in zip(output_list, seq_lengths)
]
order = scores.argsort(descending=True)
output_texts = [output_texts[i] for i in order]
return context_tokens, output_texts
def get_next_probs(self, i, context_tokens):
last_token = context_tokens[:, -1:]
if self.reset_context_delta and context_tokens.size(1) > 1:
context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
# Logits of LM with unshifted context
logits_before_shift = self.lm_model(context_tokens)["logits"]
logits_before_shift = logits_before_shift[:, -1, :]
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
if context:
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
lm_output = self.lm_model(last_token, past_key_values=context)
logits, past = (
lm_output["logits"],
lm_output["past_key_values"],
)
logits = logits[:, -1, :]
logits = self.update_special_tokens_logits(context_tokens, i, logits)
probs = nn.functional.softmax(logits, dim=-1)
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
probs = probs / probs.sum()
return probs
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
for i in range(self.num_iterations):
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
context_delta]
for p0, p1 in curr_shift:
p0.retain_grad()
p1.retain_grad()
shifted_context = list(map(add_context, context, curr_shift))
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
logits = shifted_outputs["logits"][:, -1, :]
probs = nn.functional.softmax(logits, dim=-1)
loss = 0.0
# CLIP LOSS
clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
loss += self.clip_scale * clip_loss
# CE/Fluency loss
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
loss += ce_loss.sum()
loss.backward()
# --------- Specific Gen ---------
final_grads = self.norm_grad(context, context_tokens, curr_shift)
# --------- update context ---------
context_delta = list(map(add_context, final_grads, context_delta))
for p0, p1 in curr_shift:
p0.grad.data.zero_()
p1.grad.data.zero_()
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
for p_ in context_delta]
context = list(map(add_context, context, context_delta))
new_context = []
for p0, p1 in context:
new_context.append((p0.detach(), p1.detach()))
context = new_context
return context
def norm_grad(self, context, context_tokens, curr_shift, ):
factor = 1
sep_grads = None
window_mask = torch.ones_like(context[0][0]).to(self.device)
for b in range(context_tokens.shape[0]):
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
for p_ in curr_shift]
# normalize gradients
tmp_grad = [tuple([-self.stepsize * factor * (
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
j] ** self.grad_norm_factor).data.cpu().numpy()
for j, x in enumerate(p_)])
for i, p_ in enumerate(curr_shift)]
if sep_grads is None:
sep_grads = tmp_grad
else:
for l_index in range(len(sep_grads)):
sep_grads[l_index] = list(sep_grads[l_index])
for k_index in range(len(sep_grads[0])):
sep_grads[l_index][k_index] = np.concatenate(
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
sep_grads[l_index] = tuple(sep_grads[l_index])
final_grads = sep_grads
return final_grads
def update_special_tokens_logits(self, context_tokens, i, logits):
for beam_id in range(context_tokens.shape[0]):
for token_idx in set(context_tokens[beam_id][-4:].tolist()):
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
logits[beam_id, token_idx] /= factor
if i >= self.ef_idx:
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
logits[beam_id, self.end_token] *= factor
if i == 0:
start_factor = 1.6
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
logits[beam_id, self.end_token] /= factor
for token_idx in list(self.forbidden_tokens):
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
logits[beam_id, token_idx] /= factor
return logits
def clip_loss(self, probs, context_tokens):
for p_ in self.clip.transformer.parameters():
if p_.grad is not None:
p_.grad.data.zero_()
top_size = 512
top_probs, top_indices = probs.topk(top_size, -1)
prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens]
clip_loss = 0
losses = []
top_texts = []
for idx_p in range(probs.shape[0]):
prefix_text = prefix_texts[idx_p]
for x in top_indices[idx_p]:
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1)
with torch.no_grad():
similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1)
similiraties = similiraties.reshape(probs.size(0), -1)
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
target_probs = target_probs.type(torch.float32)
clip_loss += torch.sum(-(target_probs * torch.log(top_probs)))
# for idx_p in range(probs.shape[0]):
# top_texts = []
# prefix_text = prefix_texts[idx_p]
# for x in top_indices[idx_p]:
# top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
# text_features = self.get_txt_features(top_texts)
#
# with torch.no_grad():
# similiraties = (self.image_features @ text_features.T)
# target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
# target_probs = target_probs.type(torch.float32)
#
# target = torch.zeros_like(probs[idx_p])
# target[top_indices[idx_p]] = target_probs[0]
# target = target.unsqueeze(0)
# cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
#
# clip_loss += cur_clip_loss
# losses.append(cur_clip_loss)
return clip_loss, losses
def clip_loss_old(self, probs, context_tokens):
for p_ in self.clip.transformer.parameters():
if p_.grad is not None:
p_.grad.data.zero_()
top_size = 512
_, top_indices = probs.topk(top_size, -1)
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
clip_loss = 0
losses = []
for idx_p in range(probs.shape[0]):
top_texts = []
prefix_text = prefix_texts[idx_p]
for x in top_indices[idx_p]:
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
text_features = self.get_txt_features(top_texts)
with torch.no_grad():
similiraties = (self.image_features @ text_features.T)
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
target_probs = target_probs.type(torch.float32)
target = torch.zeros_like(probs[idx_p])
target[top_indices[idx_p]] = target_probs[0]
target = target.unsqueeze(0)
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
clip_loss += cur_clip_loss
losses.append(cur_clip_loss)
return clip_loss, losses

0
zerocap/model/__init__.py

BIN
zerocap/model/__pycache__/ZeroCLIP.cpython-36.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/ZeroCLIP.cpython-37.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/ZeroCLIP_batched.cpython-36.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/ZeroCLIP_batched.cpython-37.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/__init__.cpython-36.pyc

Binary file not shown.

BIN
zerocap/model/__pycache__/__init__.cpython-37.pyc

Binary file not shown.

14
zerocap/mscoco_zerocap.sh

@ -1,14 +0,0 @@
#!/bin/bash
# lm_model:
# 1. cambridgeltl/magic_mscoco
# 2. cambridgeltl/magic_flickr30k
CUDA_VISIBLE_DEVICES=1 python run.py \
--beam_size 1 \
--target_seq_length 16 \
--reset_context_delta \
--lm_model cambridgeltl/magic_mscoco \
--test_image_prefix_path ../data/mscoco/test_images \
--test_path ../data/mscoco/mscoco_test.json \
--save_path_prefix ../inference_result/mscoco/baselines/ \
--save_name zerocap_result.json

117
zerocap/predict.py

@ -1,117 +0,0 @@
import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image",
type=Path,
help="input image"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=5, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factor",
type=float,
default=1.01, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_length",
type=int,
default=15, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factor
self.text_generator.target_seq_length = max_seq_length
self.text_generator.ce_scale = ce_loss_scale
image_features = self.text_generator.get_img_feature([str(image)], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args

129
zerocap/predict_arithmetic.py

@ -1,129 +0,0 @@
import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image1",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"image2",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"image3",
type=Path,
help="Final result will be: image1 + (image2 - image3)"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=3, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factors",
type=float,
default=1.06, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_lengths",
type=int,
default=3, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image1, image2, image3, cond_text, beam_size, end_factors, max_seq_lengths, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factors
self.text_generator.target_seq_length = max_seq_lengths
self.text_generator.ce_scale = ce_loss_scale
self.text_generator.fusion_factor = 0.95
self.text_generator.grad_norm_factor = 0.95
image_features = self.text_generator.get_combined_feature([str(image1), str(image2), str(image3)], [], [1, 1, -1], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.95)
parser.add_argument("--fusion_factor", type=float, default=0.95)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args

3
zerocap/requirements.txt

@ -1,3 +0,0 @@
ftfy
regex
tqdm

131
zerocap/run.py

@ -1,131 +0,0 @@
import argparse
import ipdb
from tqdm import tqdm
import progressbar
import torch
import ipdb
import clip
from model.ZeroCLIP import CLIPTextGenerator
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
parser.add_argument("--test_path", type=str)
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
parser.add_argument("--save_name", type=str, help="the name of the saved file")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=1)
parser.add_argument("--multi_gpu", action="store_true")
parser.add_argument('--run_type',
default='caption',
nargs='?',
choices=['caption', 'arithmetics'])
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
help="Path to image for captioning")
parser.add_argument("--arithmetics_imgs", nargs="+",
default=['example_images/arithmetics/woman2.jpg',
'example_images/arithmetics/king2.jpg',
'example_images/arithmetics/man2.jpg'])
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
args = parser.parse_args()
return args
def run(args, text_generator, img_path):
image_features = text_generator.get_img_feature([img_path], None)
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
return captions
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = get_args()
device = torch.device('cuda')
save_path_prefix = args.save_path_prefix
import os
if os.path.exists(save_path_prefix):
pass
else: # recursively construct directory
os.makedirs(save_path_prefix, exist_ok=True)
# parse save name
save_name = args.save_name
full_save_path = save_path_prefix + '/' + save_name
print ('full save path is {}'.format(full_save_path))
print ('Loading data...')
import json
with open(args.test_path) as f:
item_list = json.load(f)
print ('Data loaded.')
print ('Number of test instances is {}'.format(len(item_list)))
# ZeroCap generator
text_generator = CLIPTextGenerator(**vars(args))
result_list = []
invalid_num = 0
print ('----------------------------------------------------------------')
test_num = len(item_list)
#test_num = 10
print ('Number of inference instances is {}'.format(test_num))
p = progressbar.ProgressBar(test_num)
p.start()
for p_idx in tqdm(range(test_num)):
p.update(p_idx)
one_test_dict = item_list[p_idx]
one_res_dict = {
'split':one_test_dict['split'],
'image_name':one_test_dict['image_name'],
#'file_path':one_test_dict['file_path'],
'captions':one_test_dict['captions']
}
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
try:
output_text = run(args, text_generator, img_path=image_full_path)
one_res_dict['prediction'] = output_text[0]
result_list.append(one_res_dict)
except Exception as error:
print(f'[!] ERROR:', error)
invalid_num += 1
print ('invalid number is {}'.format(invalid_num))
continue
p.finish()
print ('Inference completed!')
import json
with open(full_save_path, 'w') as outfile:
json.dump(result_list, outfile, indent=4)

19
zerocap/setup.py

@ -1,19 +0,0 @@
import os
import pkg_resources
from setuptools import setup, find_packages
setup(
name="zero-shot-image-to-text",
py_modules=["zero-shot-image-to-text"],
version="1.0",
description="",
packages=find_packages(),
install_requires=[
str(r)
for r in pkg_resources.parse_requirements(
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
)
],
include_package_data=True
)
Loading…
Cancel
Save