logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

415 lines
17 KiB

"""
Copyright (c) Microsoft Corporation.
Licensed under the MIT license.
UNITER finetuning for VQA
"""
import argparse
import json
import os
from os.path import abspath, dirname, exists, join
from time import time
import torch
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from torch.optim import Adam, Adamax
from apex import amp
from horovod import torch as hvd
from tqdm import tqdm
from data import (TokenBucketSampler, PrefetchLoader,
TxtTokLmdb, ImageLmdbGroup, ConcatDatasetWithLens,
VqaDataset, VqaEvalDataset,
vqa_collate, vqa_eval_collate)
from model import UniterForVisualQuestionAnswering
from optim import AdamW, get_lr_sched
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
broadcast_tensors)
from utils.save import ModelSaver, save_training_meta
from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed
from utils.const import BUCKET_SIZE, IMG_DIM
def build_dataloader(dataset, collate_fn, is_train, opts):
batch_size = (opts.train_batch_size if is_train
else opts.val_batch_size)
sampler = TokenBucketSampler(dataset.lens, bucket_size=BUCKET_SIZE,
batch_size=batch_size, droplast=is_train)
dataloader = DataLoader(dataset, batch_sampler=sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem, collate_fn=collate_fn)
dataloader = PrefetchLoader(dataloader)
return dataloader
def build_optimizer(model, opts):
""" vqa linear may get larger learning rate """
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
param_optimizer = [(n, p) for n, p in model.named_parameters()
if 'vqa_output' not in n]
param_top = [(n, p) for n, p in model.named_parameters()
if 'vqa_output' in n]
optimizer_grouped_parameters = [
{'params': [p for n, p in param_top
if not any(nd in n for nd in no_decay)],
'lr': opts.learning_rate,
'weight_decay': opts.weight_decay},
{'params': [p for n, p in param_top
if any(nd in n for nd in no_decay)],
'lr': opts.learning_rate,
'weight_decay': 0.0},
{'params': [p for n, p in param_optimizer
if not any(nd in n for nd in no_decay)],
'weight_decay': opts.weight_decay},
{'params': [p for n, p in param_optimizer
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
# currently Adam only
if opts.optim == 'adam':
OptimCls = Adam
elif opts.optim == 'adamax':
OptimCls = Adamax
elif opts.optim == 'adamw':
OptimCls = AdamW
else:
raise ValueError('invalid optimizer')
optimizer = OptimCls(optimizer_grouped_parameters,
lr=opts.learning_rate, betas=opts.betas)
return optimizer
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
if opts.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
"should be >= 1".format(
opts.gradient_accumulation_steps))
set_random_seed(opts.seed)
ans2label = json.load(open(f'{dirname(abspath(__file__))}'
f'/misc/ans2label.json'))
label2ans = {label: ans for ans, label in ans2label.items()}
# load DBs and image dirs
all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
opts.num_bb, opts.compressed_db)
# train
LOGGER.info(f"Loading Train Dataset "
f"{opts.train_txt_dbs}, {opts.train_img_dbs}")
train_datasets = []
for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs):
img_db = all_img_dbs[img_path]
txt_db = TxtTokLmdb(txt_path, opts.max_txt_len)
train_datasets.append(VqaDataset(len(ans2label), txt_db, img_db))
train_dataset = ConcatDatasetWithLens(train_datasets)
train_dataloader = build_dataloader(train_dataset, vqa_collate, True, opts)
# val
LOGGER.info(f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}")
val_img_db = all_img_dbs[opts.val_img_db]
val_txt_db = TxtTokLmdb(opts.val_txt_db, -1)
val_dataset = VqaEvalDataset(len(ans2label), val_txt_db, val_img_db)
val_dataloader = build_dataloader(val_dataset, vqa_eval_collate,
False, opts)
# Prepare model
if opts.checkpoint:
checkpoint = torch.load(opts.checkpoint)
else:
checkpoint = {}
all_dbs = opts.train_txt_dbs + [opts.val_txt_db]
toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
assert all(toker == json.load(open(f'{db}/meta.json'))['bert']
for db in all_dbs)
model = UniterForVisualQuestionAnswering.from_pretrained(
opts.model_config, checkpoint,
img_dim=IMG_DIM, num_answer=len(ans2label))
model.to(device)
# make sure every process has same model parameters in the beginning
broadcast_tensors([p.data for p in model.parameters()], 0)
set_dropout(model, opts.dropout)
# Prepare optimizer
optimizer = build_optimizer(model, opts)
model, optimizer = amp.initialize(model, optimizer,
enabled=opts.fp16, opt_level='O2')
global_step = 0
if rank == 0:
save_training_meta(opts)
TB_LOGGER.create(join(opts.output_dir, 'log'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
json.dump(ans2label,
open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w'))
os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions
add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(" Num examples = %d", len(train_dataset) * hvd.size())
LOGGER.info(" Batch size = %d", opts.train_batch_size)
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
LOGGER.info(" Num steps = %d", opts.num_train_steps)
running_loss = RunningMeter('loss')
model.train()
n_examples = 0
n_epoch = 0
start = time()
# quick hack for amp delay_unscale bug
optimizer.zero_grad()
optimizer.step()
while True:
for step, batch in enumerate(train_dataloader):
n_examples += batch['input_ids'].size(0)
loss = model(batch, compute_loss=True)
loss = loss.mean() * batch['targets'].size(1) # instance-leval bce
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
if not delay_unscale:
# gather gradients from every processes
# do this before unscaling to make sure every process uses
# the same gradient scale
grads = [p.grad.data for p in model.parameters()
if p.requires_grad and p.grad is not None]
all_reduce_and_rescale_tensors(grads, float(1))
running_loss(loss.item())
if (step + 1) % opts.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
lr_this_step = get_lr_sched(global_step, opts)
for i, param_group in enumerate(optimizer.param_groups):
if i == 0 or i == 1:
param_group['lr'] = lr_this_step * opts.lr_mul
elif i == 2 or i == 3:
param_group['lr'] = lr_this_step
else:
raise ValueError()
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
# log loss
losses = all_gather_list(running_loss)
running_loss = RunningMeter(
'loss', sum(l.val for l in losses)/len(losses))
TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
TB_LOGGER.step()
# update model params
if opts.grad_norm != -1:
grad_norm = clip_grad_norm_(amp.master_params(optimizer),
opts.grad_norm)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if global_step % 100 == 0:
# monitor training throughput
LOGGER.info(f'============Step {global_step}=============')
tot_ex = sum(all_gather_list(n_examples))
ex_per_sec = int(tot_ex / (time()-start))
LOGGER.info(f'{tot_ex} examples trained at '
f'{ex_per_sec} ex/s')
TB_LOGGER.add_scalar('perf/ex_per_s',
ex_per_sec, global_step)
LOGGER.info(f'===========================================')
if global_step % opts.valid_steps == 0:
val_log, results = validate(
model, val_dataloader, label2ans)
with open(f'{opts.output_dir}/results/'
f'results_{global_step}_'
f'rank{rank}.json', 'w') as f:
json.dump(results, f)
TB_LOGGER.log_scaler_dict(val_log)
model_saver.save(model, global_step)
if global_step >= opts.num_train_steps:
break
if global_step >= opts.num_train_steps:
break
n_epoch += 1
LOGGER.info(f"finished {n_epoch} epochs")
val_log, results = validate(model, val_dataloader, label2ans)
with open(f'{opts.output_dir}/results/'
f'results_{global_step}_'
f'rank{rank}_final.json', 'w') as f:
json.dump(results, f)
TB_LOGGER.log_scaler_dict(val_log)
model_saver.save(model, f'{global_step}_final')
@torch.no_grad()
def validate(model, val_loader, label2ans):
LOGGER.info("start running validation...")
model.eval()
val_loss = 0
tot_score = 0
n_ex = 0
st = time()
results = {}
for i, batch in enumerate(val_loader):
scores = model(batch, compute_loss=False)
targets = batch['targets']
loss = F.binary_cross_entropy_with_logits(
scores, targets, reduction='sum')
val_loss += loss.item()
tot_score += compute_score_with_logits(scores, targets).sum().item()
answers = [label2ans[i]
for i in scores.max(dim=-1, keepdim=False
)[1].cpu().tolist()]
for qid, answer in zip(batch['qids'], answers):
results[qid] = answer
n_ex += len(batch['qids'])
val_loss = sum(all_gather_list(val_loss))
tot_score = sum(all_gather_list(tot_score))
n_ex = sum(all_gather_list(n_ex))
tot_time = time()-st
val_loss /= n_ex
val_acc = tot_score / n_ex
val_log = {'valid/loss': val_loss,
'valid/acc': val_acc,
'valid/ex_per_s': n_ex/tot_time}
model.train()
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score: {val_acc*100:.2f}")
return val_log, results
def compute_score_with_logits(logits, labels):
logits = torch.max(logits, 1)[1] # argmax
one_hots = torch.zeros(*labels.size(), device=labels.device)
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = (one_hots * labels)
return scores
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
# TODO datasets
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--model_config",
default=None, type=str,
help="json file for model architecture")
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
# Prepro parameters
parser.add_argument('--max_txt_len', type=int, default=60,
help='max number of tokens in text (BERT BPE)')
parser.add_argument('--conf_th', type=float, default=0.2,
help='threshold for dynamic bounding boxes '
'(-1 for fixed)')
parser.add_argument('--max_bb', type=int, default=100,
help='max number of bounding boxes')
parser.add_argument('--min_bb', type=int, default=10,
help='min number of bounding boxes')
parser.add_argument('--num_bb', type=int, default=36,
help='static number of bounding boxes')
# training parameters
parser.add_argument("--train_batch_size", default=4096, type=int,
help="Total batch size for training. "
"(batch by tokens)")
parser.add_argument("--val_batch_size", default=4096, type=int,
help="Total batch size for validation. "
"(batch by tokens)")
parser.add_argument('--gradient_accumulation_steps', type=int, default=16,
help="Number of updates steps to accumualte before "
"performing a backward/update pass.")
parser.add_argument("--learning_rate", default=3e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--lr_mul", default=1.0, type=float,
help="multiplier for top layer lr")
parser.add_argument("--valid_steps", default=1000, type=int,
help="Run validation every X steps")
parser.add_argument("--num_train_steps", default=100000, type=int,
help="Total number of training updates to perform.")
parser.add_argument("--optim", default='adam',
choices=['adam', 'adamax', 'adamw'],
help="optimizer")
parser.add_argument("--betas", default=[0.9, 0.98], nargs='+',
help="beta for adam optimizer")
parser.add_argument("--decay", default='linear',
choices=['linear', 'invsqrt', 'constant', 'vqa'],
help="learning rate decay method")
parser.add_argument("--decay_int", default=2000, type=int,
help="interval between VQA lr decy")
parser.add_argument("--warm_int", default=2000, type=int,
help="interval for VQA lr warmup")
parser.add_argument("--decay_st", default=20000, type=int,
help="when to start decay")
parser.add_argument("--decay_rate", default=0.2, type=float,
help="ratio of lr decay")
parser.add_argument("--dropout", default=0.1, type=float,
help="tune dropout regularization")
parser.add_argument("--weight_decay", default=0.0, type=float,
help="weight decay (L2) regularization")
parser.add_argument("--grad_norm", default=2.0, type=float,
help="gradient clipping (-1 for no clipping)")
parser.add_argument("--warmup_steps", default=4000, type=int,
help="Number of training steps to perform linear "
"learning rate warmup for. (invsqrt decay)")
# device parameters
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true', help="pin memory")
# can use config files
parser.add_argument('--config', help='JSON config files')
args = parse_with_config(parser)
if exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not "
"empty.".format(args.output_dir))
# options safe guard
# TODO
if args.conf_th == -1:
assert args.max_bb + args.max_txt_len + 2 <= 512
else:
assert args.num_bb + args.max_txt_len + 2 <= 512
main(args)