logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

754 lines
31 KiB

# coding=utf-8
# copied from hugginface github
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
# team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT pre-training runner."""
import argparse
import json
import os
from os.path import exists, join
import random
from time import time
import torch
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, Adamax
from torch.utils.data import DataLoader
from data.data import ConcatDetectFeatBertTokDataset as ConcatDataset
from apex import amp
from horovod import torch as hvd
import numpy as np
from tqdm import tqdm
from data import (DistributedTokenBucketSampler,
DetectFeatLmdb, MlmDatasetForVCR, mlm_collate_for_vcr,
MrmDatasetForVCR, mrm_collate_for_vcr,
MrcDatasetForVCR, mrc_collate_for_vcr,
MetaLoader, PrefetchLoader)
from model import BertForImageTextPretrainingForVCR
from optim import warmup_linear, noam_schedule, vqa_schedule, AdamW
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
broadcast_tensors)
from utils.save import ModelSaver, save_training_meta
from utils.misc import NoOp, parse_with_config
NUM_SPECIAL_TOKENS = 81
IMG_DIM = 2048
IMG_LABEL_DIM = 1601
def parse_tasks(datasets):
task_names = []
dset_paths = []
mix_ratio = []
for i, dset in enumerate(datasets):
assert len(dset['db']) == len(dset['img'])
if 'mix_ratio' in dset:
assert len(dset['tasks']) == len(dset['mix_ratio'])
mix_ratio.extend(dset['mix_ratio'])
task_names.extend(f'{t}_{dset["name"]}' for t in dset['tasks'])
n_task = len(dset['tasks'])
dset_paths.extend([(dset['db'], dset['img'])] * n_task)
assert len(task_names) == len(set(task_names)) == len(dset_paths)
if mix_ratio:
assert len(task_names) == len(mix_ratio)
return task_names, dset_paths, mix_ratio
else:
return task_names, dset_paths
def build_sampler(lens, batch_size, eval_, bucket_size=8192):
droplast = not eval_
sampler = DistributedTokenBucketSampler(
hvd.size(), hvd.rank(), lens,
bucket_size=bucket_size, batch_size=batch_size, droplast=droplast)
return sampler
def build_mlm_train_dataloader(txt_db, img_dir_gt, img_dir,
n_gpu, opts):
LOGGER.info(f"Loading MLM Train Dataset {txt_db}, "
f"{[i.img_dir for i in img_dir]}"
f"{[i.img_dir for i in img_dir_gt]}")
train_datasets = [MlmDatasetForVCR(
db, dir_gt_, dir_, opts.max_txt_len, task=t)
for db, dir_gt_, dir_ in zip(txt_db, img_dir_gt, img_dir)
for t in opts.vcr_task]
train_dataset = ConcatDataset(train_datasets)
train_sampler = build_sampler(train_dataset.lens,
opts.train_batch_size, eval_=False)
train_dataloader = DataLoader(train_dataset,
batch_sampler=train_sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=mlm_collate_for_vcr)
LOGGER.info(f"{len(train_dataset)} samples loaded")
return train_dataloader
def build_mrm_train_dataloader(txt_db, img_dir_gt, img_dir,
n_gpu, opts):
LOGGER.info(f"Loading MRM Train Dataset {txt_db}, "
f"{[i.img_dir for i in img_dir]}"
f"{[i.img_dir for i in img_dir_gt]}")
train_datasets = [MrmDatasetForVCR(
opts.mrm_prob, db, dir_gt_,
dir_, opts.max_txt_len, task=t)
for db, dir_gt_, dir_ in zip(txt_db, img_dir_gt, img_dir)
for t in opts.vcr_task]
train_dataset = ConcatDataset(train_datasets)
train_sampler = build_sampler(train_dataset.lens,
opts.train_batch_size, eval_=False)
train_dataloader = DataLoader(train_dataset,
batch_sampler=train_sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=mrm_collate_for_vcr)
LOGGER.info(f"{len(train_dataset)} samples loaded")
return train_dataloader
def build_mrc_train_dataloader(txt_db, img_dir_gt, img_dir,
n_gpu, opts):
LOGGER.info(f"Loading MRC Train Dataset {txt_db}, "
f"{[i.img_dir for i in img_dir]}"
f"{[i.img_dir for i in img_dir_gt]}")
train_datasets = [MrcDatasetForVCR(
opts.mrc_prob, db, dir_gt_,
dir_, opts.max_txt_len, task=t)
for db, dir_gt_, dir_ in zip(txt_db, img_dir_gt, img_dir)
for t in opts.vcr_task]
train_dataset = ConcatDataset(train_datasets)
train_sampler = build_sampler(train_dataset.lens,
opts.train_batch_size, eval_=False)
train_dataloader = DataLoader(train_dataset,
batch_sampler=train_sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=mrc_collate_for_vcr)
LOGGER.info(f"{len(train_dataset)} samples loaded")
return train_dataloader
def build_mlm_val_dataloader(txt_db, img_dir_gt, img_dir,
n_gpu, opts):
LOGGER.info(f"Loading MLM Val Dataset {txt_db}, "
f"{img_dir_gt.img_dir}, {img_dir.img_dir}")
val_datasets = [MlmDatasetForVCR(
txt_db, img_dir_gt, img_dir, -1, task=t)
for t in opts.vcr_task]
val_dataset = ConcatDataset(val_datasets)
val_sampler = build_sampler(val_dataset.lens,
opts.val_batch_size, eval_=True)
val_dataloader = DataLoader(val_dataset,
batch_sampler=val_sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=mlm_collate_for_vcr)
LOGGER.info(f"{len(val_dataset)} samples loaded")
return val_dataloader
def build_mrm_val_dataloader(txt_db, img_dir_gt, img_dir,
n_gpu, opts):
LOGGER.info(f"Loading MRM Val Dataset {txt_db}, "
f"{img_dir_gt.img_dir}, {img_dir.img_dir}")
val_datasets = [MrmDatasetForVCR(
opts.mrm_prob, txt_db, img_dir_gt,
img_dir, -1, task=t)
for t in opts.vcr_task]
val_dataset = ConcatDataset(val_datasets)
val_sampler = build_sampler(val_dataset.lens,
opts.val_batch_size, eval_=True)
val_dataloader = DataLoader(val_dataset,
batch_sampler=val_sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=mrm_collate_for_vcr)
LOGGER.info(f"{len(val_dataset)} samples loaded")
return val_dataloader
def build_mrc_val_dataloader(txt_db, img_dir_gt, img_dir,
n_gpu, opts):
LOGGER.info(f"Loading MRC Val Dataset {txt_db}, "
f"{img_dir_gt.img_dir}, {img_dir.img_dir}")
val_datasets = [MrcDatasetForVCR(
opts.mrc_prob, txt_db, img_dir_gt,
img_dir, -1, task=t)
for t in opts.vcr_task]
val_dataset = ConcatDataset(val_datasets)
val_sampler = build_sampler(val_dataset.lens,
opts.val_batch_size, eval_=True)
val_dataloader = DataLoader(val_dataset,
batch_sampler=val_sampler,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=mrc_collate_for_vcr)
LOGGER.info(f"{len(val_dataset)} samples loaded")
return val_dataloader
def load_img_feat(dir_list, path2imgdir, opts):
dir_ = dir_list.split(";")
assert len(dir_) <= 2, "More than two img_dirs found"
img_dir_gt, img_dir = None, None
gt_dir_path, dir_path = "", ""
for d in dir_:
if "gt" in d:
gt_dir_path = d
else:
dir_path = d
if gt_dir_path != "":
img_dir_gt = path2imgdir.get(gt_dir_path, None)
if img_dir_gt is None:
img_dir_gt = DetectFeatLmdb(gt_dir_path, -1,
opts.max_bb, opts.min_bb, 100,
opts.compressed_db)
path2imgdir[gt_dir_path] = img_dir_gt
if dir_path != "":
img_dir = path2imgdir.get(dir_path, None)
if img_dir is None:
img_dir = DetectFeatLmdb(dir_path, opts.conf_th,
opts.max_bb, opts.min_bb, opts.num_bb,
opts.compressed_db)
path2imgdir[dir_path] = img_dir
return img_dir, img_dir_gt, path2imgdir
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
if opts.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
"should be >= 1".format(
opts.gradient_accumulation_steps))
random.seed(opts.seed)
np.random.seed(opts.seed)
torch.manual_seed(opts.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(opts.seed)
if rank == 0:
save_training_meta(opts)
TB_LOGGER.create(join(opts.output_dir, 'log'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(join(args.output_dir, 'ckpt'))
add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
all_dbs = [db for datasets in [opts.train_datasets, opts.val_datasets]
for dset in datasets for db in dset['db']]
bert_model = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
assert all(bert_model == json.load(open(f'{db}/meta.json'))['bert']
for db in all_dbs)
train_tasks, train_data_paths, mix_ratio = parse_tasks(opts.train_datasets)
train_dataloaders = []
path2imgdir = {}
for (dbs, dirs), task in zip(train_data_paths, train_tasks):
img_dirs = []
img_gt_dirs = []
for db, dir_list in zip(dbs, dirs):
img_dir, img_dir_gt, path2imgdir = load_img_feat(
dir_list, path2imgdir, opts)
img_dirs.append(img_dir)
img_gt_dirs.append(img_dir_gt)
if task.startswith('mlm'):
loader = build_mlm_train_dataloader(dbs, img_gt_dirs, img_dirs,
n_gpu, opts)
elif task.startswith('mrm'):
loader = build_mrm_train_dataloader(dbs, img_gt_dirs, img_dirs,
n_gpu, opts)
elif task.startswith('mrc'):
loader = build_mrc_train_dataloader(dbs, img_gt_dirs, img_dirs,
n_gpu, opts)
else:
raise ValueError(f'Undefined task {task}')
train_dataloaders.append(loader)
val_tasks, val_data_paths = parse_tasks(opts.val_datasets)
val_dataloaders = []
for (db, dir_), task in zip(val_data_paths, val_tasks):
assert len(db) == len(dir_) == 1
db = db[0]
dir_ = dir_[0]
img_dir, img_dir_gt, path2imgdir = load_img_feat(
dir_, path2imgdir, opts)
if task.startswith('mlm'):
loader = build_mlm_val_dataloader(db, img_dir_gt, img_dir, n_gpu, opts)
elif task.startswith('mrm'):
loader = build_mrm_val_dataloader(db, img_dir_gt, img_dir, n_gpu, opts)
elif task.startswith('mrc'):
loader = build_mrc_val_dataloader(db, img_dir_gt, img_dir, n_gpu, opts)
else:
raise ValueError(f'Undefined task {task}')
val_dataloaders.append(PrefetchLoader(loader))
meta_loader = MetaLoader(train_dataloaders,
mix_ratio=mix_ratio, names=train_tasks,
accum_steps=opts.gradient_accumulation_steps,
distributed=n_gpu > 1)
meta_loader = PrefetchLoader(meta_loader)
named_val_loaders = list(zip(val_tasks, val_dataloaders))
# Prepare model
if opts.checkpoint:
if opts.checkpoint == 'google-bert':
checkpoint = None
else:
checkpoint = torch.load(opts.checkpoint)
else:
checkpoint = {}
model = BertForImageTextPretrainingForVCR.from_pretrained(
bert_model, img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM,
state_dict=checkpoint)
model.init_type_embedding()
model.init_word_embedding(NUM_SPECIAL_TOKENS)
model.pad_vocab() # tensor core padding for vocabulary
if opts.cut_bert != -1:
# cut some layers of BERT
model.bert.encoder.layer = torch.nn.ModuleList(
model.bert.encoder.layer[:opts.cut_bert])
for name, module in model.named_modules():
# we might want to tune dropout for smaller dataset
if isinstance(module, torch.nn.Dropout):
if module.p != opts.dropout:
module.p = opts.dropout
LOGGER.info(f'{name} set to {opts.dropout}')
model.to(device)
# make sure every process has same model parameters in the beginning
broadcast_tensors([p.data for p in model.parameters()], 0)
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer
if not any(nd in n for nd in no_decay)],
'weight_decay': opts.weight_decay},
{'params': [p for n, p in param_optimizer
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
if opts.optim == 'adam':
OptimCls = Adam
elif opts.optim == 'adamax':
OptimCls = Adamax
elif opts.optim == 'adamw':
OptimCls = AdamW
else:
raise ValueError('invalid optimizer')
optimizer = OptimCls(optimizer_grouped_parameters,
lr=opts.learning_rate, betas=opts.betas)
model, optimizer = amp.initialize(model, optimizer,
enabled=opts.fp16, opt_level='O2')
global_step = 0
if rank == 0:
save_training_meta(opts)
TB_LOGGER.create(join(opts.output_dir, 'log'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions
add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(" Batch size = %d", opts.train_batch_size)
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
LOGGER.info(" Num steps = %d", opts.num_train_steps)
task2loss = {task: RunningMeter(f'loss/{task}') for task in train_tasks}
model.train()
n_examples = 0
n_epoch = 0
start = time()
# quick hack for amp delay_unscale bug
optimizer.zero_grad()
optimizer.step()
while True:
for step, (name, batch) in enumerate(meta_loader):
input_ids, *_ = batch
n_examples += input_ids.size(0)
task = name.split('_')[0]
loss = model(*batch, task=task, compute_loss=True)
loss = loss.mean() # loss is not normalized
if task == 'mrckl':
# MRCkl normalization; safeguard fp16 overflow
loss = loss.float() * IMG_LABEL_DIM
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
if not delay_unscale:
# gather gradients from every processes
# do this before unscaling to make sure every process uses
# the same gradient scale
grads = [p.grad.data for p in model.parameters()
if p.requires_grad and p.grad is not None]
all_reduce_and_rescale_tensors(grads, float(1))
task2loss[name](loss.item())
if (step + 1) % opts.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
if opts.decay == 'linear':
lr_this_step = opts.learning_rate * warmup_linear(
global_step, opts.warmup_steps, opts.num_train_steps)
elif opts.decay == 'invsqrt':
lr_this_step = opts.learning_rate * noam_schedule(
global_step, opts.warmup_steps)
elif opts.decay == 'constant':
lr_this_step = opts.learning_rate
elif opts.decay == 'vqa':
lr_this_step = opts.learning_rate * vqa_schedule(
global_step, opts.warm_int, opts.decay_int,
opts.decay_st, opts.decay_rate)
if lr_this_step < 0:
# save guard for possible miscalculation of train steps
lr_this_step = 1e-8
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
# log loss
for t, l in task2loss.items():
loss = sum(v for v in all_gather_list(l.val)
if v is not None) / hvd.size()
task2loss[t] = RunningMeter(f'loss/{t}', loss)
TB_LOGGER.log_scaler_dict({l.name: l.val
for l in task2loss.values()
if l.val is not None})
TB_LOGGER.step()
# update model params
if opts.grad_norm != -1:
grad_norm = clip_grad_norm_(amp.master_params(optimizer),
opts.grad_norm)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if global_step % 5 == 0:
torch.cuda.empty_cache()
if global_step % 100 == 0:
# monitor training throughput
tot_ex = sum(all_gather_list(n_examples))
ex_per_sec = int(tot_ex / (time()-start))
LOGGER.info(f'{tot_ex} examples trained at '
f'{ex_per_sec} ex/s')
TB_LOGGER.add_scalar('perf/ex_per_s',
ex_per_sec, global_step)
if global_step % opts.valid_steps == 0:
validate(model, named_val_loaders)
model_saver.save(model, global_step)
if global_step >= opts.num_train_steps:
break
if global_step % opts.valid_steps != 0:
validate(model, named_val_loaders)
model_saver.save(model, global_step)
def validate(model, named_val_loaders):
model.eval()
for task, loader in named_val_loaders:
LOGGER.info(f"validate on {task} task")
if task.startswith('mlm'):
val_log = validate_mlm(model, loader)
elif task.startswith('mrm'):
val_log = validate_mrm(model, loader)
elif task.startswith('mrc'):
val_log = validate_mrc(model, loader, task)
else:
raise ValueError(f'Undefined task {task}')
val_log = {f'{task}_{k}': v for k, v in val_log.items()}
TB_LOGGER.log_scaler_dict(
{f'valid_{task}/{k}': v for k, v in val_log.items()})
model.train()
@torch.no_grad()
def validate_mrc(model, val_loader, task):
LOGGER.info("start running MRC validation...")
val_loss = 0
n_feat = 0
st = time()
tot_score = 0
for i, batch in enumerate(val_loader):
*_, label = batch
feat_mask, label_targets = label
prediction_soft_label = model(
*batch, task=task, compute_loss=False)
if "kl" in task:
prediction_soft_label = F.log_softmax(
prediction_soft_label, dim=-1)
loss = F.kl_div(
prediction_soft_label, label_targets, reduction='sum')
tot_score += compute_accuracy_for_mrc(
prediction_soft_label, label_targets)
else:
cls_label_targets = label_targets.max(dim=-1)[1] # argmax
loss = F.cross_entropy(
prediction_soft_label, cls_label_targets,
ignore_index=0, reduction='sum')
tot_score += compute_accuracy_for_mrc(
prediction_soft_label[:, 1:], label_targets[:, 1:])
val_loss += loss.item()
n_feat += feat_mask.sum().item()
val_loss = sum(all_gather_list(val_loss))
tot_score = sum(all_gather_list(tot_score))
n_feat = sum(all_gather_list(n_feat))
tot_time = time()-st
val_loss /= n_feat
val_acc = tot_score / n_feat
val_log = {'loss': val_loss,
'acc': val_acc,
'feat_per_s': n_feat/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score: {val_acc*100:.2f}")
return val_log
@torch.no_grad()
def validate_mrm(model, val_loader):
LOGGER.info("start running MRM validation...")
val_loss = 0
n_feat = 0
st = time()
for i, batch in enumerate(val_loader):
*_, feat_mask = batch
loss = model(*batch, task='mrm', compute_loss=True)
val_loss += loss.sum().item()
n_feat += feat_mask.sum().item()
val_loss = sum(all_gather_list(val_loss))
n_feat = sum(all_gather_list(n_feat))
tot_time = time()-st
val_loss /= (n_feat * IMG_DIM)
val_log = {'loss': val_loss,
'feat_per_s': n_feat/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"loss: {val_loss:.2f}")
return val_log
@torch.no_grad()
def validate_mlm(model, val_loader):
LOGGER.info(f"start running MLM validation ...")
val_loss = 0
n_correct = 0
n_word = 0
st = time()
for i, batch in enumerate(val_loader):
*inputs, txt_labels = batch
loss = model.forward(*batch, task='mlm', compute_loss=True)
# loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1,
# reduction='sum')
# loss = loss_fct(scores, txt_labels)
loss = loss.mean()
val_loss += loss.item()
# n_correct += accuracy_count(scores, txt_labels)
n_word += txt_labels.numel()
val_loss = sum(all_gather_list(val_loss))
n_correct = sum(all_gather_list(n_correct))
n_word = sum(all_gather_list(n_word))
tot_time = time()-st
val_loss /= n_word
acc = n_correct / n_word
val_log = {'loss': val_loss,
'acc': acc,
'tok_per_s': n_word/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"acc: {acc*100:.2f}"
f"loss: {val_loss}")
return val_log
def compute_accuracy_for_mrc(out, labels):
outputs = out.max(dim=-1)[1]
labels = labels.max(dim=-1)[1] # argmax
n_correct = (outputs == labels).sum().item()
return n_correct
def accuracy_count(out, labels):
outputs = out.max(dim=-1)[1]
mask = labels != -1
n_correct = (outputs == labels).masked_select(mask).sum().item()
return n_correct
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
# NOTE: train tasks and val tasks cannot take command line arguments
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--vcr_task",
default=["qar"], type=str, nargs='+',
choices=['qa', 'qar'],
help="VCR tasks: qa or qar")
parser.add_argument('--tasks', default=None, type=str, nargs='+',
help="specify pretraining tasks")
parser.add_argument('--mrm_prob', default=0.15, type=float,
help='probability to mask in MRM training')
parser.add_argument('--mrc_prob', default=0.15, type=float,
help='probability to mask in MRC training')
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model (can take 'google-bert') ")
parser.add_argument("--cut_bert", default=-1, type=int,
help="reduce BERT layers (-1 for original depth)")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
# Prepro parameters
parser.add_argument('--max_txt_len', type=int, default=60,
help='max number of tokens in text (BERT BPE)')
parser.add_argument('--conf_th', type=float, default=0.2,
help='threshold for dynamic bounding boxes '
'(-1 for fixed)')
parser.add_argument('--max_bb', type=int, default=100,
help='max number of bounding boxes')
parser.add_argument('--min_bb', type=int, default=10,
help='min number of bounding boxes')
parser.add_argument('--num_bb', type=int, default=36,
help='static number of bounding boxes')
# training parameters
parser.add_argument("--train_batch_size",
default=4096, type=int,
help="Total batch size for training. "
"(batch by tokens)")
parser.add_argument("--val_batch_size",
default=4096, type=int,
help="Total batch size for validation. "
"(batch by tokens)")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=16,
help="Number of updates steps to accumualte before "
"performing a backward/update pass.")
parser.add_argument("--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--valid_steps",
default=1000,
type=int,
help="Run validation every X steps")
parser.add_argument("--num_train_steps",
default=100000,
type=int,
help="Total number of training updates to perform.")
parser.add_argument('--mask_prob', default=0.15, type=float,
help='probability to mask in MRC training')
parser.add_argument("--optim", default='adam',
choices=['adam', 'adamax', 'adamw'],
help="optimizer")
parser.add_argument("--betas", default=[0.9, 0.98], nargs='+',
help="beta for adam optimizer")
parser.add_argument("--decay", default='linear',
choices=['linear', 'invsqrt', 'constant', 'vqa'],
help="learning rate decay method")
parser.add_argument("--decay_int", default=2000, type=int,
help="interval between VQA lr decy")
parser.add_argument("--warm_int", default=2000, type=int,
help="interval for VQA lr warmup")
parser.add_argument("--decay_st", default=20000, type=int,
help="when to start decay")
parser.add_argument("--decay_rate", default=0.2, type=float,
help="ratio of lr decay")
parser.add_argument("--dropout",
default=0.1,
type=float,
help="tune dropout regularization")
parser.add_argument("--weight_decay",
default=0.0,
type=float,
help="weight decay (L2) regularization")
parser.add_argument("--grad_norm",
default=0.25,
type=float,
help="gradient clipping (-1 for no clipping)")
parser.add_argument("--warmup_steps",
default=4000,
type=int,
help="Number of training steps to perform linear "
"learning rate warmup for. (invsqrt decay)")
# device parameters
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
# can use config files
parser.add_argument('--config', help='JSON config files')
args = parse_with_config(parser)
if exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not "
"empty.".format(args.output_dir))
# options safe guard
# TODO
if args.conf_th == -1:
assert args.max_bb + args.max_txt_len + 2 <= 512
else:
assert args.num_bb + args.max_txt_len + 2 <= 512
assert len(args.vcr_task) > 0, "Must choose at least one vcr task"
main(args)