clipcap
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
167 lines
6.7 KiB
167 lines
6.7 KiB
2 years ago
|
import clip
|
||
|
import torch
|
||
|
import skimage.io as io
|
||
|
import PIL.Image
|
||
|
import numpy as np
|
||
|
import torch.nn.functional as nnf
|
||
|
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
||
|
from tqdm import tqdm, trange
|
||
|
from clipcap_model import MLP, ClipCaptionModel, ClipCaptionPrefix
|
||
|
|
||
|
is_gpu = False
|
||
|
device = CUDA(0) if is_gpu else "cpu"
|
||
|
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
|
||
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||
|
CPU = torch.device('cpu')
|
||
|
|
||
|
|
||
|
def generate2(
|
||
|
model,
|
||
|
tokenizer,
|
||
|
tokens=None,
|
||
|
prompt=None,
|
||
|
embed=None,
|
||
|
entry_count=1,
|
||
|
entry_length=67, # maximum number of words
|
||
|
top_p=0.8,
|
||
|
temperature=1.,
|
||
|
stop_token: str = '.',
|
||
|
):
|
||
|
model.eval()
|
||
|
generated_num = 0
|
||
|
generated_list = []
|
||
|
stop_token_index = tokenizer.encode(stop_token)[0]
|
||
|
filter_value = -float("Inf")
|
||
|
device = next(model.parameters()).device
|
||
|
|
||
|
with torch.no_grad():
|
||
|
|
||
|
for entry_idx in trange(entry_count):
|
||
|
if embed is not None:
|
||
|
generated = embed
|
||
|
else:
|
||
|
if tokens is None:
|
||
|
tokens = torch.tensor(tokenizer.encode(prompt))
|
||
|
tokens = tokens.unsqueeze(0).to(device)
|
||
|
|
||
|
generated = model.gpt.transformer.wte(tokens)
|
||
|
|
||
|
for i in range(entry_length):
|
||
|
|
||
|
outputs = model.gpt(inputs_embeds=generated)
|
||
|
logits = outputs.logits
|
||
|
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
||
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||
|
cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
|
||
|
sorted_indices_to_remove = cumulative_probs > top_p
|
||
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ ..., :-1].clone()
|
||
|
sorted_indices_to_remove[..., 0] = 0
|
||
|
|
||
|
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
||
|
logits[:, indices_to_remove] = filter_value
|
||
|
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
||
|
next_token_embed = model.gpt.transformer.wte(next_token)
|
||
|
if tokens is None:
|
||
|
tokens = next_token
|
||
|
else:
|
||
|
tokens = torch.cat((tokens, next_token), dim=1)
|
||
|
generated = torch.cat((generated, next_token_embed), dim=1)
|
||
|
if stop_token_index == next_token.item():
|
||
|
break
|
||
|
|
||
|
output_list = list(tokens.squeeze().cpu().numpy())
|
||
|
output_text = tokenizer.decode(output_list)
|
||
|
generated_list.append(output_text)
|
||
|
|
||
|
return generated_list[0]
|
||
|
|
||
|
def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
|
||
|
entry_length=67, temperature=1., stop_token: str = '.'):
|
||
|
|
||
|
model.eval()
|
||
|
stop_token_index = tokenizer.encode(stop_token)[0]
|
||
|
tokens = None
|
||
|
scores = None
|
||
|
device = next(model.parameters()).device
|
||
|
seq_lengths = torch.ones(beam_size, device=device)
|
||
|
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
||
|
with torch.no_grad():
|
||
|
if embed is not None:
|
||
|
generated = embed
|
||
|
else:
|
||
|
if tokens is None:
|
||
|
tokens = torch.tensor(tokenizer.encode(prompt))
|
||
|
tokens = tokens.unsqueeze(0).to(device)
|
||
|
generated = model.gpt.transformer.wte(tokens)
|
||
|
for i in range(entry_length):
|
||
|
outputs = model.gpt(inputs_embeds=generated)
|
||
|
logits = outputs.logits
|
||
|
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
||
|
logits = logits.softmax(-1).log()
|
||
|
if scores is None:
|
||
|
scores, next_tokens = logits.topk(beam_size, -1)
|
||
|
generated = generated.expand(beam_size, *generated.shape[1:])
|
||
|
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
||
|
if tokens is None:
|
||
|
tokens = next_tokens
|
||
|
else:
|
||
|
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
||
|
tokens = torch.cat((tokens, next_tokens), dim=1)
|
||
|
else:
|
||
|
logits[is_stopped] = -float(np.inf)
|
||
|
logits[is_stopped, 0] = 0
|
||
|
scores_sum = scores[:, None] + logits
|
||
|
seq_lengths[~is_stopped] += 1
|
||
|
scores_sum_average = scores_sum / seq_lengths[:, None]
|
||
|
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
|
||
|
next_tokens_source = next_tokens // scores_sum.shape[1]
|
||
|
seq_lengths = seq_lengths[next_tokens_source]
|
||
|
next_tokens = next_tokens % scores_sum.shape[1]
|
||
|
next_tokens = next_tokens.unsqueeze(1)
|
||
|
tokens = tokens[next_tokens_source]
|
||
|
tokens = torch.cat((tokens, next_tokens), dim=1)
|
||
|
generated = generated[next_tokens_source]
|
||
|
scores = scores_sum_average * seq_lengths
|
||
|
is_stopped = is_stopped[next_tokens_source]
|
||
|
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
|
||
|
generated = torch.cat((generated, next_token_embed), dim=1)
|
||
|
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
||
|
if is_stopped.all():
|
||
|
break
|
||
|
scores = scores / seq_lengths
|
||
|
output_list = tokens.cpu().numpy()
|
||
|
output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
|
||
|
order = scores.argsort(descending=True)
|
||
|
output_texts = [output_texts[i] for i in order]
|
||
|
return output_texts
|
||
|
|
||
|
prefix_length = 10
|
||
|
|
||
|
model = ClipCaptionModel(prefix_length)
|
||
|
model_path = '/Users/zilliz/git/image_captioning/git/clipcap/weights/coco_weights.pt'
|
||
|
model.load_state_dict(torch.load(model_path, map_location=CPU))
|
||
|
model = model.eval()
|
||
|
|
||
|
use_beam_search = False #@param {type:"boolean"}
|
||
|
use_beam_search = True #@param {type:"boolean"}
|
||
|
|
||
|
UPLOADED_FILE = 'einstein.jpg'
|
||
|
image = io.imread(UPLOADED_FILE)
|
||
|
pil_image = PIL.Image.fromarray(image)
|
||
|
|
||
|
image = preprocess(pil_image).unsqueeze(0).to(device)
|
||
|
with torch.no_grad():
|
||
|
# if type(model) is ClipCaptionE2E:
|
||
|
# prefix_embed = model.forward_image(image)
|
||
|
# else:
|
||
|
prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
|
||
|
prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
|
||
|
if use_beam_search:
|
||
|
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
|
||
|
else:
|
||
|
generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
|
||
|
|
||
|
print(generated_text_prefix)
|
||
|
|
||
|
|