logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

131 lines
5.4 KiB

import argparse
import ipdb
from tqdm import tqdm
import progressbar
import torch
import ipdb
import clip
from model.ZeroCLIP import CLIPTextGenerator
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
parser.add_argument("--test_path", type=str)
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
parser.add_argument("--save_name", type=str, help="the name of the saved file")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
parser.add_argument("--target_seq_length", type=int, default=15)
parser.add_argument("--cond_text", type=str, default="Image of a")
parser.add_argument("--reset_context_delta", action="store_true",
help="Should we reset the context at each token gen")
parser.add_argument("--num_iterations", type=int, default=5)
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
parser.add_argument("--clip_scale", type=float, default=1)
parser.add_argument("--ce_scale", type=float, default=0.2)
parser.add_argument("--stepsize", type=float, default=0.3)
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
parser.add_argument("--fusion_factor", type=float, default=0.99)
parser.add_argument("--repetition_penalty", type=float, default=1)
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
parser.add_argument("--beam_size", type=int, default=1)
parser.add_argument("--multi_gpu", action="store_true")
parser.add_argument('--run_type',
default='caption',
nargs='?',
choices=['caption', 'arithmetics'])
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
help="Path to image for captioning")
parser.add_argument("--arithmetics_imgs", nargs="+",
default=['example_images/arithmetics/woman2.jpg',
'example_images/arithmetics/king2.jpg',
'example_images/arithmetics/man2.jpg'])
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
args = parser.parse_args()
return args
def run(args, text_generator, img_path):
image_features = text_generator.get_img_feature([img_path], None)
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
return captions
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = get_args()
device = torch.device('cuda')
save_path_prefix = args.save_path_prefix
import os
if os.path.exists(save_path_prefix):
pass
else: # recursively construct directory
os.makedirs(save_path_prefix, exist_ok=True)
# parse save name
save_name = args.save_name
full_save_path = save_path_prefix + '/' + save_name
print ('full save path is {}'.format(full_save_path))
print ('Loading data...')
import json
with open(args.test_path) as f:
item_list = json.load(f)
print ('Data loaded.')
print ('Number of test instances is {}'.format(len(item_list)))
# ZeroCap generator
text_generator = CLIPTextGenerator(**vars(args))
result_list = []
invalid_num = 0
print ('----------------------------------------------------------------')
test_num = len(item_list)
#test_num = 10
print ('Number of inference instances is {}'.format(test_num))
p = progressbar.ProgressBar(test_num)
p.start()
for p_idx in tqdm(range(test_num)):
p.update(p_idx)
one_test_dict = item_list[p_idx]
one_res_dict = {
'split':one_test_dict['split'],
'image_name':one_test_dict['image_name'],
#'file_path':one_test_dict['file_path'],
'captions':one_test_dict['captions']
}
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
try:
output_text = run(args, text_generator, img_path=image_full_path)
one_res_dict['prediction'] = output_text[0]
result_list.append(one_res_dict)
except Exception as error:
print(f'[!] ERROR:', error)
invalid_num += 1
print ('invalid number is {}'.format(invalid_num))
continue
p.finish()
print ('Inference completed!')
import json
with open(full_save_path, 'w') as outfile:
json.dump(result_list, outfile, indent=4)