import os import sys import operator from tqdm import tqdm from operator import itemgetter import torch from torch import nn import random import argparse import numpy as np import torch.nn.functional as F from torch.nn import CrossEntropyLoss from loss_func import contrastive_loss from utlis import PlugAndPlayContrastiveDecodingOneStepFast import seaborn as sns import matplotlib.pyplot as plt import pandas as pd import datetime train_fct = CrossEntropyLoss() val_fct = CrossEntropyLoss(reduction='none') class SimCTG(nn.Module): def __init__(self, model_name, sos_token, pad_token): super(SimCTG, self).__init__() from transformers import AutoTokenizer, GPT2LMHeadModel self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.sos_token, self.sos_token_id = self.add_special_token(sos_token) print ('sos token is {}, sos token id is {}'.format(self.sos_token, self.sos_token_id)) self.pad_token, self.pad_token_id = self.add_special_token(pad_token) print ('pad token is {}, pad token id is {}'.format(self.pad_token, self.pad_token_id)) self.eos_token, self.eos_token_id = self.tokenizer.bos_token, self.tokenizer.bos_token_id print ('eos token is {}, eos token id is {}'.format(self.eos_token, self.eos_token_id)) self.model = GPT2LMHeadModel.from_pretrained(model_name) self.vocab_size = len(self.tokenizer) print ('Resizing model embedding...') self.model.resize_token_embeddings(len(self.tokenizer)) print ('Model embedding resized!') self.embed_dim = self.model.config.hidden_size def add_special_token(self, special_token): if special_token in self.tokenizer.vocab: print (special_token + ' token exists.') else: print ('Add token to the tokenizer.') print ('Original vocabulary size is {}'.format(len(self.tokenizer))) self.tokenizer.add_tokens([special_token]) print ('Vocabulary size after extension is {}'.format(len(self.tokenizer))) assert len(self.tokenizer.convert_tokens_to_ids([special_token])) == 1 special_token_id = self.tokenizer.convert_tokens_to_ids([special_token])[0] return special_token, special_token_id def compute_logits_and_hidden_states(self, input_ids): # used for advanced decoding # input_ids: 1 x seqlen outputs = self.model(input_ids=input_ids, output_hidden_states=True) last_hidden_states = outputs.hidden_states[-1] logits = outputs.logits return last_hidden_states, logits def forward(self, input_ids, labels, margin): bsz, seqlen = input_ids.size() outputs = self.model(input_ids=input_ids, output_hidden_states=True) logits = outputs.logits assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) last_hidden_states = outputs.hidden_states[-1] assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) mle_loss = train_fct(logits.view(-1, self.vocab_size), labels.view(-1)) norm_rep = last_hidden_states / last_hidden_states.norm(dim=2, keepdim=True) cosine_scores = torch.matmul(norm_rep, norm_rep.transpose(1,2)) assert cosine_scores.size() == torch.Size([bsz, seqlen, seqlen]) cl_loss = contrastive_loss(margin, cosine_scores, input_ids, self.pad_token_id, prefix_len=0) return mle_loss, cl_loss def eval_loss(self, input_ids, labels): bsz, seqlen = input_ids.size() outputs = self.model(input_ids=input_ids, output_hidden_states=True) logits = outputs.logits assert logits.size() == torch.Size([bsz, seqlen, self.vocab_size]) last_hidden_states = outputs.hidden_states[-1] assert last_hidden_states.size() == torch.Size([bsz, seqlen, self.embed_dim]) mle_loss = val_fct(logits.view(-1, self.vocab_size), labels.view(-1)) assert mle_loss.size() == torch.Size([bsz * seqlen]) mask_tmp = labels.masked_fill(~labels.eq(-100), 1.0) mask = mask_tmp.masked_fill(mask_tmp.eq(-100), 0.0) # sum mle_loss_sum = torch.sum(mle_loss) token_num_sum = torch.sum(mask) return mle_loss_sum, token_num_sum def save_model(self, ckpt_save_path): import os if os.path.exists(ckpt_save_path): pass else: # recursively construct directory os.makedirs(ckpt_save_path, exist_ok=True) # save model self.model.save_pretrained(ckpt_save_path) # save tokenizer self.tokenizer.save_pretrained(ckpt_save_path) def parse_sentences(self, text, num_of_sentences_to_keep): item_list = text.split('.') res_list = item_list[:num_of_sentences_to_keep] if len(item_list) > num_of_sentences_to_keep: res_text = '.'.join(res_list).strip('.') + '.' else: res_text = '.'.join(res_list).strip('.').strip() return res_text def parse_generated_result(self, output, num_of_sentences_to_keep): output_text = self.tokenizer.decode(output) item_list = output_text.split(self.eos_token) full_text = self.eos_token.join(item_list[:2]).strip() full_text = self.parse_sentences(full_text, num_of_sentences_to_keep) generated_text = item_list[1].strip() generated_text = self.parse_sentences(generated_text, num_of_sentences_to_keep) return full_text, generated_text # decoding functions # ------------------------------------------------------- # def parse_output_token_list(self, output): output = output.tolist() res_list = [] for token_id in output: if token_id == self.sos_token_id: continue elif token_id == self.eos_token_id: break else: res_list.append(token_id) text = self.tokenizer.decode(res_list).strip() return ' '.join(text.split()).strip() @torch.no_grad() def magic_search(self, input_ids, beam_width, alpha, decoding_len, beta, image_instance, clip, clip_text_max_len):#, add_token_level_score=False): prefix_len = input_ids.size()[1] #from utlis import PlugAndPlayContrastiveDecodingOneStepFast past_key_values, last_hidden_states, logits = None, None, None generated = [item for item in input_ids.tolist()] input_ids_for_class = input_ids.clone() image_embeds = clip.compute_image_representation_from_image_instance(image_instance) start_time = datetime.datetime.now() # the maximum supported length of generation for SimCTG is 256 # to support longer generated length, you can re-train the SimCTG model with longer sequences decoding_len = decoding_len - prefix_len for step in range(decoding_len): input_ids, past_key_values, last_hidden_states, logits, input_ids_for_class = \ PlugAndPlayContrastiveDecodingOneStepFast( self.model, input_ids, prefix_len, beam_width, alpha, beta, self.tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states, logits, first_step=step==0, input_ids_for_class=input_ids_for_class, ) end_time = datetime.datetime.now() time_diff = (end_time - start_time) execution_time = time_diff.total_seconds() * 1000 return self.parse_output_token_list(input_ids_for_class[0]) def fast_contrastive_search(self, input_ids, beam_width, alpha, decoding_len): ''' input_ids: prefix input; 1 x prefix_len decoding_len: how many tokens to generate beam_width: size of candidate pool during decoding alpha: regulates importance of model confidence and degeneration penalty ''' self.model.eval() #from utlis import ContrastiveDecodingOneStepFast # sanity check assert alpha >= 0. and alpha <= 1.0 # fast mode prefix_len = input_ids.size()[1] batch_size, seqlen = input_ids.size() #generated = [[] for _ in range(batch_size)] generated = [item for item in input_ids.tolist()] past_key_values = None last_hidden_states = None logits = None decoding_len = decoding_len - prefix_len for step in range(decoding_len): input_ids, past_key_values, last_hidden_states, logits = ContrastiveDecodingOneStepFast( self.model, input_ids, beam_width, alpha, past_key_values, last_hidden_states, self.tokenizer, logits, first_step=step == 0, ) tokens = input_ids.squeeze(dim=-1).tolist() for idx, t in enumerate(tokens): generated[idx].append(t) return self.parse_output_token_list(torch.LongTensor(generated[0])) def top_k_sampling(self, input_ids, k, decoding_len): _, prefix_len = input_ids.size() output = self.model.generate( input_ids, do_sample=True, max_length=decoding_len, top_p=1.0, top_k=k) return self.parse_output_token_list(output[0]) def nucleus_sampling(self, input_ids, nucleus_p, decoding_len): _, prefix_len = input_ids.size() output = self.model.generate( input_ids, do_sample=True, max_length=decoding_len, top_p=nucleus_p, top_k=0) return self.parse_output_token_list(output[0])