import sys import torch import numpy as np import progressbar import os def parse_config(): parser = argparse.ArgumentParser() parser.add_argument("--clip_name", type=str, default="openai/clip-vit-base-patch32") parser.add_argument("--text_file_path", type=str) # save configuration parser.add_argument("--save_index_prefix", type=str, help='where to save the mips index') parser.add_argument("--save_index_name", type=str) parser.add_argument("--save_mapping_dict_name", type=str, help="a json file that stores a dictory. The dictory contains mapping between mips index and caption text") # inference configuration parser.add_argument("--batch_size", type=int, help="the batch size used to conduct inference with CLIP") return parser.parse_args() def load_batch_text(text_file_path, batch_size): import json with open(text_file_path) as f: item_list = json.load(f) text_list = [] for item in item_list: captions = item["captions"] for cap in captions: text_list.append(cap) print ('Number of text instances is {}'.format(len(text_list))) data_num = len(text_list) batch_num = data_num // batch_size batch_text_list = [] s_idx, e_idx = 0, batch_size for p_idx in range(batch_num): one_batch_text_list = [] for idx in range(s_idx, e_idx): one_batch_text_list.append(text_list[idx]) batch_text_list.append(one_batch_text_list) return batch_text_list import argparse if __name__ == '__main__': if torch.cuda.is_available(): print ('Cuda is available.') cuda_available = torch.cuda.is_available() args = parse_config() device = torch.device('cuda') import os if os.path.exists(args.save_index_prefix): pass else: # recursively construct directory os.makedirs(args.save_index_prefix, exist_ok=True) print ('Loading CLIP...') from clip import CLIP model = CLIP(args.clip_name) if cuda_available: model = model.cuda(device) model.eval() print ('CLIP loaded!') print ('Loading text data...') batch_text_list = load_batch_text(args.text_file_path, args.batch_size) print ('Text data loaded.') res_text_vec_list, res_text_list = [], [] batch_num = len(batch_text_list) print ('Number of batches is {}'.format(batch_num)) print ('Start inference...') p = progressbar.ProgressBar(batch_num) p.start() with torch.no_grad(): for p_idx in range(batch_num): p.update(p_idx) one_text_batch = batch_text_list[p_idx] one_batch_vec = model.compute_batch_index_text_representation(one_text_batch).detach().cpu() one_batch_vec_list = one_batch_vec.unbind(dim=0) bsz = len(one_batch_vec_list) for k in range(bsz): res_text_vec_list.append(one_batch_vec_list[k].numpy()) res_text_list.append(one_text_batch[k]) p.finish() assert len(res_text_vec_list) == len(res_text_list) print ('Inference completed!') index_text_mapping_dict = {} for k in range(len(res_text_list)): index_text_mapping_dict[k] = res_text_list[k] mapping_list_save_path = args.save_index_prefix + '/' + args.save_mapping_dict_name import json with open(mapping_list_save_path, 'w') as outfile: json.dump(index_text_mapping_dict, outfile, indent=4) print ('Mapping dictionary saved!') print ('Start buiding index...') index_save_path = args.save_index_prefix + '/' + args.save_index_name with open(index_save_path, 'w', encoding = 'utf8') as o: for vec in res_text_vec_list: one_text = ' '.join([str(num) for num in vec]).strip() o.writelines(one_text + '\n') print ('Index completed!')