magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
131 lines
5.4 KiB
131 lines
5.4 KiB
import argparse
|
|
import ipdb
|
|
from tqdm import tqdm
|
|
import progressbar
|
|
import torch
|
|
import ipdb
|
|
import clip
|
|
from model.ZeroCLIP import CLIPTextGenerator
|
|
from model.ZeroCLIP_batched import CLIPTextGenerator as CLIPTextGenerator_multigpu
|
|
|
|
def get_args():
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
|
|
parser.add_argument("--test_path", type=str)
|
|
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
|
|
parser.add_argument("--save_name", type=str, help="the name of the saved file")
|
|
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--lm_model", type=str, default="gpt-2", help="gpt-2 or gpt-neo")
|
|
parser.add_argument("--clip_checkpoints", type=str, default="./clip_checkpoints", help="path to CLIP")
|
|
parser.add_argument("--target_seq_length", type=int, default=15)
|
|
parser.add_argument("--cond_text", type=str, default="Image of a")
|
|
parser.add_argument("--reset_context_delta", action="store_true",
|
|
help="Should we reset the context at each token gen")
|
|
parser.add_argument("--num_iterations", type=int, default=5)
|
|
parser.add_argument("--clip_loss_temperature", type=float, default=0.01)
|
|
parser.add_argument("--clip_scale", type=float, default=1)
|
|
parser.add_argument("--ce_scale", type=float, default=0.2)
|
|
parser.add_argument("--stepsize", type=float, default=0.3)
|
|
parser.add_argument("--grad_norm_factor", type=float, default=0.9)
|
|
parser.add_argument("--fusion_factor", type=float, default=0.99)
|
|
parser.add_argument("--repetition_penalty", type=float, default=1)
|
|
parser.add_argument("--end_token", type=str, default=".", help="Token to end text")
|
|
parser.add_argument("--end_factor", type=float, default=1.01, help="Factor to increase end_token")
|
|
parser.add_argument("--forbidden_factor", type=float, default=20, help="Factor to decrease forbidden tokens")
|
|
parser.add_argument("--beam_size", type=int, default=1)
|
|
|
|
parser.add_argument("--multi_gpu", action="store_true")
|
|
|
|
parser.add_argument('--run_type',
|
|
default='caption',
|
|
nargs='?',
|
|
choices=['caption', 'arithmetics'])
|
|
|
|
parser.add_argument("--caption_img_path", type=str, default='example_images/captions/COCO_val2014_000000008775.jpg',
|
|
help="Path to image for captioning")
|
|
|
|
parser.add_argument("--arithmetics_imgs", nargs="+",
|
|
default=['example_images/arithmetics/woman2.jpg',
|
|
'example_images/arithmetics/king2.jpg',
|
|
'example_images/arithmetics/man2.jpg'])
|
|
parser.add_argument("--arithmetics_weights", nargs="+", default=[1, 1, -1])
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
def run(args, text_generator, img_path):
|
|
image_features = text_generator.get_img_feature([img_path], None)
|
|
captions = text_generator.run(image_features, args.cond_text, beam_size=args.beam_size)
|
|
|
|
encoded_captions = [text_generator.clip.encode_text(clip.tokenize(c).to(text_generator.device)) for c in captions]
|
|
encoded_captions = [x / x.norm(dim=-1, keepdim=True) for x in encoded_captions]
|
|
best_clip_idx = (torch.cat(encoded_captions) @ image_features.t()).squeeze().argmax().item()
|
|
return captions
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if torch.cuda.is_available():
|
|
print ('Cuda is available.')
|
|
cuda_available = torch.cuda.is_available()
|
|
args = get_args()
|
|
device = torch.device('cuda')
|
|
|
|
save_path_prefix = args.save_path_prefix
|
|
import os
|
|
if os.path.exists(save_path_prefix):
|
|
pass
|
|
else: # recursively construct directory
|
|
os.makedirs(save_path_prefix, exist_ok=True)
|
|
# parse save name
|
|
save_name = args.save_name
|
|
full_save_path = save_path_prefix + '/' + save_name
|
|
print ('full save path is {}'.format(full_save_path))
|
|
|
|
print ('Loading data...')
|
|
import json
|
|
with open(args.test_path) as f:
|
|
item_list = json.load(f)
|
|
print ('Data loaded.')
|
|
print ('Number of test instances is {}'.format(len(item_list)))
|
|
|
|
# ZeroCap generator
|
|
text_generator = CLIPTextGenerator(**vars(args))
|
|
|
|
result_list = []
|
|
invalid_num = 0
|
|
print ('----------------------------------------------------------------')
|
|
test_num = len(item_list)
|
|
#test_num = 10
|
|
print ('Number of inference instances is {}'.format(test_num))
|
|
p = progressbar.ProgressBar(test_num)
|
|
p.start()
|
|
for p_idx in tqdm(range(test_num)):
|
|
p.update(p_idx)
|
|
one_test_dict = item_list[p_idx]
|
|
|
|
one_res_dict = {
|
|
'split':one_test_dict['split'],
|
|
'image_name':one_test_dict['image_name'],
|
|
#'file_path':one_test_dict['file_path'],
|
|
'captions':one_test_dict['captions']
|
|
}
|
|
|
|
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
|
|
try:
|
|
output_text = run(args, text_generator, img_path=image_full_path)
|
|
one_res_dict['prediction'] = output_text[0]
|
|
result_list.append(one_res_dict)
|
|
except Exception as error:
|
|
print(f'[!] ERROR:', error)
|
|
invalid_num += 1
|
|
print ('invalid number is {}'.format(invalid_num))
|
|
continue
|
|
p.finish()
|
|
print ('Inference completed!')
|
|
|
|
import json
|
|
with open(full_save_path, 'w') as outfile:
|
|
json.dump(result_list, outfile, indent=4)
|