magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
292 lines
12 KiB
292 lines
12 KiB
2 years ago
|
import sys
|
||
|
import os
|
||
|
import operator
|
||
|
from operator import itemgetter
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
import random
|
||
|
import numpy as np
|
||
|
import argparse
|
||
|
import random
|
||
|
|
||
|
def parse_prompt(text):
|
||
|
'''
|
||
|
process the prompt text;
|
||
|
'''
|
||
|
eos_token = '<|endoftext|>'
|
||
|
text = text.strip(eos_token).strip()
|
||
|
left_bracket_idx, right_bracket_idx = -1, -1
|
||
|
for idx in range(len(text)):
|
||
|
char = text[idx]
|
||
|
if char == '[' and left_bracket_idx == -1: # first [ is met
|
||
|
left_bracket_idx = idx
|
||
|
elif char == ']' and right_bracket_idx == -1: # first ] is met
|
||
|
right_bracket_idx = idx
|
||
|
else:
|
||
|
pass
|
||
|
res_text = ''
|
||
|
remove = False
|
||
|
if left_bracket_idx > -1 and right_bracket_idx > left_bracket_idx:
|
||
|
if right_bracket_idx - left_bracket_idx <= 6:
|
||
|
remove = True
|
||
|
else:
|
||
|
pass
|
||
|
|
||
|
for idx in range(len(text)):
|
||
|
if remove:
|
||
|
if idx >= left_bracket_idx and idx <= right_bracket_idx:
|
||
|
continue
|
||
|
else:
|
||
|
res_text += text[idx]
|
||
|
else:
|
||
|
res_text += text[idx]
|
||
|
res_text = res_text.strip()
|
||
|
res_text = ' '.join(res_text.split()).strip()
|
||
|
return res_text
|
||
|
|
||
|
def typical_filtering(scores, mass, min_tokens_to_keep, filter_value):
|
||
|
# calculate entropy
|
||
|
normalized = torch.nn.functional.log_softmax(scores, dim=-1)
|
||
|
p = torch.exp(normalized)
|
||
|
ent = -(normalized * p).nansum(-1, keepdim=True)
|
||
|
|
||
|
# shift and sort
|
||
|
shifted_scores = torch.abs((-normalized) - ent)
|
||
|
sorted_scores, sorted_indices = torch.sort(shifted_scores, descending=False)
|
||
|
sorted_logits = scores.gather(-1, sorted_indices)
|
||
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
||
|
|
||
|
# Remove tokens with cumulative mass above the threshold
|
||
|
last_ind = (cumulative_probs < mass).sum(dim=1)
|
||
|
last_ind[last_ind < 0] = 0
|
||
|
sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
|
||
|
if min_tokens_to_keep > 1:
|
||
|
# Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
|
||
|
sorted_indices_to_remove[..., : min_tokens_to_keep] = 0
|
||
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||
|
|
||
|
scores = scores.masked_fill(indices_to_remove, filter_value)
|
||
|
return scores
|
||
|
|
||
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, threshold=-float('Inf'), filter_value=-np.inf):
|
||
|
assert logits.dim() == 1
|
||
|
top_k = min(top_k, logits.size(-1))
|
||
|
if top_k > 0:
|
||
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
if top_p > 0.0:
|
||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||
|
sorted_indices_to_remove[..., 0] = 0
|
||
|
|
||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
|
||
|
indices_to_remove = logits < threshold
|
||
|
logits[indices_to_remove] = filter_value
|
||
|
return logits
|
||
|
|
||
|
# ========== batch version ========= #
|
||
|
def ranking_fast(context_hidden, next_hidden, next_top_k_probs, alpha, beam_width):
|
||
|
'''
|
||
|
context_hidden: bsz*beam x seqlen x embed_dim
|
||
|
next_hidden: bsz*beam x 1 x embed_dim
|
||
|
next_top_k_probs: bsz x beam
|
||
|
'''
|
||
|
_, context_len, embed_dim = context_hidden.size()
|
||
|
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
||
|
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
||
|
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1) # [B*K, S]
|
||
|
scores, _ = torch.max(cosine_matrix, dim=-1) # [B*K]
|
||
|
next_top_k_probs = next_top_k_probs.view(-1) # [B*K]
|
||
|
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores
|
||
|
scores = torch.stack(torch.split(scores, beam_width)) # [B, K]
|
||
|
selected_idx = scores.max(dim=-1)[1] # [B]
|
||
|
return selected_idx
|
||
|
|
||
|
def ContrastiveDecodingOneStepFast(
|
||
|
model,
|
||
|
ids,
|
||
|
beam_width,
|
||
|
alpha,
|
||
|
past_key_values,
|
||
|
last_hidden_states,
|
||
|
vocab,
|
||
|
logit_for_next_step,
|
||
|
first_step=False,
|
||
|
):
|
||
|
# input_ids: [B, S]
|
||
|
if first_step:
|
||
|
output = model(
|
||
|
input_ids=ids,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=True,
|
||
|
output_hidden_states=True
|
||
|
)
|
||
|
past_key_values = output.past_key_values
|
||
|
last_hidden_states = output.hidden_states[-1] # [B, S, E]
|
||
|
logit_for_next_step = output.logits[:, -1, :] # [B, V]
|
||
|
bsz, seqlen, embed_dim = last_hidden_states.size()
|
||
|
p = random.uniform(0, 1)
|
||
|
|
||
|
next_probs = F.softmax(logit_for_next_step, dim=-1)
|
||
|
_, top_k_ids = torch.topk(logit_for_next_step, dim=-1, k=beam_width) # [B, K]
|
||
|
top_k_probs = torch.gather(next_probs, dim=1, index=top_k_ids) # [B, K]
|
||
|
# compute new hidden
|
||
|
past_key_values = enlarge_past_key_values(past_key_values, beam_width)
|
||
|
output = model(
|
||
|
input_ids=top_k_ids.view(-1, 1),
|
||
|
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
|
||
|
past_key_values=past_key_values,
|
||
|
output_hidden_states=True,
|
||
|
use_cache=True,
|
||
|
)
|
||
|
past_key_values = output.past_key_values
|
||
|
logits = output.logits[:, -1, :] # [B*K, V]
|
||
|
next_hidden = output.hidden_states[-1] # [B*K, 1, E]
|
||
|
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim) # [B*K, S, E]
|
||
|
|
||
|
selected_idx = ranking_fast(
|
||
|
context_hidden,
|
||
|
next_hidden,
|
||
|
top_k_probs, # [B, K]
|
||
|
alpha,
|
||
|
beam_width,
|
||
|
) # [B]
|
||
|
# prepare for the next step
|
||
|
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1) # [B, 1]
|
||
|
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width)) # [B, K, E]
|
||
|
next_hidden = next_hidden[range(bsz), selected_idx, :] # [B, E]
|
||
|
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1) # [B, S, E]
|
||
|
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
|
||
|
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :] # [B, V]
|
||
|
# next_id: [B, 1]
|
||
|
return next_id, past_key_values, last_hidden_states, logits
|
||
|
|
||
|
def enlarge_past_key_values(past_key_values, beam_width):
|
||
|
# from [B, num_head, seq_len, esz] to [B*K, num_head, seq_len, esz]
|
||
|
new_key_values = []
|
||
|
for layer in past_key_values:
|
||
|
items = []
|
||
|
for item in layer:
|
||
|
# item is the key and value matrix
|
||
|
bsz, num_head, seq_len, esz = item.size()
|
||
|
item = item.unsqueeze(1).expand(-1, beam_width, -1, -1, -1).reshape(bsz*beam_width, num_head, seq_len, esz) # [bsz*beam, num_head, seq_len, esz]
|
||
|
items.append(item)
|
||
|
new_key_values.append(items)
|
||
|
return new_key_values
|
||
|
|
||
|
def select_past_key_values(past_key_values, beam_width, selected_idx):
|
||
|
'''select_idx: [B]'''
|
||
|
new_key_values = []
|
||
|
for layer in past_key_values:
|
||
|
items = []
|
||
|
for item in layer:
|
||
|
bsz_and_beam, num_head, seq_len, esz = item.size()
|
||
|
bsz = int(bsz_and_beam//beam_width)
|
||
|
item = torch.stack(torch.split(item, beam_width, dim=0)) # [B, K, num_head, seq_len, esz]
|
||
|
item = item[range(bsz), selected_idx, :, :, :] # [B, num_head, seq_len, esz]
|
||
|
items.append(item)
|
||
|
new_key_values.append(items)
|
||
|
return new_key_values
|
||
|
|
||
|
# ========== fast plug and play version ========= #
|
||
|
def plug_and_play_fast_ranking(
|
||
|
context_hidden,
|
||
|
next_hidden,
|
||
|
next_top_k_ids,
|
||
|
next_top_k_probs,
|
||
|
alpha,
|
||
|
beta,
|
||
|
batch_class_score,
|
||
|
beam_width,
|
||
|
):
|
||
|
'''
|
||
|
context_hidden: beam_width x context_len x embed_dim
|
||
|
next_hidden: beam_width x 1 x embed_dim
|
||
|
next_top_k_ids: beam_width x 1
|
||
|
batch_class_score: beam_width x 1
|
||
|
'''
|
||
|
_, context_len, embed_dim = context_hidden.size()
|
||
|
norm_context_hidden = context_hidden / context_hidden.norm(dim=2, keepdim=True)
|
||
|
norm_next_hidden = next_hidden / next_hidden.norm(dim=2, keepdim=True)
|
||
|
cosine_matrix = torch.matmul(norm_context_hidden, norm_next_hidden.transpose(1,2)).squeeze(-1)
|
||
|
scores, _ = torch.max(cosine_matrix, dim = -1)
|
||
|
next_top_k_probs = next_top_k_probs.view(-1)
|
||
|
scores = (1.0 - alpha) * next_top_k_probs - alpha * scores + beta * batch_class_score.view([beam_width])
|
||
|
scores = torch.stack(torch.split(scores, beam_width))
|
||
|
selected_idx = scores.max(dim=-1)[1]
|
||
|
return selected_idx
|
||
|
|
||
|
def PlugAndPlayContrastiveDecodingOneStepFast(model, input_ids, prefix_len, beam_width, alpha, beta,
|
||
|
simctg_tokenizer, image_embeds, clip, clip_text_max_len, past_key_values, last_hidden_states,
|
||
|
logit_for_next_step, first_step=False, input_ids_for_class=None):#, add_token_level_score=False):
|
||
|
'''
|
||
|
model: the generation model, e.g., gpt2
|
||
|
input_ids: 1 x seqlen
|
||
|
'''
|
||
|
|
||
|
if first_step:
|
||
|
output = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True, output_hidden_states=True)
|
||
|
past_key_values = output.past_key_values
|
||
|
last_hidden_states = output.hidden_states[-1] # [B, S, E]
|
||
|
logit_for_next_step = output.logits[:, -1, :] # [B, V]
|
||
|
bsz, seqlen, embed_dim = last_hidden_states.size()
|
||
|
next_probs = F.softmax(logit_for_next_step, dim = -1)
|
||
|
_, top_k_ids = torch.topk(logit_for_next_step, dim = -1, k = beam_width)
|
||
|
top_k_probs = torch.gather(next_probs, dim = 1, index=top_k_ids)
|
||
|
|
||
|
# compute the new hidden
|
||
|
past_key_values = enlarge_past_key_values(past_key_values, beam_width)
|
||
|
output = model(
|
||
|
input_ids=top_k_ids.view(-1, 1) ,
|
||
|
attention_mask=torch.ones_like(top_k_ids.view(-1, 1)),
|
||
|
past_key_values=past_key_values,
|
||
|
output_hidden_states=True,
|
||
|
use_cache=True,
|
||
|
)
|
||
|
past_key_values = output.past_key_values
|
||
|
logits = output.logits[:, -1, :]
|
||
|
next_hidden = output.hidden_states[-1]
|
||
|
context_hidden = last_hidden_states.unsqueeze(1).expand(-1, beam_width, -1, -1).reshape(bsz*beam_width, seqlen, embed_dim)
|
||
|
|
||
|
# prepare for the classification model
|
||
|
input_ids_for_class_ = torch.cat([
|
||
|
input_ids_for_class.unsqueeze(1).expand(-1, beam_width, -1).reshape(bsz*beam_width, seqlen),
|
||
|
top_k_ids.view(-1, 1)
|
||
|
], dim=-1
|
||
|
)
|
||
|
|
||
|
batch_text_list = []
|
||
|
for one_input_id in input_ids_for_class_:
|
||
|
one_text = simctg_tokenizer.decode(one_input_id[prefix_len:][-clip_text_max_len:])
|
||
|
# we only consider the class score of the generated text continuation
|
||
|
batch_text_list.append(one_text)
|
||
|
batch_score = clip.compute_image_text_similarity_via_raw_text(image_embeds, batch_text_list)
|
||
|
|
||
|
selected_idx = plug_and_play_fast_ranking(
|
||
|
context_hidden,
|
||
|
next_hidden,
|
||
|
top_k_ids,
|
||
|
top_k_probs,
|
||
|
alpha,
|
||
|
beta,
|
||
|
batch_score,
|
||
|
beam_width,
|
||
|
)
|
||
|
|
||
|
# prepare for the next step
|
||
|
next_id = top_k_ids[range(len(top_k_ids)), selected_idx].unsqueeze(-1)
|
||
|
next_hidden = torch.stack(torch.split(next_hidden.squeeze(dim=1), beam_width))
|
||
|
next_hidden = next_hidden[range(bsz), selected_idx, :]
|
||
|
last_hidden_states = torch.cat([last_hidden_states, next_hidden.unsqueeze(1)], dim=1)
|
||
|
past_key_values = select_past_key_values(past_key_values, beam_width, selected_idx)
|
||
|
logits = torch.stack(torch.split(logits, beam_width))[range(bsz), selected_idx, :]
|
||
|
input_ids_for_class = torch.cat([input_ids_for_class, next_id], dim=-1)
|
||
|
return next_id, past_key_values, last_hidden_states, logits, input_ids_for_class
|
||
|
|
||
|
|