logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

118 lines
4.5 KiB

import os
import tempfile
import sys
sys.path.append('CLIP')
from pathlib import Path
import cog
import argparse
import torch
import clip
from model.ZeroCLIP import CLIPTextGenerator
def perplexity_score(text, lm_model, lm_tokenizer, device):
encodings = lm_tokenizer(f'{lm_tokenizer.bos_token + text}', return_tensors='pt')
input_ids = encodings.input_ids.to(device)
target_ids = input_ids.clone()
outputs = lm_model(input_ids, labels=target_ids)
log_likelihood = outputs[0]
ll = log_likelihood.item()
return ll
class Predictor(cog.Predictor):
def setup(self):
self.args = get_args()
self.args.reset_context_delta = True
self.text_generator = CLIPTextGenerator(**vars(self.args))
@cog.input(
"image",
type=Path,
help="input image"
)
@cog.input(
"cond_text",
type=str,
default='Image of a',
help="conditional text",
)
@cog.input(
"beam_size",
type=int,
default=5, min=1, max=10,
help="Number of beams to use",
)
@cog.input(
"end_factor",
type=float,
default=1.01, min=1.0, max=1.10,
help="Higher value for shorter captions",
)
@cog.input(
"max_seq_length",
type=int,
default=15, min=1, max=20,
help="Maximum number of tokens to generate",
)
@cog.input(
"ce_loss_scale",
type=float,
default=0.2, min=0.0, max=0.6,
help="Scale of cross-entropy loss with un-shifted language model",
)
def predict(self, image, cond_text, beam_size, end_factor, max_seq_length, ce_loss_scale):
self.args.cond_text = cond_text
self.text_generator.end_factor = end_factor
self.text_generator.target_seq_length = max_seq_length
self.text_generator.ce_scale = ce_loss_scale
image_features = self.text_generator.get_img_feature([str(image)], None)
captions = self.text_generator.run(image_features, self.args.cond_text, beam_size=beam_size)
# CLIP SCORE
encoded_captions = [self.text_generator.clip.encode_text(clip.tokenize(c).to(self.text_generator.device))
for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
# Perplexity SCORE
ppl_scores = [perplexity_score(x, self.text_generator.lm_model, self.text_generator.lm_tokenizer, self.text_generator.device) for x in captions]
best_ppl_index = torch.tensor(ppl_scores).argmin().item()
best_clip_caption = self.args.cond_text + captions[best_clip_idx]
best_mixed = self.args.cond_text + captions[0]
best_PPL = self.args.cond_text + captions[best_ppl_index]
final = f'Best CLIP: {best_clip_caption} \nBest fluency: {best_PPL} \nBest mixed: {best_mixed}'
return final
# return self.args.cond_text + captions[best_clip_idx]
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=5)
args = parser.parse_args('')
return args