magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
234 lines
9.9 KiB
234 lines
9.9 KiB
2 years ago
|
import os
|
||
|
import sys
|
||
|
import operator
|
||
|
from tqdm import tqdm
|
||
|
from operator import itemgetter
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import random
|
||
|
import argparse
|
||
|
import numpy as np
|
||
|
import torch.nn.functional as F
|
||
|
from torch.nn import CrossEntropyLoss
|
||
|
from loss_func import contrastive_loss
|
||
|
from utlis import PlugAndPlayContrastiveDecodingOneStepFast
|
||
|
|
||
|
import seaborn as sns
|
||
|
import matplotlib.pyplot as plt
|
||
|
import pandas as pd
|
||
|
import datetime
|
||
|
|
||
|
train_fct = CrossEntropyLoss()
|
||
|
val_fct = CrossEntropyLoss(reduction='none')
|
||
|
class SimCTG(nn.Module):
|
||
|
def __init__(self, model_name, sos_token, pad_token):
|
||
|
super(SimCTG, self).__init__()
|
||
|
from transformers import AutoTokenizer, GPT2LMHeadModel
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
self.sos_token, self.sos_token_id = self.add_special_token(sos_token)
|
||
|
print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id))
|
||
|
self.pad_token, self.pad_token_id = self.add_special_token(pad_token)
|
||
|
print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id))
|
||
|
self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id
|
||
|
print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id))
|
||
|
self.model = GPT2LMHeadModel.from_pretrained(model_name)
|
||
|
self.vocab_size = len(self.tokenizer)
|
||
|
print ('Resizing model embedding...')
|
||
|
self.model.resize_token_embeddings(len(self.tokenizer))
|
||
|
print ('Model embedding resized!')
|
||
|
self.embed_dim = self.model.config.hidden_size
|
||
|
|
||
|
def add_special_token(self, special_token):
|
||
|
if special_token in self.tokenizer.vocab:
|
||
|
print (special_token + ' token exists.')
|
||
|
else:
|
||
|
print ('Add token to the tokenizer.')
|
||
|
print ('Original vocabulary size is {}'.format(len(self.tokenizer)))
|
||
|
self.tokenizer.add_tokens([special_token])
|
||
|
print ('Vocabulary size after extension is {}'.format(len(self.tokenizer)))
|
||
|
assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1
|
||
|
special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0]
|
||
|
return special_token, special_token_id
|
||
|
|
||
|
def compute_logits_and_hidden_states(self, input_ids):
|
||
|
# used for advanced decoding
|
||
|
# input_ids: 1 x seqlen
|
||
|
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
|
||
|
last_hidden_states = outputs.hidden_states[-1]
|
||
|
logits = outputs.logits
|
||
|
return last_hidden_states, logits
|
||
|
|
||
|
def forward(self, input_ids, labels, margin):
|
||
|
bsz, seqlen = input_ids.size()
|
||
|
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
|
||
|
logits = outputs.logits
|
||
|
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size])
|
||
|
last_hidden_states = outputs.hidden_states[-1]
|
||
|
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim])
|
||
|
mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1))
|
||
|
|
||
|
norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True)
|
||
|
cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2))
|
||
|
assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen])
|
||
|
cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0)
|
||
|
return mle_loss, cl_loss
|
||
|
|
||
|
def eval_loss(self, input_ids, labels):
|
||
|
bsz, seqlen = input_ids.size()
|
||
|
outputs = self.model(input_ids=input_ids, output_hidden_states=True)
|
||
|
logits = outputs.logits
|
||
|
assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size])
|
||
|
last_hidden_states = outputs.hidden_states[-1]
|
||
|
assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim])
|
||
|
mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1))
|
||
|
assert mle_loss.size() == torch.Size([bsz * seqlen])
|
||
|
mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0)
|
||
|
mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0)
|
||
|
# sum
|
||
|
mle_loss_sum = torch.sum(mle_loss)
|
||
|
token_num_sum = torch.sum(mask)
|
||
|
return mle_loss_sum, token_num_sum
|
||
|
|
||
|
def save_model(self, ckpt_save_path):
|
||
|
import os
|
||
|
if os.path.exists(ckpt_save_path):
|
||
|
pass
|
||
|
else: # recursively construct directory
|
||
|
os.makedirs(ckpt_save_path, exist_ok=True)
|
||
|
# save model
|
||
|
self.model.save_pretrained(ckpt_save_path)
|
||
|
# save tokenizer
|
||
|
self.tokenizer.save_pretrained(ckpt_save_path)
|
||
|
|
||
|
def parse_sentences(self, text, num_of_sentences_to_keep):
|
||
|
item_list = text.split('.')
|
||
|
res_list = item_list[:num_of_sentences_to_keep]
|
||
|
if len(item_list) > num_of_sentences_to_keep:
|
||
|
res_text = '.'.join(res_list).strip('.') + '.'
|
||
|
else:
|
||
|
res_text = '.'.join(res_list).strip('.').strip()
|
||
|
return res_text
|
||
|
|
||
|
def parse_generated_result(self, output, num_of_sentences_to_keep):
|
||
|
output_text = self.tokenizer.decode(output)
|
||
|
item_list = output_text.split(self.eos_token)
|
||
|
full_text = self.eos_token.join(item_list[:2]).strip()
|
||
|
full_text = self.parse_sentences(full_text, num_of_sentences_to_keep)
|
||
|
generated_text = item_list[1].strip()
|
||
|
generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep)
|
||
|
return full_text, generated_text
|
||
|
|
||
|
# decoding functions
|
||
|
# ------------------------------------------------------- #
|
||
|
|
||
|
def parse_output_token_list(self, output):
|
||
|
output = output.tolist()
|
||
|
res_list = []
|
||
|
for token_id in output:
|
||
|
if token_id == self.sos_token_id:
|
||
|
continue
|
||
|
elif token_id == self.eos_token_id:
|
||
|
break
|
||
|
else:
|
||
|
res_list.append(token_id)
|
||
|
text = self.tokenizer.decode(res_list).strip()
|
||
|
return ' '.join(text.split()).strip()
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip,
|
||
|
clip_text_max_len):#, add_token_level_score=False):
|
||
|
prefix_len = input_ids.size()[1]
|
||
|
#from utlis import PlugAndPlayContrastiveDecodingOneStepFast
|
||
|
past_key_values, last_hidden_states, logits = None, None, None
|
||
|
generated = [item for item in input_ids.tolist()]
|
||
|
input_ids_for_class = input_ids.clone()
|
||
|
|
||
|
image_embeds = clip.compute_image_representation_from_image_instance(image_instance)
|
||
|
|
||
|
start_time = datetime.datetime.now()
|
||
|
|
||
|
# the maximum supported length of generation for SimCTG is 256
|
||
|
# to support longer generated length, you can re-train the SimCTG model with longer sequences
|
||
|
decoding_len = decoding_len - prefix_len
|
||
|
for step in range(decoding_len):
|
||
|
input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \
|
||
|
PlugAndPlayContrastiveDecodingOneStepFast(
|
||
|
self.model,
|
||
|
input_ids,
|
||
|
prefix_len,
|
||
|
beam_width,
|
||
|
alpha,
|
||
|
beta,
|
||
|
self.tokenizer,
|
||
|
image_embeds,
|
||
|
clip,
|
||
|
clip_text_max_len,
|
||
|
past_key_values,
|
||
|
last_hidden_states,
|
||
|
logits,
|
||
|
first_step=step==0,
|
||
|
input_ids_for_class=input_ids_for_class,
|
||
|
)
|
||
|
end_time = datetime.datetime.now()
|
||
|
time_diff = (end_time - start_time)
|
||
|
execution_time = time_diff.total_seconds() * 1000
|
||
|
return self.parse_output_token_list(input_ids_for_class[0])
|
||
|
|
||
|
def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len):
|
||
|
'''
|
||
|
input_ids: prefix input; 1 x prefix_len
|
||
|
decoding_len: how many tokens to generate
|
||
|
beam_width: size of candidate pool during decoding
|
||
|
alpha: regulates importance of model confidence and degeneration penalty
|
||
|
'''
|
||
|
self.model.eval()
|
||
|
#from utlis import ContrastiveDecodingOneStepFast
|
||
|
# sanity check
|
||
|
assert alpha >= 0. and alpha <= 1.0
|
||
|
|
||
|
# fast mode
|
||
|
prefix_len = input_ids.size()[1]
|
||
|
batch_size, seqlen = input_ids.size()
|
||
|
#generated = [[] for _ in range(batch_size)]
|
||
|
generated = [item for item in input_ids.tolist()]
|
||
|
past_key_values = None
|
||
|
last_hidden_states = None
|
||
|
logits = None
|
||
|
decoding_len = decoding_len - prefix_len
|
||
|
for step in range(decoding_len):
|
||
|
input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast(
|
||
|
self.model,
|
||
|
input_ids,
|
||
|
beam_width,
|
||
|
alpha,
|
||
|
past_key_values,
|
||
|
last_hidden_states,
|
||
|
self.tokenizer,
|
||
|
logits,
|
||
|
first_step=step == 0,
|
||
|
)
|
||
|
tokens = input_ids.squeeze(dim=-1).tolist()
|
||
|
for idx, t in enumerate(tokens):
|
||
|
generated[idx].append(t)
|
||
|
return self.parse_output_token_list(torch.LongTensor(generated[0]))
|
||
|
|
||
|
def top_k_sampling(self, input_ids, k, decoding_len):
|
||
|
_, prefix_len = input_ids.size()
|
||
|
output = self.model.generate(
|
||
|
input_ids,
|
||
|
do_sample=True,
|
||
|
max_length=decoding_len,
|
||
|
top_p=1.0,
|
||
|
top_k=k)
|
||
|
return self.parse_output_token_list(output[0])
|
||
|
|
||
|
def nucleus_sampling(self, input_ids, nucleus_p, decoding_len):
|
||
|
_, prefix_len = input_ids.size()
|
||
|
output = self.model.generate(
|
||
|
input_ids,
|
||
|
do_sample=True,
|
||
|
max_length=decoding_len,
|
||
|
top_p=nucleus_p,
|
||
|
top_k=0)
|
||
|
return self.parse_output_token_list(output[0])
|