import argparse import ipdb from tqdm import tqdm import progressbar import torch import ipdb import clip from model.ZeroCLIP import CLIPTextGenerator from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images") parser.add_argument("--test_path", type=str) parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory") parser.add_argument("--save_name", type=str, help="the name of the saved file") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo") parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP") parser.add_argument("--target_seq_length", type=int, default=15) parser.add_argument("--cond_text", type=str, default="Image of a") parser.add_argument("--reset_context_delta", action="store_true", help="Should we reset the context at each token gen") parser.add_argument("--num_iterations", type=int, default=5) parser.add_argument("--clip_loss_temperature", type=float, default=0.01) parser.add_argument("--clip_scale", type=float, default=1) parser.add_argument("--ce_scale", type=float, default=0.2) parser.add_argument("--stepsize", type=float, default=0.3) parser.add_argument("--grad_norm_factor", type=float, default=0.9) parser.add_argument("--fusion_factor", type=float, default=0.99) parser.add_argument("--repetition_penalty", type=float, default=1) parser.add_argument("--end_token", type=str, default=".", help="Token to end text") parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token") parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens") parser.add_argument("--beam_size", type=int, default=1) parser.add_argument("--multi_gpu", action="store_true") parser.add_argument('--run_type', default='caption', nargs='?', choices=['caption', 'arithmetics']) parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg', help="Path to image for captioning") parser.add_argument("--arithmetics_imgs", nargs="+", default=['example_images/arithmetics/woman2.jpg', 'example_images/arithmetics/king2.jpg', 'example_images/arithmetics/man2.jpg']) parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1]) args = parser.parse_args() return args def run(args, text_generator, img_path): image_features = text_generator.get_img_feature([img_path], None) captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size) encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions] encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions] best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item() return captions if __name__ == '__main__': if torch.cuda.is_available(): print ('Cuda is available.') cuda_available = torch.cuda.is_available() args = get_args() device = torch.device('cuda') save_path_prefix = args.save_path_prefix import os if os.path.exists(save_path_prefix): pass else: # recursively construct directory os.makedirs(save_path_prefix, exist_ok=True) # parse save name save_name = args.save_name full_save_path = save_path_prefix + '/' + save_name print ('full save path is {}'.format(full_save_path)) print ('Loading data...') import json with open(args.test_path) as f: item_list = json.load(f) print ('Data loaded.') print ('Number of test instances is {}'.format(len(item_list))) # ZeroCap generator text_generator = CLIPTextGenerator(**vars(args)) result_list = [] invalid_num = 0 print ('----------------------------------------------------------------') test_num = len(item_list) #test_num = 10 print ('Number of inference instances is {}'.format(test_num)) p = progressbar.ProgressBar(test_num) p.start() for p_idx in tqdm(range(test_num)): p.update(p_idx) one_test_dict = item_list[p_idx] one_res_dict = { 'split':one_test_dict['split'], 'image_name':one_test_dict['image_name'], #'file_path':one_test_dict['file_path'], 'captions':one_test_dict['captions'] } image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name'] try: output_text = run(args, text_generator, img_path=image_full_path) one_res_dict['prediction'] = output_text[0] result_list.append(one_res_dict) except Exception as error: print(f'[!] ERROR:', error) invalid_num += 1 print ('invalid number is {}'.format(invalid_num)) continue p.finish() print ('Inference completed!') import json with open(full_save_path, 'w') as outfile: json.dump(result_list, outfile, indent=4)