logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

136 lines
4.7 KiB

import json
import copy
import torch
import progressbar
import numpy as np
from PIL import Image
class CLIPIndex:
def __init__(self, index_matrix_path, mapping_dict_path, clip):
'''
index_path: the pre-trained index
mapping_dict_path: the pre-indexed mapping dictionary
clip: the pre-trained clip model
'''
print ('Loading index...')
self.index_matrix = self.normalization(self.load_matrix(index_matrix_path))
print ('Index loaded.')
print (self.index_matrix.shape)
with open(mapping_dict_path) as f:
self.mapping_dict = json.load(f)
self.clip = clip
def load_matrix(self, in_f):
matrix_list = []
with open(in_f, 'r', encoding = 'utf8') as i:
lines = i.readlines()
for l in lines:
one_vec = [float(num) for num in l.strip('\n').split()]
matrix_list.append(one_vec)
return np.array(matrix_list)
def normalization(self, matrix):
'''
matrix: num_instance x num_feature
'''
return matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
def get_image_representation(self, image_path):
image_instance = Image.open(image_path)
image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy()
image_vec = self.normalization(image_vec)
return image_vec
def search_text(self, image_path):
image_vec = self.get_image_representation(image_path)
sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1]
top_idx = sort_idx_list[0]
return self.mapping_dict[str(top_idx)]
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument("--clip_name", type=str)
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
parser.add_argument("--test_path", type=str)
# index configuration
parser.add_argument("--index_matrix_path", type=str)
parser.add_argument("--mapping_dict_path", type=str)
# save configuration
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
parser.add_argument("--save_name", type=str, help="the name of the saved file")
return parser.parse_args()
import argparse
if __name__ == '__main__':
if torch.cuda.is_available():
print ('Cuda is available.')
cuda_available = torch.cuda.is_available()
args = parse_config()
device = torch.device('cuda')
save_path_prefix = args.save_path_prefix
import os
if os.path.exists(save_path_prefix):
pass
else: # recursively construct directory
os.makedirs(save_path_prefix, exist_ok=True)
# parse save name
save_name = args.save_name
full_save_path = save_path_prefix + '/' + save_name
print ('full save path is {}'.format(full_save_path))
print ('Loading CLIP...')
from clip import CLIP
clip = CLIP(args.clip_name)
if cuda_available:
clip = clip.cuda(device)
clip.eval()
print ('CLIP loaded!')
clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip)
print ('Loading data...')
import json
with open(args.test_path) as f:
item_list = json.load(f)
print ('Data loaded.')
print ('Number of test instances is {}'.format(len(item_list)))
result_list = []
invalid_num = 0
print ('----------------------------------------------------------------')
with torch.no_grad():
test_num = len(item_list)
#test_num = 10
print ('Number of inference instances is {}'.format(test_num))
p = progressbar.ProgressBar(test_num)
p.start()
for p_idx in range(test_num):
p.update(p_idx)
one_test_dict = item_list[p_idx]
one_res_dict = {
'split':one_test_dict['split'],
'image_name':one_test_dict['image_name'],
#'file_path':one_test_dict['file_path'],
'captions':one_test_dict['captions']
}
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
try:
output_text = clipindex.search_text(image_full_path)
one_res_dict['prediction'] = output_text
result_list.append(one_res_dict)
except:
invalid_num += 1
print ('invalid number is {}'.format(invalid_num))
continue
p.finish()
print ('Inference completed!')
import json
with open(full_save_path, 'w') as outfile:
json.dump(result_list, outfile, indent=4)