lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
155 lines
5.7 KiB
155 lines
5.7 KiB
"""run inference for Image Text Retrieval"""
|
|
import argparse
|
|
import json
|
|
import os
|
|
from os.path import exists
|
|
import pickle
|
|
from time import time
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from apex import amp
|
|
from horovod import torch as hvd
|
|
|
|
from data import (PrefetchLoader,
|
|
DetectFeatLmdb, TxtTokLmdb, ItmEvalDataset, itm_eval_collate)
|
|
from model import UniterForImageTextRetrieval
|
|
|
|
from utils.logger import LOGGER
|
|
from utils.distributed import all_gather_list
|
|
from utils.misc import Struct
|
|
from utils.const import IMG_DIM
|
|
|
|
from eval.itm import itm_eval
|
|
from train_itm import inference # FIXME
|
|
|
|
|
|
def main(opts):
|
|
hvd.init()
|
|
n_gpu = hvd.size()
|
|
device = torch.device("cuda", hvd.local_rank())
|
|
torch.cuda.set_device(hvd.local_rank())
|
|
rank = hvd.rank()
|
|
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
|
|
"16-bits training: {}".format(
|
|
device, n_gpu, hvd.rank(), opts.fp16))
|
|
|
|
hps_file = f'{opts.output_dir}/log/hps.json'
|
|
model_opts = Struct(json.load(open(hps_file)))
|
|
|
|
# load DBs and image dirs
|
|
eval_img_db = DetectFeatLmdb(opts.img_db,
|
|
model_opts.conf_th, model_opts.max_bb,
|
|
model_opts.min_bb, model_opts.num_bb,
|
|
opts.compressed_db)
|
|
eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
|
|
eval_dataset = ItmEvalDataset(eval_txt_db, eval_img_db, opts.batch_size)
|
|
|
|
# Prepare model
|
|
if exists(opts.checkpoint):
|
|
ckpt_file = opts.checkpoint
|
|
else:
|
|
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
|
|
checkpoint = torch.load(ckpt_file)
|
|
model = UniterForImageTextRetrieval.from_pretrained(
|
|
f'{opts.output_dir}/log/model.json', checkpoint,
|
|
img_dim=IMG_DIM)
|
|
if 'rank_output' not in checkpoint:
|
|
model.init_output() # zero shot setting
|
|
|
|
model.to(device)
|
|
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
|
|
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=1,
|
|
num_workers=opts.n_workers,
|
|
pin_memory=opts.pin_mem,
|
|
collate_fn=itm_eval_collate)
|
|
eval_dataloader = PrefetchLoader(eval_dataloader)
|
|
|
|
eval_log, results = evaluate(model, eval_dataloader)
|
|
if hvd.rank() == 0:
|
|
result_dir = f'{opts.output_dir}/itm_results_{opts.name}'
|
|
if not exists(result_dir) and rank == 0:
|
|
os.makedirs(result_dir)
|
|
out_file = f'{result_dir}/results_{opts.checkpoint}.bin'
|
|
if not exists(out_file):
|
|
with open(out_file, 'wb') as f:
|
|
pickle.dump(results, f)
|
|
with open(f'{result_dir}/scores_{opts.checkpoint}.json',
|
|
'w') as f:
|
|
json.dump(eval_log, f)
|
|
LOGGER.info(f'evaluation finished')
|
|
LOGGER.info(
|
|
f"======================== {opts.name} =========================\n"
|
|
f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n"
|
|
f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n"
|
|
f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n"
|
|
f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n"
|
|
f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n"
|
|
f"text retrieval R10: {eval_log['txt_r10']*100:.2f}")
|
|
LOGGER.info("========================================================")
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, eval_loader):
|
|
model.eval()
|
|
st = time()
|
|
LOGGER.info("start running Image/Text Retrieval evaluation ...")
|
|
score_matrix = inference(model, eval_loader)
|
|
dset = eval_loader.dataset
|
|
all_score = hvd.allgather(score_matrix)
|
|
all_txt_ids = [i for ids in all_gather_list(dset.ids)
|
|
for i in ids]
|
|
all_img_ids = dset.all_img_ids
|
|
assert all_score.size() == (len(all_txt_ids), len(all_img_ids))
|
|
if hvd.rank() != 0:
|
|
return {}, tuple()
|
|
# NOTE: only use rank0 to compute final scores
|
|
eval_log = itm_eval(all_score, all_txt_ids, all_img_ids,
|
|
dset.txt2img, dset.img2txts)
|
|
|
|
results = (all_score, all_txt_ids, all_img_ids)
|
|
tot_time = time()-st
|
|
LOGGER.info(f"evaluation finished in {int(tot_time)} seconds, ")
|
|
return eval_log, results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument("--txt_db", default=None, type=str,
|
|
help="The input train corpus. (LMDB)")
|
|
parser.add_argument("--img_db", default=None, type=str,
|
|
help="The input train images.")
|
|
parser.add_argument("--name", default='flickr_val', type=str,
|
|
help="affects output path")
|
|
parser.add_argument('--compressed_db', action='store_true',
|
|
help='use compressed LMDB')
|
|
parser.add_argument("--checkpoint",
|
|
default=None, type=str,
|
|
help="pretrained model (can take 'google-bert') ")
|
|
parser.add_argument("--batch_size",
|
|
default=400, type=int,
|
|
help="number of tokens in a batch")
|
|
|
|
parser.add_argument(
|
|
"--output_dir", default=None, type=str,
|
|
help="The output directory where the model checkpoints will be "
|
|
"written.")
|
|
|
|
# Prepro parameters
|
|
|
|
# device parameters
|
|
parser.add_argument('--fp16', action='store_true',
|
|
help="Whether to use 16-bit float precision instead "
|
|
"of 32-bit")
|
|
parser.add_argument('--n_workers', type=int, default=4,
|
|
help="number of data workers")
|
|
parser.add_argument('--pin_mem', action='store_true',
|
|
help="pin memory")
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|