""" Copyright (c) Microsoft Corporation. Licensed under the MIT license. UNITER finetuning for VQA """ import argparse import json import os from os.path import abspath, dirname, exists, join from time import time import torch from torch.nn import functional as F from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from torch.optim import Adam, Adamax from apex import amp from horovod import torch as hvd from tqdm import tqdm from data import (TokenBucketSampler, PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, ConcatDatasetWithLens, VqaDataset, VqaEvalDataset, vqa_collate, vqa_eval_collate) from model import UniterForVisualQuestionAnswering from optim import AdamW, get_lr_sched from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, broadcast_tensors) from utils.save import ModelSaver, save_training_meta from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed from utils.const import BUCKET_SIZE, IMG_DIM def build_dataloader(dataset, collate_fn, is_train, opts): batch_size = (opts.train_batch_size if is_train else opts.val_batch_size) sampler = TokenBucketSampler(dataset.lens, bucket_size=BUCKET_SIZE, batch_size=batch_size, droplast=is_train) dataloader = DataLoader(dataset, batch_sampler=sampler, num_workers=opts.n_workers, pin_memory=opts.pin_mem, collate_fn=collate_fn) dataloader = PrefetchLoader(dataloader) return dataloader def build_optimizer(model, opts): """ vqa linear may get larger learning rate """ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] param_optimizer = [(n, p) for n, p in model.named_parameters() if 'vqa_output' not in n] param_top = [(n, p) for n, p in model.named_parameters() if 'vqa_output' in n] optimizer_grouped_parameters = [ {'params': [p for n, p in param_top if not any(nd in n for nd in no_decay)], 'lr': opts.learning_rate, 'weight_decay': opts.weight_decay}, {'params': [p for n, p in param_top if any(nd in n for nd in no_decay)], 'lr': opts.learning_rate, 'weight_decay': 0.0}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': opts.weight_decay}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] # currently Adam only if opts.optim == 'adam': OptimCls = Adam elif opts.optim == 'adamax': OptimCls = Adamax elif opts.optim == 'adamw': OptimCls = AdamW else: raise ValueError('invalid optimizer') optimizer = OptimCls(optimizer_grouped_parameters, lr=opts.learning_rate, betas=opts.betas) return optimizer def main(opts): hvd.init() n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) rank = hvd.rank() opts.rank = rank LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format( device, n_gpu, hvd.rank(), opts.fp16)) if opts.gradient_accumulation_steps < 1: raise ValueError("Invalid gradient_accumulation_steps parameter: {}, " "should be >= 1".format( opts.gradient_accumulation_steps)) set_random_seed(opts.seed) ans2label = json.load(open(f'{dirname(abspath(__file__))}' f'/misc/ans2label.json')) label2ans = {label: ans for ans, label in ans2label.items()} # load DBs and image dirs all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, opts.compressed_db) # train LOGGER.info(f"Loading Train Dataset " f"{opts.train_txt_dbs}, {opts.train_img_dbs}") train_datasets = [] for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): img_db = all_img_dbs[img_path] txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) train_datasets.append(VqaDataset(len(ans2label), txt_db, img_db)) train_dataset = ConcatDatasetWithLens(train_datasets) train_dataloader = build_dataloader(train_dataset, vqa_collate, True, opts) # val LOGGER.info(f"Loading Train Dataset {opts.val_txt_db}, {opts.val_img_db}") val_img_db = all_img_dbs[opts.val_img_db] val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) val_dataset = VqaEvalDataset(len(ans2label), val_txt_db, val_img_db) val_dataloader = build_dataloader(val_dataset, vqa_eval_collate, False, opts) # Prepare model if opts.checkpoint: checkpoint = torch.load(opts.checkpoint) else: checkpoint = {} all_dbs = opts.train_txt_dbs + [opts.val_txt_db] toker = json.load(open(f'{all_dbs[0]}/meta.json'))['bert'] assert all(toker == json.load(open(f'{db}/meta.json'))['bert'] for db in all_dbs) model = UniterForVisualQuestionAnswering.from_pretrained( opts.model_config, checkpoint, img_dim=IMG_DIM, num_answer=len(ans2label)) model.to(device) # make sure every process has same model parameters in the beginning broadcast_tensors([p.data for p in model.parameters()], 0) set_dropout(model, opts.dropout) # Prepare optimizer optimizer = build_optimizer(model, opts) model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O2') global_step = 0 if rank == 0: save_training_meta(opts) TB_LOGGER.create(join(opts.output_dir, 'log')) pbar = tqdm(total=opts.num_train_steps) model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) json.dump(ans2label, open(join(opts.output_dir, 'ckpt', 'ans2label.json'), 'w')) os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) else: LOGGER.disabled = True pbar = NoOp() model_saver = NoOp() LOGGER.info(f"***** Running training with {n_gpu} GPUs *****") LOGGER.info(" Num examples = %d", len(train_dataset) * hvd.size()) LOGGER.info(" Batch size = %d", opts.train_batch_size) LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps) LOGGER.info(" Num steps = %d", opts.num_train_steps) running_loss = RunningMeter('loss') model.train() n_examples = 0 n_epoch = 0 start = time() # quick hack for amp delay_unscale bug optimizer.zero_grad() optimizer.step() while True: for step, batch in enumerate(train_dataloader): n_examples += batch['input_ids'].size(0) loss = model(batch, compute_loss=True) loss = loss.mean() * batch['targets'].size(1) # instance-leval bce delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0 with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale ) as scaled_loss: scaled_loss.backward() if not delay_unscale: # gather gradients from every processes # do this before unscaling to make sure every process uses # the same gradient scale grads = [p.grad.data for p in model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors(grads, float(1)) running_loss(loss.item()) if (step + 1) % opts.gradient_accumulation_steps == 0: global_step += 1 # learning rate scheduling lr_this_step = get_lr_sched(global_step, opts) for i, param_group in enumerate(optimizer.param_groups): if i == 0 or i == 1: param_group['lr'] = lr_this_step * opts.lr_mul elif i == 2 or i == 3: param_group['lr'] = lr_this_step else: raise ValueError() TB_LOGGER.add_scalar('lr', lr_this_step, global_step) # log loss losses = all_gather_list(running_loss) running_loss = RunningMeter( 'loss', sum(l.val for l in losses)/len(losses)) TB_LOGGER.add_scalar('loss', running_loss.val, global_step) TB_LOGGER.step() # update model params if opts.grad_norm != -1: grad_norm = clip_grad_norm_(amp.master_params(optimizer), opts.grad_norm) TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) optimizer.step() optimizer.zero_grad() pbar.update(1) if global_step % 100 == 0: # monitor training throughput LOGGER.info(f'============Step {global_step}=============') tot_ex = sum(all_gather_list(n_examples)) ex_per_sec = int(tot_ex / (time()-start)) LOGGER.info(f'{tot_ex} examples trained at ' f'{ex_per_sec} ex/s') TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec, global_step) LOGGER.info(f'===========================================') if global_step % opts.valid_steps == 0: val_log, results = validate( model, val_dataloader, label2ans) with open(f'{opts.output_dir}/results/' f'results_{global_step}_' f'rank{rank}.json', 'w') as f: json.dump(results, f) TB_LOGGER.log_scaler_dict(val_log) model_saver.save(model, global_step) if global_step >= opts.num_train_steps: break if global_step >= opts.num_train_steps: break n_epoch += 1 LOGGER.info(f"finished {n_epoch} epochs") val_log, results = validate(model, val_dataloader, label2ans) with open(f'{opts.output_dir}/results/' f'results_{global_step}_' f'rank{rank}_final.json', 'w') as f: json.dump(results, f) TB_LOGGER.log_scaler_dict(val_log) model_saver.save(model, f'{global_step}_final') @torch.no_grad() def validate(model, val_loader, label2ans): LOGGER.info("start running validation...") model.eval() val_loss = 0 tot_score = 0 n_ex = 0 st = time() results = {} for i, batch in enumerate(val_loader): scores = model(batch, compute_loss=False) targets = batch['targets'] loss = F.binary_cross_entropy_with_logits( scores, targets, reduction='sum') val_loss += loss.item() tot_score += compute_score_with_logits(scores, targets).sum().item() answers = [label2ans[i] for i in scores.max(dim=-1, keepdim=False )[1].cpu().tolist()] for qid, answer in zip(batch['qids'], answers): results[qid] = answer n_ex += len(batch['qids']) val_loss = sum(all_gather_list(val_loss)) tot_score = sum(all_gather_list(tot_score)) n_ex = sum(all_gather_list(n_ex)) tot_time = time()-st val_loss /= n_ex val_acc = tot_score / n_ex val_log = {'valid/loss': val_loss, 'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time} model.train() LOGGER.info(f"validation finished in {int(tot_time)} seconds, " f"score: {val_acc*100:.2f}") return val_log, results def compute_score_with_logits(logits, labels): logits = torch.max(logits, 1)[1] # argmax one_hots = torch.zeros(*labels.size(), device=labels.device) one_hots.scatter_(1, logits.view(-1, 1), 1) scores = (one_hots * labels) return scores if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters # TODO datasets parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB') parser.add_argument("--model_config", default=None, type=str, help="json file for model architecture") parser.add_argument("--checkpoint", default=None, type=str, help="pretrained model") parser.add_argument( "--output_dir", default=None, type=str, help="The output directory where the model checkpoints will be " "written.") # Prepro parameters parser.add_argument('--max_txt_len', type=int, default=60, help='max number of tokens in text (BERT BPE)') parser.add_argument('--conf_th', type=float, default=0.2, help='threshold for dynamic bounding boxes ' '(-1 for fixed)') parser.add_argument('--max_bb', type=int, default=100, help='max number of bounding boxes') parser.add_argument('--min_bb', type=int, default=10, help='min number of bounding boxes') parser.add_argument('--num_bb', type=int, default=36, help='static number of bounding boxes') # training parameters parser.add_argument("--train_batch_size", default=4096, type=int, help="Total batch size for training. " "(batch by tokens)") parser.add_argument("--val_batch_size", default=4096, type=int, help="Total batch size for validation. " "(batch by tokens)") parser.add_argument('--gradient_accumulation_steps', type=int, default=16, help="Number of updates steps to accumualte before " "performing a backward/update pass.") parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--lr_mul", default=1.0, type=float, help="multiplier for top layer lr") parser.add_argument("--valid_steps", default=1000, type=int, help="Run validation every X steps") parser.add_argument("--num_train_steps", default=100000, type=int, help="Total number of training updates to perform.") parser.add_argument("--optim", default='adam', choices=['adam', 'adamax', 'adamw'], help="optimizer") parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', help="beta for adam optimizer") parser.add_argument("--decay", default='linear', choices=['linear', 'invsqrt', 'constant', 'vqa'], help="learning rate decay method") parser.add_argument("--decay_int", default=2000, type=int, help="interval between VQA lr decy") parser.add_argument("--warm_int", default=2000, type=int, help="interval for VQA lr warmup") parser.add_argument("--decay_st", default=20000, type=int, help="when to start decay") parser.add_argument("--decay_rate", default=0.2, type=float, help="ratio of lr decay") parser.add_argument("--dropout", default=0.1, type=float, help="tune dropout regularization") parser.add_argument("--weight_decay", default=0.0, type=float, help="weight decay (L2) regularization") parser.add_argument("--grad_norm", default=2.0, type=float, help="gradient clipping (-1 for no clipping)") parser.add_argument("--warmup_steps", default=4000, type=int, help="Number of training steps to perform linear " "learning rate warmup for. (invsqrt decay)") # device parameters parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead " "of 32-bit") parser.add_argument('--n_workers', type=int, default=4, help="number of data workers") parser.add_argument('--pin_mem', action='store_true', help="pin memory") # can use config files parser.add_argument('--config', help='JSON config files') args = parse_with_config(parser) if exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not " "empty.".format(args.output_dir)) # options safe guard # TODO if args.conf_th == -1: assert args.max_bb + args.max_txt_len + 2 <= 512 else: assert args.num_bb + args.max_txt_len + 2 <= 512 main(args)