lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
268 lines
10 KiB
268 lines
10 KiB
"""run inference of VCR for submission"""
|
|
import argparse
|
|
import json
|
|
import os
|
|
from os.path import exists
|
|
from time import time
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from apex import amp
|
|
from horovod import torch as hvd
|
|
|
|
from data import (DetectFeatLmdb, VcrEvalDataset, vcr_eval_collate,
|
|
PrefetchLoader)
|
|
from torch.utils.data.distributed import DistributedSampler
|
|
from model import BertForVisualCommonsenseReasoning
|
|
|
|
from utils.logger import LOGGER
|
|
from utils.distributed import all_gather_list
|
|
from utils.misc import NoOp, Struct
|
|
NUM_SPECIAL_TOKENS = 81
|
|
|
|
|
|
def load_img_feat(dir_list, path2imgdir, opts):
|
|
dir_ = dir_list.split(";")
|
|
assert len(dir_) <= 2, "More than two img_dirs found"
|
|
img_dir_gt, img_dir = None, None
|
|
gt_dir_path, dir_path = "", ""
|
|
for d in dir_:
|
|
if "gt" in d:
|
|
gt_dir_path = d
|
|
else:
|
|
dir_path = d
|
|
if gt_dir_path != "":
|
|
img_dir_gt = path2imgdir.get(gt_dir_path, None)
|
|
if img_dir_gt is None:
|
|
img_dir_gt = DetectFeatLmdb(gt_dir_path, -1,
|
|
opts.max_bb, opts.min_bb, 100,
|
|
opts.compressed_db)
|
|
path2imgdir[gt_dir_path] = img_dir_gt
|
|
if dir_path != "":
|
|
img_dir = path2imgdir.get(dir_path, None)
|
|
if img_dir is None:
|
|
img_dir = DetectFeatLmdb(dir_path, opts.conf_th,
|
|
opts.max_bb, opts.min_bb, opts.num_bb,
|
|
opts.compressed_db)
|
|
path2imgdir[dir_path] = img_dir
|
|
return img_dir, img_dir_gt, path2imgdir
|
|
|
|
|
|
def main(opts):
|
|
hvd.init()
|
|
n_gpu = hvd.size()
|
|
device = torch.device("cuda", hvd.local_rank())
|
|
torch.cuda.set_device(hvd.local_rank())
|
|
rank = hvd.rank()
|
|
opts.rank = rank
|
|
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
|
|
"16-bits training: {}".format(
|
|
device, n_gpu, hvd.rank(), opts.fp16))
|
|
|
|
hps_file = f'{opts.output_dir}/log/hps.json'
|
|
model_opts = Struct(json.load(open(hps_file)))
|
|
|
|
path2imgdir = {}
|
|
# load DBs and image dirs
|
|
val_img_dir, val_img_dir_gt, path2imgdir = load_img_feat(
|
|
opts.img_dir, path2imgdir, model_opts)
|
|
eval_dataset = VcrEvalDataset("test", opts.txt_db,
|
|
val_img_dir_gt, val_img_dir,
|
|
max_txt_len=-1)
|
|
|
|
# Prepare model
|
|
bert_model = json.load(open(f'{opts.txt_db}/meta.json'))['bert']
|
|
model = BertForVisualCommonsenseReasoning.from_pretrained(
|
|
bert_model, img_dim=2048, obj_cls=False,
|
|
state_dict={})
|
|
model.init_type_embedding()
|
|
model.init_word_embedding(NUM_SPECIAL_TOKENS)
|
|
if exists(opts.checkpoint):
|
|
ckpt_file = opts.checkpoint
|
|
else:
|
|
ckpt_file = f'{opts.output_dir}/ckpt/model_step_{opts.checkpoint}.pt'
|
|
checkpoint = torch.load(ckpt_file)
|
|
state_dict = checkpoint.get('model_state', checkpoint)
|
|
matched_state_dict = {}
|
|
unexpected_keys = set()
|
|
missing_keys = set()
|
|
for name, param in model.named_parameters():
|
|
missing_keys.add(name)
|
|
for key, data in state_dict.items():
|
|
if key in missing_keys:
|
|
matched_state_dict[key] = data
|
|
missing_keys.remove(key)
|
|
else:
|
|
unexpected_keys.add(key)
|
|
print("Unexpected_keys:", list(unexpected_keys))
|
|
print("Missing_keys:", list(missing_keys))
|
|
model.load_state_dict(matched_state_dict, strict=False)
|
|
if model_opts.cut_bert != -1:
|
|
# cut some layers of BERT
|
|
model.bert.encoder.layer = torch.nn.ModuleList(
|
|
model.bert.encoder.layer[:model_opts.cut_bert])
|
|
model.to(device)
|
|
if opts.fp16:
|
|
model = amp.initialize(model, enabled=opts.fp16, opt_level='O2')
|
|
|
|
sampler = DistributedSampler(
|
|
eval_dataset, num_replicas=n_gpu, rank=rank)
|
|
eval_dataloader = DataLoader(eval_dataset,
|
|
batch_size=opts.batch_size,
|
|
sampler=sampler,
|
|
num_workers=opts.n_workers,
|
|
pin_memory=opts.pin_mem,
|
|
collate_fn=vcr_eval_collate)
|
|
eval_dataloader = PrefetchLoader(eval_dataloader)
|
|
|
|
val_log, results = evaluate(model, eval_dataloader)
|
|
result_dir = f'{opts.output_dir}/results_{opts.split}'
|
|
if not exists(result_dir) and rank == 0:
|
|
os.makedirs(result_dir)
|
|
# dummy sync
|
|
_ = None
|
|
all_gather_list(_)
|
|
if n_gpu > 1:
|
|
with open(f'{opts.output_dir}/results_test/'
|
|
f'results_{opts.checkpoint}_rank{rank}.json',
|
|
'w') as f:
|
|
json.dump(results, f)
|
|
# dummy sync
|
|
_ = None
|
|
all_gather_list(_)
|
|
# join results
|
|
if n_gpu > 1:
|
|
results = []
|
|
for rank in range(n_gpu):
|
|
results.extend(json.load(open(
|
|
f'{opts.output_dir}/results_test/'
|
|
f'results_{opts.checkpoint}_rank{rank}.json')))
|
|
if rank == 0:
|
|
with open(f'{opts.output_dir}/results_test/'
|
|
f'results_{opts.checkpoint}_all.json', 'w') as f:
|
|
json.dump(results, f)
|
|
|
|
|
|
def compute_accuracies(out_qa, labels_qa, out_qar, labels_qar):
|
|
outputs_qa = out_qa.max(dim=-1)[1]
|
|
outputs_qar = out_qar.max(dim=-1)[1]
|
|
matched_qa = outputs_qa.squeeze() == labels_qa.squeeze()
|
|
matched_qar = outputs_qar.squeeze() == labels_qar.squeeze()
|
|
matched_joined = matched_qa & matched_qar
|
|
n_correct_qa = matched_qa.sum().item()
|
|
n_correct_qar = matched_qar.sum().item()
|
|
n_correct_joined = matched_joined.sum().item()
|
|
return n_correct_qa, n_correct_qar, n_correct_joined
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, val_loader):
|
|
if hvd.rank() == 0:
|
|
val_pbar = tqdm(total=len(val_loader))
|
|
else:
|
|
val_pbar = NoOp()
|
|
LOGGER.info(f"start running evaluation ...")
|
|
model.eval()
|
|
val_qa_loss, val_qar_loss = 0, 0
|
|
tot_qa_score, tot_qar_score, tot_score = 0, 0, 0
|
|
n_ex = 0
|
|
st = time()
|
|
results = {}
|
|
for i, batch in enumerate(val_loader):
|
|
qids, *inputs, qa_targets, qar_targets, _ = batch
|
|
scores = model(
|
|
*inputs, targets=None, compute_loss=False)
|
|
scores = scores.view(len(qids), -1)
|
|
if torch.max(qa_targets) > -1:
|
|
vcr_qa_loss = F.cross_entropy(
|
|
scores[:, :4], qa_targets.squeeze(-1), reduction="sum")
|
|
if scores.shape[1] > 8:
|
|
qar_scores = []
|
|
for batch_id in range(scores.shape[0]):
|
|
answer_ind = qa_targets[batch_id].item()
|
|
qar_index = [4+answer_ind*4+i
|
|
for i in range(4)]
|
|
qar_scores.append(scores[batch_id, qar_index])
|
|
qar_scores = torch.stack(qar_scores, dim=0)
|
|
else:
|
|
qar_scores = scores[:, 4:]
|
|
vcr_qar_loss = F.cross_entropy(
|
|
qar_scores, qar_targets.squeeze(-1), reduction="sum")
|
|
val_qa_loss += vcr_qa_loss.item()
|
|
val_qar_loss += vcr_qar_loss.item()
|
|
|
|
curr_qa_score, curr_qar_score, curr_score = compute_accuracies(
|
|
scores[:, :4], qa_targets, qar_scores, qar_targets)
|
|
tot_qar_score += curr_qar_score
|
|
tot_qa_score += curr_qa_score
|
|
tot_score += curr_score
|
|
for qid, score in zip(qids, scores):
|
|
results[qid] = score.cpu().tolist()
|
|
n_ex += len(qids)
|
|
val_pbar.update(1)
|
|
val_qa_loss = sum(all_gather_list(val_qa_loss))
|
|
val_qar_loss = sum(all_gather_list(val_qar_loss))
|
|
tot_qa_score = sum(all_gather_list(tot_qa_score))
|
|
tot_qar_score = sum(all_gather_list(tot_qar_score))
|
|
tot_score = sum(all_gather_list(tot_score))
|
|
n_ex = sum(all_gather_list(n_ex))
|
|
tot_time = time()-st
|
|
val_qa_loss /= n_ex
|
|
val_qar_loss /= n_ex
|
|
val_qa_acc = tot_qa_score / n_ex
|
|
val_qar_acc = tot_qar_score / n_ex
|
|
val_acc = tot_score / n_ex
|
|
val_log = {f'valid/vcr_qa_loss': val_qa_loss,
|
|
f'valid/vcr_qar_loss': val_qar_loss,
|
|
f'valid/acc_qa': val_qa_acc,
|
|
f'valid/acc_qar': val_qar_acc,
|
|
f'valid/acc': val_acc,
|
|
f'valid/ex_per_s': n_ex/tot_time}
|
|
model.train()
|
|
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
|
|
f"score_qa: {val_qa_acc*100:.2f} "
|
|
f"score_qar: {val_qar_acc*100:.2f} "
|
|
f"score: {val_acc*100:.2f} ")
|
|
return val_log, results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Required parameters
|
|
parser.add_argument("--txt_db",
|
|
default=None, type=str,
|
|
help="The input train corpus. (LMDB)")
|
|
parser.add_argument("--img_dir",
|
|
default=None, type=str,
|
|
help="The input train images.")
|
|
parser.add_argument('--compressed_db', action='store_true',
|
|
help='use compressed LMDB')
|
|
parser.add_argument("--split",
|
|
default="test", type=str,
|
|
help="The input split")
|
|
parser.add_argument("--checkpoint",
|
|
default=None, type=str,
|
|
help="pretrained model (can take 'google-bert') ")
|
|
parser.add_argument("--batch_size",
|
|
default=10, type=int,
|
|
help="number of tokens in a batch")
|
|
parser.add_argument(
|
|
"--output_dir", default=None, type=str,
|
|
help="The output directory where the model checkpoints will be "
|
|
"written.")
|
|
# device parameters
|
|
parser.add_argument('--fp16',
|
|
action='store_true',
|
|
help="Whether to use 16-bit float precision instead "
|
|
"of 32-bit")
|
|
parser.add_argument('--n_workers', type=int, default=4,
|
|
help="number of data workers")
|
|
parser.add_argument('--pin_mem', action='store_true',
|
|
help="pin memory")
|
|
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|