lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
180 lines
6.5 KiB
180 lines
6.5 KiB
"""run inference of VQA for submission"""
|
|
import argparse
|
|
import json
|
|
import os
|
|
from os.path import exists
|
|
from time import time
|
|
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from apex import amp
|
|
from horovod import torch as hvd
|
|
import numpy as np
|
|
from cytoolz import concat
|
|
|
|
from data import (TokenBucketSampler, PrefetchLoader,
|
|
DetectFeatLmdb, TxtTokLmdb, VqaEvalDataset, vqa_eval_collate)
|
|
from model import UniterForVisualQuestionAnswering
|
|
|
|
from utils.logger import LOGGER
|
|
from utils.distributed import all_gather_list
|
|
from utils.misc import Struct
|
|
from utils.const import BUCKET_SIZE, IMG_DIM
|
|
|
|
|
|
def main(opts):
|
|
hvd.init()
|
|
n_gpu = hvd.size()
|
|
device = torch.device("cuda", hvd.local_rank())
|
|
torch.cuda.set_device(hvd.local_rank())
|
|
rank = hvd.rank()
|
|
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
|
|
"16-bits training: {}".format(
|
|
device, n_gpu, hvd.rank(), opts.fp16))
|
|
|
|
hps_file = f'{opts.output_dir}/log/hps.json'
|
|
model_opts = Struct(json.load(open(hps_file)))
|
|
|
|
# train_examples = None
|
|
ans2label_file = f'{opts.output_dir}/ckpt/ans2label.json'
|
|
ans2label = json.load(open(ans2label_file))
|
|
label2ans = {label: ans for ans, label in ans2label.items()}
|
|
|
|
# load DBs and image dirs
|
|
eval_img_db = DetectFeatLmdb(opts.img_db,
|
|
model_opts.conf_th, model_opts.max_bb,
|
|
model_opts.min_bb, model_opts.num_bb,
|
|
opts.compressed_db)
|
|
eval_txt_db = TxtTokLmdb(opts.txt_db, -1)
|
|
eval_dataset = VqaEvalDataset(len(ans2label), eval_txt_db, eval_img_db)
|
|
|
|
# Prepare model
|
|
if exists(opts.checkpoint):
|
|
ckpt_file = opts.checkpoint
|
|
else:
|
|
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
|
|
checkpoint = torch.load(ckpt_file)
|
|
model = UniterForVisualQuestionAnswering.from_pretrained(
|
|
f'{opts.output_dir}/log/model.json', checkpoint,
|
|
img_dim=IMG_DIM, num_answer=len(ans2label))
|
|
model.to(device)
|
|
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
|
|
|
|
sampler = TokenBucketSampler(eval_dataset.lens, bucket_size=BUCKET_SIZE,
|
|
batch_size=opts.batch_size, droplast=False)
|
|
eval_dataloader = DataLoader(eval_dataset,
|
|
batch_sampler=sampler,
|
|
num_workers=opts.n_workers,
|
|
pin_memory=opts.pin_mem,
|
|
collate_fn=vqa_eval_collate)
|
|
eval_dataloader = PrefetchLoader(eval_dataloader)
|
|
|
|
val_log, results, logits = evaluate(model, eval_dataloader, label2ans,
|
|
opts.save_logits)
|
|
result_dir = f'{opts.output_dir}/results_test'
|
|
if not exists(result_dir) and rank == 0:
|
|
os.makedirs(result_dir)
|
|
|
|
all_results = list(concat(all_gather_list(results)))
|
|
if opts.save_logits:
|
|
all_logits = {}
|
|
for id2logit in all_gather_list(logits):
|
|
all_logits.update(id2logit)
|
|
if hvd.rank() == 0:
|
|
with open(f'{result_dir}/'
|
|
f'results_{opts.checkpoint}_all.json', 'w') as f:
|
|
json.dump(all_results, f)
|
|
if opts.save_logits:
|
|
np.savez(f'{result_dir}/logits_{opts.checkpoint}_all.npz',
|
|
**all_logits)
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, eval_loader, label2ans, save_logits=False):
|
|
LOGGER.info("start running evaluation...")
|
|
model.eval()
|
|
n_ex = 0
|
|
st = time()
|
|
results = []
|
|
logits = {}
|
|
for i, batch in enumerate(eval_loader):
|
|
qids = batch['qids']
|
|
scores = model(batch, compute_loss=False)
|
|
answers = [label2ans[i]
|
|
for i in scores.max(dim=-1, keepdim=False
|
|
)[1].cpu().tolist()]
|
|
for qid, answer in zip(qids, answers):
|
|
results.append({'answer': answer, 'question_id': int(qid)})
|
|
if save_logits:
|
|
scores = scores.cpu()
|
|
for i, qid in enumerate(qids):
|
|
logits[qid] = scores[i].half().numpy()
|
|
if i % 100 == 0 and hvd.rank() == 0:
|
|
n_results = len(results)
|
|
n_results *= hvd.size() # an approximation to avoid hangs
|
|
LOGGER.info(f'{n_results}/{len(eval_loader.dataset)} '
|
|
'answers predicted')
|
|
n_ex += len(qids)
|
|
n_ex = sum(all_gather_list(n_ex))
|
|
tot_time = time()-st
|
|
val_log = {'valid/ex_per_s': n_ex/tot_time}
|
|
model.train()
|
|
LOGGER.info(f"evaluation finished in {int(tot_time)} seconds "
|
|
f"at {int(n_ex/tot_time)} examples per second")
|
|
return val_log, results, logits
|
|
|
|
|
|
def compute_score_with_logits(logits, labels):
|
|
logits = torch.max(logits, 1)[1] # argmax
|
|
one_hots = torch.zeros(*labels.size(), device=labels.device)
|
|
one_hots.scatter_(1, logits.view(-1, 1), 1)
|
|
scores = (one_hots * labels)
|
|
return scores
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument("--txt_db",
|
|
default=None, type=str,
|
|
help="The input train corpus. (LMDB)")
|
|
parser.add_argument("--img_db",
|
|
default=None, type=str,
|
|
help="The input train images.")
|
|
parser.add_argument('--compressed_db', action='store_true',
|
|
help='use compressed LMDB')
|
|
parser.add_argument("--checkpoint",
|
|
default=None, type=str,
|
|
help="pretrained model (can take 'google-bert') ")
|
|
parser.add_argument("--batch_size",
|
|
default=8192, type=int,
|
|
help="number of tokens in a batch")
|
|
|
|
parser.add_argument(
|
|
"--output_dir", default=None, type=str,
|
|
help="The output directory where the model checkpoints will be "
|
|
"written.")
|
|
|
|
parser.add_argument("--save_logits", action='store_true',
|
|
help="Whether to save logits (for making ensemble)")
|
|
|
|
# Prepro parameters
|
|
|
|
# device parameters
|
|
parser.add_argument('--fp16',
|
|
action='store_true',
|
|
help="Whether to use 16-bit float precision instead "
|
|
"of 32-bit")
|
|
parser.add_argument('--n_workers', type=int, default=4,
|
|
help="number of data workers")
|
|
parser.add_argument('--pin_mem', action='store_true',
|
|
help="pin memory")
|
|
|
|
args = parser.parse_args()
|
|
|
|
# options safe guard
|
|
# TODO
|
|
|
|
main(args)
|