magic
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
136 lines
4.7 KiB
136 lines
4.7 KiB
2 years ago
|
import json
|
||
|
import copy
|
||
|
import torch
|
||
|
import progressbar
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
|
||
|
class CLIPIndex:
|
||
|
def __init__(self, index_matrix_path, mapping_dict_path, clip):
|
||
|
'''
|
||
|
index_path: the pre-trained index
|
||
|
mapping_dict_path: the pre-indexed mapping dictionary
|
||
|
clip: the pre-trained clip model
|
||
|
'''
|
||
|
print ('Loading index...')
|
||
|
self.index_matrix = self.normalization(self.load_matrix(index_matrix_path))
|
||
|
print ('Index loaded.')
|
||
|
print (self.index_matrix.shape)
|
||
|
with open(mapping_dict_path) as f:
|
||
|
self.mapping_dict = json.load(f)
|
||
|
self.clip = clip
|
||
|
|
||
|
def load_matrix(self, in_f):
|
||
|
matrix_list = []
|
||
|
with open(in_f, 'r', encoding = 'utf8') as i:
|
||
|
lines = i.readlines()
|
||
|
for l in lines:
|
||
|
one_vec = [float(num) for num in l.strip('\n').split()]
|
||
|
matrix_list.append(one_vec)
|
||
|
return np.array(matrix_list)
|
||
|
|
||
|
def normalization(self, matrix):
|
||
|
'''
|
||
|
matrix: num_instance x num_feature
|
||
|
'''
|
||
|
return matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
|
||
|
|
||
|
def get_image_representation(self, image_path):
|
||
|
image_instance = Image.open(image_path)
|
||
|
image_vec = self.clip.compute_batch_index_image_features([image_instance]).detach().cpu().numpy()
|
||
|
image_vec = self.normalization(image_vec)
|
||
|
return image_vec
|
||
|
|
||
|
def search_text(self, image_path):
|
||
|
image_vec = self.get_image_representation(image_path)
|
||
|
sort_idx_list = np.matmul(image_vec, self.index_matrix.transpose())[0].argsort()[::-1]
|
||
|
top_idx = sort_idx_list[0]
|
||
|
return self.mapping_dict[str(top_idx)]
|
||
|
|
||
|
|
||
|
def parse_config():
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("--clip_name", type=str)
|
||
|
parser.add_argument("--test_image_prefix_path", type=str, help="the folder that stores all test images")
|
||
|
parser.add_argument("--test_path", type=str)
|
||
|
# index configuration
|
||
|
parser.add_argument("--index_matrix_path", type=str)
|
||
|
parser.add_argument("--mapping_dict_path", type=str)
|
||
|
# save configuration
|
||
|
parser.add_argument("--save_path_prefix", type=str, help="save the result in which directory")
|
||
|
parser.add_argument("--save_name", type=str, help="the name of the saved file")
|
||
|
return parser.parse_args()
|
||
|
|
||
|
import argparse
|
||
|
if __name__ == '__main__':
|
||
|
if torch.cuda.is_available():
|
||
|
print ('Cuda is available.')
|
||
|
cuda_available = torch.cuda.is_available()
|
||
|
args = parse_config()
|
||
|
device = torch.device('cuda')
|
||
|
|
||
|
save_path_prefix = args.save_path_prefix
|
||
|
import os
|
||
|
if os.path.exists(save_path_prefix):
|
||
|
pass
|
||
|
else: # recursively construct directory
|
||
|
os.makedirs(save_path_prefix, exist_ok=True)
|
||
|
# parse save name
|
||
|
save_name = args.save_name
|
||
|
full_save_path = save_path_prefix + '/' + save_name
|
||
|
print ('full save path is {}'.format(full_save_path))
|
||
|
|
||
|
print ('Loading CLIP...')
|
||
|
from clip import CLIP
|
||
|
clip = CLIP(args.clip_name)
|
||
|
if cuda_available:
|
||
|
clip = clip.cuda(device)
|
||
|
clip.eval()
|
||
|
print ('CLIP loaded!')
|
||
|
|
||
|
clipindex = CLIPIndex(args.index_matrix_path, args.mapping_dict_path, clip)
|
||
|
|
||
|
print ('Loading data...')
|
||
|
import json
|
||
|
with open(args.test_path) as f:
|
||
|
item_list = json.load(f)
|
||
|
print ('Data loaded.')
|
||
|
print ('Number of test instances is {}'.format(len(item_list)))
|
||
|
|
||
|
result_list = []
|
||
|
invalid_num = 0
|
||
|
print ('----------------------------------------------------------------')
|
||
|
with torch.no_grad():
|
||
|
test_num = len(item_list)
|
||
|
#test_num = 10
|
||
|
print ('Number of inference instances is {}'.format(test_num))
|
||
|
p = progressbar.ProgressBar(test_num)
|
||
|
p.start()
|
||
|
for p_idx in range(test_num):
|
||
|
p.update(p_idx)
|
||
|
one_test_dict = item_list[p_idx]
|
||
|
|
||
|
one_res_dict = {
|
||
|
'split':one_test_dict['split'],
|
||
|
'image_name':one_test_dict['image_name'],
|
||
|
#'file_path':one_test_dict['file_path'],
|
||
|
'captions':one_test_dict['captions']
|
||
|
}
|
||
|
|
||
|
image_full_path = args.test_image_prefix_path + '/' + one_test_dict['image_name']
|
||
|
try:
|
||
|
output_text = clipindex.search_text(image_full_path)
|
||
|
one_res_dict['prediction'] = output_text
|
||
|
result_list.append(one_res_dict)
|
||
|
except:
|
||
|
invalid_num += 1
|
||
|
print ('invalid number is {}'.format(invalid_num))
|
||
|
continue
|
||
|
p.finish()
|
||
|
print ('Inference completed!')
|
||
|
|
||
|
import json
|
||
|
with open(full_save_path, 'w') as outfile:
|
||
|
json.dump(result_list, outfile, indent=4)
|
||
|
|