magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
449 lines
18 KiB
449 lines
18 KiB
2 years ago
|
import numpy as np
|
||
|
from torch import nn
|
||
|
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer
|
||
|
from transformers.models.gpt_neo import GPTNeoForCausalLM
|
||
|
import torch
|
||
|
import clip
|
||
|
from PIL import Image
|
||
|
from datetime import datetime
|
||
|
import sys
|
||
|
|
||
|
class TextCLIP(nn.Module):
|
||
|
def __init__(self, model):
|
||
|
super(TextCLIP, self).__init__()
|
||
|
self.model = model
|
||
|
|
||
|
def forward(self, text):
|
||
|
return self.model.encode_text(text)
|
||
|
|
||
|
|
||
|
class ImageCLIP(nn.Module):
|
||
|
def __init__(self, model):
|
||
|
super(ImageCLIP, self).__init__()
|
||
|
self.model = model
|
||
|
|
||
|
def forward(self, image):
|
||
|
return self.model.encode_image(image)
|
||
|
|
||
|
def log_info(text, verbose=True):
|
||
|
if verbose:
|
||
|
dt_string = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
|
||
|
print(f'{dt_string} | {text}')
|
||
|
sys.stdout.flush()
|
||
|
|
||
|
|
||
|
def add_context(x, y):
|
||
|
return (x[0] + y[0], x[1] + y[1])
|
||
|
|
||
|
|
||
|
def convert_models_to_fp32(model):
|
||
|
for p in model.parameters():
|
||
|
p.data = p.data.float()
|
||
|
|
||
|
|
||
|
class CLIPTextGenerator:
|
||
|
def __init__(self,
|
||
|
seed=0,
|
||
|
lm_model='gpt-2',
|
||
|
forbidden_tokens_file_path='./forbidden_tokens.npy',
|
||
|
clip_checkpoints='./clip_checkpoints',
|
||
|
target_seq_length=15,
|
||
|
reset_context_delta=True,
|
||
|
num_iterations=5,
|
||
|
clip_loss_temperature=0.01,
|
||
|
clip_scale=1.,
|
||
|
ce_scale=0.2,
|
||
|
stepsize=0.3,
|
||
|
grad_norm_factor=0.9,
|
||
|
fusion_factor=0.99,
|
||
|
repetition_penalty=1.,
|
||
|
end_token='.',
|
||
|
end_factor=1.01,
|
||
|
forbidden_factor=20,
|
||
|
**kwargs):
|
||
|
|
||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
||
|
# set Random seed
|
||
|
torch.manual_seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
|
||
|
# Initialize Language model
|
||
|
self.context_prefix = ''
|
||
|
|
||
|
if lm_model == 'gpt-neo':
|
||
|
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('EleutherAI/gpt-neo-125M')
|
||
|
self.lm_model = GPTNeoForCausalLM.from_pretrained('EleutherAI/gpt-neo-125M', output_hidden_states=True)
|
||
|
elif lm_model == 'gpt-2':
|
||
|
self.lm_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
|
||
|
self.lm_model = GPT2LMHeadModel.from_pretrained('gpt2-medium', output_hidden_states=True)
|
||
|
self.context_prefix = self.lm_tokenizer.bos_token
|
||
|
|
||
|
self.lm_model.to(self.device)
|
||
|
self.lm_model.eval()
|
||
|
|
||
|
self.forbidden_tokens = np.load(forbidden_tokens_file_path)
|
||
|
self.capital_letter_tokens = [self.lm_tokenizer.encoder[x] for x in self.lm_tokenizer.encoder.keys() if
|
||
|
(x[0] == 'Ä ' and len(x) > 1 and x[1].isupper())]
|
||
|
|
||
|
# Freeze LM weights
|
||
|
for param in self.lm_model.parameters():
|
||
|
param.requires_grad = False
|
||
|
|
||
|
# Initialize CLIP
|
||
|
self.clip, self.clip_preprocess = clip.load("ViT-B/32", device=self.device,
|
||
|
download_root=clip_checkpoints, jit=False)
|
||
|
self.clip_image = ImageCLIP(self.clip)
|
||
|
self.clip_image = torch.nn.DataParallel(self.clip_image)
|
||
|
self.clip_text = TextCLIP(self.clip)
|
||
|
self.clip_text = torch.nn.DataParallel(self.clip_text)
|
||
|
|
||
|
# Init arguments
|
||
|
self.target_seq_length = target_seq_length
|
||
|
self.reset_context_delta = reset_context_delta
|
||
|
self.num_iterations = num_iterations
|
||
|
self.clip_loss_temperature = clip_loss_temperature
|
||
|
self.clip_scale = clip_scale
|
||
|
self.ce_scale = ce_scale
|
||
|
self.stepsize = stepsize
|
||
|
self.grad_norm_factor = grad_norm_factor
|
||
|
self.fusion_factor = fusion_factor
|
||
|
self.repetition_penalty = repetition_penalty
|
||
|
self.end_token = self.lm_tokenizer.encode(end_token)[0]
|
||
|
self.end_factor = end_factor
|
||
|
self.ef_idx = 1
|
||
|
self.forbidden_factor = forbidden_factor
|
||
|
|
||
|
def get_img_feature(self, img_path, weights):
|
||
|
imgs = [Image.open(x) for x in img_path]
|
||
|
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
|
||
|
|
||
|
with torch.no_grad():
|
||
|
image_fts = [self.clip_image(x) for x in clip_imgs]
|
||
|
|
||
|
if weights is not None:
|
||
|
image_features = sum([x * weights[i] for i, x in enumerate(image_fts)])
|
||
|
else:
|
||
|
image_features = sum(image_fts)
|
||
|
|
||
|
image_features = torch.nn.functional.normalize(image_features, dim=-1)
|
||
|
return image_features.detach()
|
||
|
|
||
|
def get_txt_features(self, text):
|
||
|
clip_texts = clip.tokenize(text).to(self.device)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
text_features = self.clip_text(clip_texts)
|
||
|
|
||
|
text_features = torch.nn.functional.normalize(text_features, dim=-1)
|
||
|
return text_features.detach()
|
||
|
|
||
|
def get_combined_feature(self, img_path, texts, weights_i, weights_t):
|
||
|
imgs = [Image.open(x) for x in img_path]
|
||
|
clip_imgs = [self.clip_preprocess(x).unsqueeze(0).to(self.device) for x in imgs]
|
||
|
clip_texts = [clip.tokenize(x).to(self.device) for x in texts]
|
||
|
|
||
|
with torch.no_grad():
|
||
|
image_fts = [self.clip.encode_image(x) for x in clip_imgs]
|
||
|
text_fts = [self.clip.encode_text(x) for x in clip_texts]
|
||
|
|
||
|
features = sum([x * weights_i[i] for i, x in enumerate(image_fts)])
|
||
|
if weights_t is not None:
|
||
|
features += sum([x * weights_t[i] for i, x in enumerate(text_fts)])
|
||
|
|
||
|
features = features / features.norm(dim=-1, keepdim=True)
|
||
|
return features.detach()
|
||
|
|
||
|
def run(self, image_features, cond_text, beam_size):
|
||
|
self.image_features = image_features
|
||
|
|
||
|
context_tokens = self.lm_tokenizer.encode(self.context_prefix + cond_text)
|
||
|
|
||
|
output_tokens, output_text = self.generate_text(context_tokens, beam_size)
|
||
|
|
||
|
return output_text
|
||
|
|
||
|
def generate_text(self, context_tokens, beam_size):
|
||
|
context_tokens = torch.tensor(context_tokens, device=self.device, dtype=torch.long).unsqueeze(0)
|
||
|
|
||
|
gen_tokens = None
|
||
|
scores = None
|
||
|
seq_lengths = torch.ones(beam_size, device=self.device)
|
||
|
is_stopped = torch.zeros(beam_size, device=self.device, dtype=torch.bool)
|
||
|
|
||
|
for i in range(self.target_seq_length):
|
||
|
probs = self.get_next_probs(i, context_tokens)
|
||
|
logits = probs.log()
|
||
|
|
||
|
if scores is None:
|
||
|
scores, next_tokens = logits.topk(beam_size, -1)
|
||
|
context_tokens = context_tokens.expand(beam_size, *context_tokens.shape[1:])
|
||
|
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
||
|
|
||
|
if gen_tokens is None:
|
||
|
gen_tokens = next_tokens
|
||
|
else:
|
||
|
gen_tokens = gen_tokens.expand(beam_size, *gen_tokens.shape[1:])
|
||
|
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=1)
|
||
|
else:
|
||
|
logits[is_stopped] = -float(np.inf)
|
||
|
logits[is_stopped, 0] = 0
|
||
|
scores_sum = scores[:, None] + logits
|
||
|
seq_lengths[~is_stopped] += 1
|
||
|
scores_sum_average = scores_sum / seq_lengths[:, None]
|
||
|
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
|
||
|
beam_size, -1)
|
||
|
next_tokens_source = next_tokens // scores_sum.shape[1]
|
||
|
seq_lengths = seq_lengths[next_tokens_source]
|
||
|
next_tokens = next_tokens % scores_sum.shape[1]
|
||
|
next_tokens = next_tokens.unsqueeze(1)
|
||
|
gen_tokens = gen_tokens[next_tokens_source]
|
||
|
gen_tokens = torch.cat((gen_tokens, next_tokens), dim=-1)
|
||
|
context_tokens = context_tokens[next_tokens_source]
|
||
|
scores = scores_sum_average * seq_lengths
|
||
|
is_stopped = is_stopped[next_tokens_source]
|
||
|
|
||
|
context_tokens = torch.cat((context_tokens, next_tokens), dim=1)
|
||
|
is_stopped = is_stopped + next_tokens.eq(self.end_token).squeeze()
|
||
|
|
||
|
####
|
||
|
tmp_scores = scores / seq_lengths
|
||
|
tmp_output_list = gen_tokens.cpu().numpy()
|
||
|
tmp_output_texts = [
|
||
|
self.lm_tokenizer.decode(tmp_output)
|
||
|
for tmp_output, tmp_length in zip(tmp_output_list, seq_lengths)
|
||
|
]
|
||
|
tmp_order = tmp_scores.argsort(descending=True)
|
||
|
tmp_output_texts = [tmp_output_texts[i] + ' %% ' + str(tmp_scores[i].cpu().numpy()) for i in tmp_order]
|
||
|
log_info(tmp_output_texts, verbose=True)
|
||
|
####
|
||
|
|
||
|
if is_stopped.all():
|
||
|
break
|
||
|
|
||
|
scores = scores / seq_lengths
|
||
|
output_list = gen_tokens.cpu().numpy()
|
||
|
output_texts = [
|
||
|
self.lm_tokenizer.decode(output[: int(length)])
|
||
|
for output, length in zip(output_list, seq_lengths)
|
||
|
]
|
||
|
order = scores.argsort(descending=True)
|
||
|
output_texts = [output_texts[i] for i in order]
|
||
|
|
||
|
return context_tokens, output_texts
|
||
|
|
||
|
def get_next_probs(self, i, context_tokens):
|
||
|
last_token = context_tokens[:, -1:]
|
||
|
|
||
|
if self.reset_context_delta and context_tokens.size(1) > 1:
|
||
|
context = self.lm_model(context_tokens[:, :-1])["past_key_values"]
|
||
|
|
||
|
# Logits of LM with unshifted context
|
||
|
logits_before_shift = self.lm_model(context_tokens)["logits"]
|
||
|
logits_before_shift = logits_before_shift[:, -1, :]
|
||
|
probs_before_shift = nn.functional.softmax(logits_before_shift, dim=-1)
|
||
|
|
||
|
if context:
|
||
|
context = self.shift_context(i, context, last_token, context_tokens, probs_before_shift)
|
||
|
|
||
|
lm_output = self.lm_model(last_token, past_key_values=context)
|
||
|
logits, past = (
|
||
|
lm_output["logits"],
|
||
|
lm_output["past_key_values"],
|
||
|
)
|
||
|
logits = logits[:, -1, :]
|
||
|
|
||
|
logits = self.update_special_tokens_logits(context_tokens, i, logits)
|
||
|
|
||
|
probs = nn.functional.softmax(logits, dim=-1)
|
||
|
probs = (probs ** self.fusion_factor) * (probs_before_shift ** (1 - self.fusion_factor))
|
||
|
probs = probs / probs.sum()
|
||
|
|
||
|
return probs
|
||
|
|
||
|
def shift_context(self, i, context, last_token, context_tokens, probs_before_shift):
|
||
|
context_delta = [tuple([np.zeros(x.shape).astype("float32") for x in p]) for p in context]
|
||
|
|
||
|
for i in range(self.num_iterations):
|
||
|
curr_shift = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_]) for p_ in
|
||
|
context_delta]
|
||
|
|
||
|
for p0, p1 in curr_shift:
|
||
|
p0.retain_grad()
|
||
|
p1.retain_grad()
|
||
|
|
||
|
shifted_context = list(map(add_context, context, curr_shift))
|
||
|
|
||
|
shifted_outputs = self.lm_model(last_token, past_key_values=shifted_context)
|
||
|
logits = shifted_outputs["logits"][:, -1, :]
|
||
|
probs = nn.functional.softmax(logits, dim=-1)
|
||
|
|
||
|
loss = 0.0
|
||
|
|
||
|
# CLIP LOSS
|
||
|
clip_loss, clip_losses = self.clip_loss(probs, context_tokens)
|
||
|
loss += self.clip_scale * clip_loss
|
||
|
|
||
|
# CE/Fluency loss
|
||
|
ce_loss = self.ce_scale * ((probs * probs.log()) - (probs * probs_before_shift.log())).sum(-1)
|
||
|
loss += ce_loss.sum()
|
||
|
|
||
|
loss.backward()
|
||
|
|
||
|
# --------- Specific Gen ---------
|
||
|
final_grads = self.norm_grad(context, context_tokens, curr_shift)
|
||
|
|
||
|
# --------- update context ---------
|
||
|
context_delta = list(map(add_context, final_grads, context_delta))
|
||
|
|
||
|
for p0, p1 in curr_shift:
|
||
|
p0.grad.data.zero_()
|
||
|
p1.grad.data.zero_()
|
||
|
|
||
|
new_context = []
|
||
|
for p0, p1 in context:
|
||
|
new_context.append((p0.detach(), p1.detach()))
|
||
|
context = new_context
|
||
|
|
||
|
context_delta = [tuple([torch.from_numpy(x).requires_grad_(True).to(device=self.device) for x in p_])
|
||
|
for p_ in context_delta]
|
||
|
context = list(map(add_context, context, context_delta))
|
||
|
|
||
|
new_context = []
|
||
|
for p0, p1 in context:
|
||
|
new_context.append((p0.detach(), p1.detach()))
|
||
|
context = new_context
|
||
|
|
||
|
return context
|
||
|
|
||
|
def norm_grad(self, context, context_tokens, curr_shift, ):
|
||
|
factor = 1
|
||
|
sep_grads = None
|
||
|
window_mask = torch.ones_like(context[0][0]).to(self.device)
|
||
|
|
||
|
for b in range(context_tokens.shape[0]):
|
||
|
tmp_sep_norms = [[(torch.norm(x.grad[b:(b + 1)] * window_mask[b:(b + 1)]) + 1e-15) for x in p_]
|
||
|
for p_ in curr_shift]
|
||
|
|
||
|
# normalize gradients
|
||
|
tmp_grad = [tuple([-self.stepsize * factor * (
|
||
|
x.grad[b:(b + 1)] * window_mask[b:(b + 1)] / tmp_sep_norms[i][
|
||
|
j] ** self.grad_norm_factor).data.cpu().numpy()
|
||
|
for j, x in enumerate(p_)])
|
||
|
for i, p_ in enumerate(curr_shift)]
|
||
|
if sep_grads is None:
|
||
|
sep_grads = tmp_grad
|
||
|
else:
|
||
|
for l_index in range(len(sep_grads)):
|
||
|
sep_grads[l_index] = list(sep_grads[l_index])
|
||
|
for k_index in range(len(sep_grads[0])):
|
||
|
sep_grads[l_index][k_index] = np.concatenate(
|
||
|
(sep_grads[l_index][k_index], tmp_grad[l_index][k_index]), axis=0)
|
||
|
sep_grads[l_index] = tuple(sep_grads[l_index])
|
||
|
final_grads = sep_grads
|
||
|
|
||
|
return final_grads
|
||
|
|
||
|
def update_special_tokens_logits(self, context_tokens, i, logits):
|
||
|
for beam_id in range(context_tokens.shape[0]):
|
||
|
for token_idx in set(context_tokens[beam_id][-4:].tolist()):
|
||
|
factor = self.repetition_penalty if logits[beam_id, token_idx] > 0 else (1 / self.repetition_penalty)
|
||
|
logits[beam_id, token_idx] /= factor
|
||
|
|
||
|
if i >= self.ef_idx:
|
||
|
factor = self.end_factor if logits[beam_id, self.end_token] > 0 else (1 / self.end_factor)
|
||
|
logits[beam_id, self.end_token] *= factor
|
||
|
if i == 0:
|
||
|
start_factor = 1.6
|
||
|
factor = start_factor if logits[beam_id, self.end_token] > 0 else (1 / start_factor)
|
||
|
logits[beam_id, self.end_token] /= factor
|
||
|
|
||
|
for token_idx in list(self.forbidden_tokens):
|
||
|
factor = self.forbidden_factor if logits[beam_id, token_idx] > 0 else (1 / self.forbidden_factor)
|
||
|
logits[beam_id, token_idx] /= factor
|
||
|
|
||
|
return logits
|
||
|
|
||
|
def clip_loss(self, probs, context_tokens):
|
||
|
for p_ in self.clip.transformer.parameters():
|
||
|
if p_.grad is not None:
|
||
|
p_.grad.data.zero_()
|
||
|
|
||
|
top_size = 512
|
||
|
top_probs, top_indices = probs.topk(top_size, -1)
|
||
|
|
||
|
prefix_texts = [self.lm_tokenizer.decode(x, skip_special_tokens=True) for x in context_tokens]
|
||
|
|
||
|
clip_loss = 0
|
||
|
losses = []
|
||
|
|
||
|
top_texts = []
|
||
|
for idx_p in range(probs.shape[0]):
|
||
|
prefix_text = prefix_texts[idx_p]
|
||
|
for x in top_indices[idx_p]:
|
||
|
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
|
||
|
|
||
|
text_features = self.get_txt_features(top_texts)#.reshape(probs.size(0), top_size, -1)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
similiraties = (self.image_features @ text_features.T).reshape(probs.size(0), -1)
|
||
|
similiraties = similiraties.reshape(probs.size(0), -1)
|
||
|
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
|
||
|
target_probs = target_probs.type(torch.float32)
|
||
|
|
||
|
clip_loss += torch.sum(-(target_probs * torch.log(top_probs)))
|
||
|
# for idx_p in range(probs.shape[0]):
|
||
|
# top_texts = []
|
||
|
# prefix_text = prefix_texts[idx_p]
|
||
|
# for x in top_indices[idx_p]:
|
||
|
# top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
|
||
|
# text_features = self.get_txt_features(top_texts)
|
||
|
#
|
||
|
# with torch.no_grad():
|
||
|
# similiraties = (self.image_features @ text_features.T)
|
||
|
# target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
|
||
|
# target_probs = target_probs.type(torch.float32)
|
||
|
#
|
||
|
# target = torch.zeros_like(probs[idx_p])
|
||
|
# target[top_indices[idx_p]] = target_probs[0]
|
||
|
# target = target.unsqueeze(0)
|
||
|
# cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
|
||
|
#
|
||
|
# clip_loss += cur_clip_loss
|
||
|
# losses.append(cur_clip_loss)
|
||
|
|
||
|
return clip_loss, losses
|
||
|
|
||
|
def clip_loss_old(self, probs, context_tokens):
|
||
|
for p_ in self.clip.transformer.parameters():
|
||
|
if p_.grad is not None:
|
||
|
p_.grad.data.zero_()
|
||
|
|
||
|
top_size = 512
|
||
|
_, top_indices = probs.topk(top_size, -1)
|
||
|
|
||
|
prefix_texts = [self.lm_tokenizer.decode(x).replace(self.lm_tokenizer.bos_token, '') for x in context_tokens]
|
||
|
|
||
|
clip_loss = 0
|
||
|
losses = []
|
||
|
for idx_p in range(probs.shape[0]):
|
||
|
top_texts = []
|
||
|
prefix_text = prefix_texts[idx_p]
|
||
|
for x in top_indices[idx_p]:
|
||
|
top_texts.append(prefix_text + self.lm_tokenizer.decode(x))
|
||
|
text_features = self.get_txt_features(top_texts)
|
||
|
|
||
|
with torch.no_grad():
|
||
|
similiraties = (self.image_features @ text_features.T)
|
||
|
target_probs = nn.functional.softmax(similiraties / self.clip_loss_temperature, dim=-1).detach()
|
||
|
target_probs = target_probs.type(torch.float32)
|
||
|
|
||
|
target = torch.zeros_like(probs[idx_p])
|
||
|
target[top_indices[idx_p]] = target_probs[0]
|
||
|
target = target.unsqueeze(0)
|
||
|
cur_clip_loss = torch.sum(-(target * torch.log(probs[idx_p:(idx_p + 1)])))
|
||
|
|
||
|
clip_loss += cur_clip_loss
|
||
|
losses.append(cur_clip_loss)
|
||
|
|
||
|
return clip_loss, losses
|