""" Copyright (c) Microsoft Corporation. Licensed under the MIT license. UNITER finetuning for Image-Text Retrieval """ import argparse import os from os.path import exists, join from time import time import torch from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader, ConcatDataset from apex import amp from horovod import torch as hvd from tqdm import tqdm from data import (PrefetchLoader, TxtTokLmdb, ImageLmdbGroup, ItmRankDatasetHardNegFromText, ItmRankDatasetHardNegFromImage, itm_rank_hnv2_collate, ItmValDataset, itm_val_collate, ItmEvalDataset, itm_eval_collate) from model import UniterForImageTextRetrievalHardNeg from optim import get_lr_sched from optim.misc import build_optimizer from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list, broadcast_tensors) from utils.save import ModelSaver, save_training_meta from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed from utils.const import IMG_DIM from eval.itm import itm_eval def build_dataloader(dataset, collate_fn, is_train, opts): dataloader = DataLoader(dataset, batch_size=1, shuffle=is_train, drop_last=is_train, num_workers=opts.n_workers, pin_memory=opts.pin_mem, collate_fn=collate_fn) dataloader = PrefetchLoader(dataloader) return dataloader def main(opts): hvd.init() n_gpu = hvd.size() device = torch.device("cuda", hvd.local_rank()) torch.cuda.set_device(hvd.local_rank()) rank = hvd.rank() opts.rank = rank LOGGER.info("device: {} n_gpu: {}, rank: {}, " "16-bits training: {}".format( device, n_gpu, hvd.rank(), opts.fp16)) set_random_seed(opts.seed) if hvd.rank() == 0: save_training_meta(opts) TB_LOGGER.create(join(opts.output_dir, 'log')) pbar = tqdm(total=opts.num_train_steps) model_saver = ModelSaver(join(opts.output_dir, 'ckpt')) add_log_to_file(join(opts.output_dir, 'log', 'log.txt')) # store ITM predictions os.makedirs(join(opts.output_dir, 'results_val')) os.makedirs(join(opts.output_dir, 'results_test')) os.makedirs(join(opts.output_dir, 'results_train')) else: LOGGER.disabled = True pbar = NoOp() model_saver = NoOp() # train_examples = None LOGGER.info(f"Loading Train Dataset {opts.train_txt_dbs}, " f"{opts.train_img_dbs}") # check multiple DBs assert len(opts.train_txt_dbs) == len(opts.train_img_dbs), \ "train txt_db and img_db have different length" # load DBs and image dirs all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb, opts.num_bb, opts.compressed_db) # train LOGGER.info(f"Loading Train Dataset " f"{opts.train_txt_dbs}, {opts.train_img_dbs}") train_datasets_t = [] train_datasets_i = [] for txt_path, img_path in zip(opts.train_txt_dbs, opts.train_img_dbs): img_db = all_img_dbs[img_path] txt_db = TxtTokLmdb(txt_path, opts.max_txt_len) train_datasets_t.append( ItmRankDatasetHardNegFromText(txt_db, img_db, opts.negative_size)) train_datasets_i.append( ItmRankDatasetHardNegFromImage(txt_db, img_db, opts.negative_size)) train_dataset_t = ConcatDataset(train_datasets_t) train_dataset_i = ConcatDataset(train_datasets_i) train_dataloader_t = build_dataloader( train_dataset_t, itm_rank_hnv2_collate, True, opts) train_dataloader_i = build_dataloader( train_dataset_i, itm_rank_hnv2_collate, True, opts) # val LOGGER.info(f"Loading Val Dataset {opts.val_txt_db}, {opts.val_img_db}") val_img_db = all_img_dbs[opts.val_img_db] val_txt_db = TxtTokLmdb(opts.val_txt_db, -1) val_dataset = ItmValDataset(val_txt_db, val_img_db, opts.inf_minibatch_size) val_dataloader = build_dataloader(val_dataset, itm_val_collate, False, opts) # eval LOGGER.info(f"Loading val, test Dataset for full evaluation: " f"{opts.val_txt_db}, {opts.val_img_db}" f"{opts.test_txt_db}, {opts.test_img_db}") eval_dataset_val = ItmEvalDataset(val_txt_db, val_img_db, opts.inf_minibatch_size) eval_loader_val = build_dataloader(eval_dataset_val, itm_eval_collate, False, opts) test_img_db = all_img_dbs[opts.test_img_db] test_txt_db = TxtTokLmdb(opts.test_txt_db, -1) eval_dataset_test = ItmEvalDataset(test_txt_db, test_img_db, opts.inf_minibatch_size) eval_loader_test = build_dataloader(eval_dataset_test, itm_eval_collate, False, opts) # Prepare model if opts.checkpoint: checkpoint = torch.load(opts.checkpoint) else: checkpoint = {} model = UniterForImageTextRetrievalHardNeg.from_pretrained( opts.model_config, state_dict=checkpoint, img_dim=IMG_DIM, margin=opts.margin, hard_size=opts.hard_neg_size) model.init_output() # pretrain ITM head is different from ranking head model.to(device) # make sure every process has same model parameters in the beginning broadcast_tensors([p.data for p in model.parameters()], 0) set_dropout(model, opts.dropout) # Prepare optimizer optimizer = build_optimizer(model, opts) model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16, opt_level='O2') LOGGER.info(f"***** Running training on {n_gpu} GPUs *****") LOGGER.info(" Num examples = %d", sum(all_gather_list(len(train_dataset_t)))) LOGGER.info(" Batch size = %d", opts.train_batch_size) LOGGER.info(" Num steps = %d", opts.num_train_steps) running_loss = RunningMeter('loss') model.train() global_step = 0 step = 0 n_examples = 0 n_hard_ex = 0 n_epoch = 0 start = time() train_iter_i = iter(train_dataloader_i) # quick hack for amp delay_unscale bug optimizer.zero_grad() optimizer.step() while True: for batch in train_dataloader_t: # hard text from image try: batch_i = next(train_iter_i) except StopIteration: train_iter_i = iter(train_dataloader_i) batch_i = next(train_iter_i) n_examples += batch_i['attn_masks'].size(0) loss = model(batch_i, sample_from='i', compute_loss=True) n_hard_ex += loss.numel() loss = loss.mean() / opts.train_batch_size with amp.scale_loss(loss, optimizer, delay_unscale=True ) as scaled_loss: scaled_loss.backward() # hard image from text n_examples += batch['attn_masks'].size(0) loss = model(batch, sample_from='t', compute_loss=True) n_hard_ex += loss.numel() # NOTE we use gradient accumulation to implemented train_batch_size loss = loss.mean() / opts.train_batch_size step += 1 delay_unscale = step % opts.train_batch_size != 0 with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale ) as scaled_loss: scaled_loss.backward() if not delay_unscale: # gather gradients from every processes # do this before unscaling to make sure every process uses # the same gradient scale grads = [p.grad.data for p in model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors(grads, float(1)) running_loss(loss.item()) if step % opts.train_batch_size == 0: global_step += 1 # learning rate scheduling lr_this_step = get_lr_sched(global_step, opts) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step TB_LOGGER.add_scalar('lr', lr_this_step, global_step) # log loss losses = all_gather_list(running_loss) running_loss = RunningMeter( 'loss', sum(l.val for l in losses)/len(losses)) TB_LOGGER.add_scalar('loss', running_loss.val, global_step) TB_LOGGER.step() # update model params if opts.grad_norm != -1: grad_norm = clip_grad_norm_(amp.master_params(optimizer), opts.grad_norm) TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step) optimizer.step() optimizer.zero_grad() pbar.update(1) if global_step % 100 == 0: # monitor training throughput LOGGER.info(f'------------Step {global_step}-------------') tot_ex = sum(all_gather_list(n_examples)) ex_per_sec = int(tot_ex / (time()-start)) tot_hn = sum(all_gather_list(n_hard_ex)) hn_per_sec = int(tot_hn / (time()-start)) LOGGER.info(f'{tot_ex} ({tot_hn}) examples (hard) ' f'trained at {ex_per_sec} ({hn_per_sec}) ex/s') TB_LOGGER.add_scalar('perf/ex_per_s', ex_per_sec, global_step) TB_LOGGER.add_scalar('perf/hn_per_s', hn_per_sec, global_step) LOGGER.info(f'-------------------------------------------') if global_step % opts.valid_steps == 0: if opts.full_val: LOGGER.info( f"========================== Step {global_step} " f"==========================") val_log = evaluate(model, eval_loader_val) TB_LOGGER.log_scaler_dict( {f"valid/{k}": v for k, v in val_log.items()}) if hvd.rank() == 0: LOGGER.info( f"image retrieval R1: " f"{val_log['img_r1']*100:.2f},\n" f"image retrieval R5: " f"{val_log['img_r5']*100:.2f},\n" f"image retrieval R10: " f"{val_log['img_r10']*100:.2f}\n" f"text retrieval R1: " f"{val_log['txt_r1']*100:.2f},\n" f"text retrieval R5: " f"{val_log['txt_r5']*100:.2f},\n" f"text retrieval R10: " f"{val_log['txt_r10']*100:.2f}") LOGGER.info("=================================" "=================================") else: val_log = validate(model, val_dataloader) TB_LOGGER.log_scaler_dict(val_log) model_saver.save(model, global_step) if global_step >= opts.num_train_steps: break if global_step >= opts.num_train_steps: break n_epoch += 1 LOGGER.info(f"finished {n_epoch} epochs") pbar.close() # final validation val_log = validate(model, val_dataloader) TB_LOGGER.log_scaler_dict(val_log) model_saver.save(model, f'{global_step}_final') # evaluation for split, loader in [('val', eval_loader_val), ('test', eval_loader_test)]: eval_log = evaluate(model, loader) TB_LOGGER.log_scaler_dict({f"eval/{split}_{k}": v for k, v in eval_log.items()}) if hvd.rank() != 0: continue LOGGER.info( f"========================= {split} ===========================\n" f"image retrieval R1: {eval_log['img_r1']*100:.2f},\n" f"image retrieval R5: {eval_log['img_r5']*100:.2f},\n" f"image retrieval R10: {eval_log['img_r10']*100:.2f}\n" f"text retrieval R1: {eval_log['txt_r1']*100:.2f},\n" f"text retrieval R5: {eval_log['txt_r5']*100:.2f},\n" f"text retrieval R10: {eval_log['txt_r10']*100:.2f}") LOGGER.info("=========================================================") @torch.no_grad() def validate(model, val_loader): if hvd.rank() == 0: pbar = tqdm(total=len(val_loader)) else: pbar = NoOp() LOGGER.info("start running Image Retrieval validation ...") model.eval() n_ex = 0 st = time() recall_at_1, recall_at_5, recall_at_10 = 0, 0, 0 for batch in val_loader: scores = model(batch, compute_loss=False) _, indices = scores.squeeze(1).topk(10, dim=0) rank = (indices == 0).nonzero() if rank.numel(): rank = rank.item() if rank < 1: recall_at_1 += 1 if rank < 5: recall_at_5 += 1 if rank < 10: recall_at_10 += 1 n_ex += 1 pbar.update(1) n_ex = sum(all_gather_list(n_ex)) recall_at_1 = sum(all_gather_list(recall_at_1)) / n_ex recall_at_5 = sum(all_gather_list(recall_at_5)) / n_ex recall_at_10 = sum(all_gather_list(recall_at_10)) / n_ex tot_time = time()-st val_log = {'valid/ex_per_s': n_ex/tot_time, 'valid/recall_1': recall_at_1, 'valid/recall_5': recall_at_5, 'valid/recall_10': recall_at_10} model.train() LOGGER.info(f"validation finished in {int(tot_time)} seconds, " f"recall_1: {recall_at_1*100:.2f}, " f"recall_5: {recall_at_5*100:.2f}, " f"recall_10: {recall_at_10*100:.2f}") pbar.close() return val_log @torch.no_grad() def evaluate(model, eval_loader): st = time() LOGGER.info("start running Image/Text Retrieval evaluation ...") score_matrix = inference(model, eval_loader) dset = eval_loader.dataset all_score = hvd.allgather(score_matrix) all_txt_ids = [i for ids in all_gather_list(dset.ids) for i in ids] all_img_ids = dset.all_img_ids assert all_score.size() == (len(all_txt_ids), len(all_img_ids)) if hvd.rank() != 0: return {} # NOTE: only use rank0 to compute final scores # TODO store score_matrix and ids eval_log = itm_eval(all_score, all_txt_ids, all_img_ids, dset.txt2img, dset.img2txts) tot_time = time()-st LOGGER.info(f"evaluation finished in {int(tot_time)} seconds") return eval_log @torch.no_grad() def inference(model, eval_loader): model.eval() if hvd.rank() == 0: pbar = tqdm(total=len(eval_loader)) else: pbar = NoOp() score_matrix = torch.zeros(len(eval_loader.dataset), len(eval_loader.dataset.all_img_ids), device=torch.device("cuda"), dtype=torch.float16) for i, mini_batches in enumerate(eval_loader): j = 0 for batch in mini_batches: scores = model(batch, compute_loss=False) bs = scores.size(0) score_matrix.data[i, j:j+bs] = scores.data.squeeze(1).half() j += bs assert j == score_matrix.size(1) pbar.update(1) model.train() pbar.close() return score_matrix if __name__ == "__main__": parser = argparse.ArgumentParser() # Required parameters parser.add_argument('--compressed_db', action='store_true', help='use compressed LMDB') parser.add_argument("--checkpoint", default=None, type=str, help="pretrained MLM") parser.add_argument("--output_dir", default=None, type=str, help="The output directory where the model " "checkpoints will be written.") # Prepro parameters parser.add_argument('--max_txt_len', type=int, default=60, help='max number of tokens in text (BERT BPE)') parser.add_argument('--conf_th', type=float, default=0.2, help='threshold for dynamic bounding boxes ' '(-1 for fixed)') parser.add_argument('--max_bb', type=int, default=100, help='max number of bounding boxes') parser.add_argument('--min_bb', type=int, default=10, help='min number of bounding boxes') parser.add_argument('--num_bb', type=int, default=36, help='static number of bounding boxes') # training parameters parser.add_argument("--train_batch_size", default=32, type=int, help="batch size (# positive examples) for training. " "(implemented with gradient accumulation)") parser.add_argument("--negative_size", default=511, type=int, help="Number of negative samples per positive sample" "(forward only)") parser.add_argument("--hard_neg_size", default=31, type=int, help="Number of hard negative samples " "per positive sample (acutally used to train)") parser.add_argument("--inf_minibatch_size", default=512, type=int, help="batch size for running inference. " "(used for validation and evaluation)") parser.add_argument("--margin", default=0.2, type=float, help="margin of ranking loss") parser.add_argument("--learning_rate", default=3e-5, type=float, help="The initial learning rate for Adam.") parser.add_argument("--valid_steps", default=1000, type=int, help="Run validation every X steps") parser.add_argument("--num_train_steps", default=100000, type=int, help="Total number of training updates to perform.") parser.add_argument("--optim", default='adam', choices=['adam', 'adamax', 'adamw'], help="optimizer") parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', help="beta for adam optimizer") parser.add_argument("--decay", default='linear', choices=['linear', 'invsqrt', 'constant'], help="learning rate decay method") parser.add_argument("--dropout", default=0.1, type=float, help="tune dropout regularization") parser.add_argument("--weight_decay", default=0.01, type=float, help="weight decay (L2) regularization") parser.add_argument("--grad_norm", default=0.25, type=float, help="gradient clipping (-1 for no clipping)") parser.add_argument("--warmup_steps", default=4000, type=int, help="Number of training steps to perform linear " "learning rate warmup for. (invsqrt decay)") # device parameters parser.add_argument('--seed', type=int, default=42, help="random seed for initialization") parser.add_argument('--full_val', action='store_true', help="Always run full evaluation during training") parser.add_argument('--fp16', action='store_true', help="Whether to use 16-bit float precision instead " "of 32-bit") parser.add_argument('--n_workers', type=int, default=4, help="number of data workers") parser.add_argument('--pin_mem', action='store_true', help="pin memory") # can use config files parser.add_argument('--config', help='JSON config files') args = parse_with_config(parser) if exists(args.output_dir) and os.listdir(args.output_dir): raise ValueError("Output directory ({}) already exists and is not " "empty.".format(args.output_dir)) # options safe guard if args.conf_th == -1: assert args.max_bb + args.max_txt_len + 2 <= 512 else: assert args.num_bb + args.max_txt_len + 2 <= 512 # for tensor core assert (args.negative_size+1) % 8 == (args.hard_neg_size+1) % 8 == 0 main(args)