logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

155 lines
5.7 KiB

"""run inference for Image Text Retrieval"""
import argparse
import json
import os
from os.path import exists
import pickle
from time import time
import torch
from torch.utils.data import DataLoader
from apex import amp
from horovod import torch as hvd
from data import (PrefetchLoader,
DetectFeatLmdb, TxtTokLmdb, ItmEvalDataset, itm_eval_collate)
from model import UniterForImageTextRetrieval
from utils.logger import LOGGER
from utils.distributed import all_gather_list
from utils.misc import Struct
from utils.const import IMG_DIM
from eval.itm import itm_eval
from train_itm import inference # FIXME
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
hps_file = f'{opts.output_dir}/log/hps.json'
model_opts = Struct(json.load(open(hps_file)))
# load DBs and image dirs
eval_img_db = DetectFeatLmdb(opts.img_db,
model_opts.conf_th, model_opts.max_bb,
model_opts.min_bb, model_opts.num_bb,
opts.compressed_db)
eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size)
# Prepare model
if exists(opts.checkpoint):
ckpt_file = opts.checkpoint
else:
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
checkpoint = torch.load(ckpt_file)
model = UniterForImageTextRetrieval.from_pretrained(
f'{opts.output_dir}/log/model.json', checkpoint,
img_dim=IMG_DIM)
if 'rank_output' not in checkpoint:
model.init_output() # zero shot setting
model.to(device)
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
eval_dataloader = DataLoader(eval_dataset, batch_size=1,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=itm_eval_collate)
eval_dataloader = PrefetchLoader(eval_dataloader)
eval_log, results = evaluate(model, eval_dataloader)
if hvd.rank() == 0:
result_dir = f'{opts.output_dir}/itm_results_{opts.name}'
if not exists(result_dir) and rank == 0:
os.makedirs(result_dir)
out_file = f'{result_dir}/results_{opts.checkpoint}.bin'
if not exists(out_file):
with open(out_file, 'wb') as f:
pickle.dump(results, f)
with open(f'{result_dir}/scores_{opts.checkpoint}.json',
'w') as f:
json.dump(eval_log, f)
LOGGER.info(f'evaluation finished')
LOGGER.info(
f"======================== {opts.name} =========================\n"
f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
LOGGER.info("========================================================")
@torch.no_grad()
def evaluate(model, eval_loader):
model.eval()
st = time()
LOGGER.info("start running Image/Text Retrieval evaluation ...")
score_matrix = inference(model, eval_loader)
dset = eval_loader.dataset
all_score = hvd.allgather(score_matrix)
all_txt_ids = [i for ids in all_gather_list(dset.ids)
for i in ids]
all_img_ids = dset.all_img_ids
assert all_score.size() == (len(all_txt_ids), len(all_img_ids))
if hvd.rank() != 0:
return {}, tuple()
# NOTE: only use rank0 to compute final scores
eval_log = itm_eval(all_score, all_txt_ids, all_img_ids,
dset.txt2img, dset.img2txts)
results = (all_score, all_txt_ids, all_img_ids)
tot_time = time()-st
LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ")
return eval_log, results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--txt_db", default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument("--img_db", default=None, type=str,
help="The input train images.")
parser.add_argument("--name", default='flickr_val', type=str,
help="affects output path")
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model (can take 'google-bert') ")
parser.add_argument("--batch_size",
default=400, type=int,
help="number of tokens in a batch")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
# Prepro parameters
# device parameters
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
args = parser.parse_args()
main(args)