logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

413 lines
17 KiB

# coding=utf-8
# copied from hugginface github
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
# team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT pre-training runner."""
import argparse
import json
import os
from os.path import exists, join
import pickle
from time import time
import torch
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader
from apex import amp
from horovod import torch as hvd
from tqdm import tqdm
from data import (TokenBucketSampler, PrefetchLoader,
DetectFeatLmdb, TxtTokLmdb,
VeDataset, VeEvalDataset,
ve_collate, ve_eval_collate)
from model import UniterForVisualEntailment
from optim import get_lr_sched
from optim.misc import build_optimizer
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
broadcast_tensors)
from utils.save import ModelSaver, save_training_meta
from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed
from utils.misc import VE_ENT2IDX as ans2label
from utils.misc import VE_IDX2ENT as label2ans
from utils.const import IMG_DIM, BUCKET_SIZE
def create_dataloader(img_path, txt_path, batch_size, is_train,
dset_cls, collate_fn, opts):
img_db = DetectFeatLmdb(img_path, opts.conf_th, opts.max_bb, opts.min_bb,
opts.num_bb, opts.compressed_db)
txt_db = TxtTokLmdb(txt_path, opts.max_txt_len if is_train else -1)
dset = dset_cls(txt_db, img_db)
sampler = TokenBucketSampler(dset.lens, bucket_size=BUCKET_SIZE,
batch_size=batch_size, droplast=is_train)
loader = DataLoader(dset, batch_sampler=sampler,
num_workers=opts.n_workers, pin_memory=opts.pin_mem,
collate_fn=collate_fn)
return PrefetchLoader(loader)
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
if opts.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
"should be >= 1".format(
opts.gradient_accumulation_steps))
set_random_seed(opts.seed)
# train_examples = None
LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
f"{opts.train_img_db}")
train_dataloader = create_dataloader(opts.train_img_db, opts.train_txt_db,
opts.train_batch_size, True,
VeDataset, ve_collate, opts)
val_dataloader = create_dataloader(opts.val_img_db, opts.val_txt_db,
opts.val_batch_size, False,
VeEvalDataset, ve_eval_collate, opts)
test_dataloader = create_dataloader(opts.test_img_db, opts.test_txt_db,
opts.val_batch_size, False,
VeEvalDataset, ve_eval_collate, opts)
# Prepare model
if opts.checkpoint:
checkpoint = torch.load(opts.checkpoint)
else:
checkpoint = {}
bert_model = json.load(open(f'{opts.train_txt_db}/meta.json'))['bert']
if 'bert' not in bert_model:
bert_model = 'bert-large-cased' # quick hack for glove exp
model = UniterForVisualEntailment.from_pretrained(
opts.model_config, state_dict=checkpoint, img_dim=IMG_DIM)
model.to(device)
# make sure every process has same model parameters in the beginning
broadcast_tensors([p.data for p in model.parameters()], 0)
set_dropout(model, opts.dropout)
# Prepare optimizer
optimizer = build_optimizer(model, opts)
model, optimizer = amp.initialize(model, optimizer,
enabled=opts.fp16, opt_level='O2')
global_step = 0
if rank == 0:
save_training_meta(opts)
TB_LOGGER.create(join(opts.output_dir, 'log'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(join(opts.output_dir, 'ckpt'))
pickle.dump(ans2label,
open(join(opts.output_dir, 'ckpt', 'ans2label.pkl'), 'wb'))
os.makedirs(join(opts.output_dir, 'results')) # store VQA predictions
add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(" Num examples = %d", len(train_dataloader.dataset))
LOGGER.info(" Batch size = %d", opts.train_batch_size)
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
LOGGER.info(" Num steps = %d", opts.num_train_steps)
running_loss = RunningMeter('loss')
model.train()
n_examples = 0
n_epoch = 0
start = time()
# quick hack for amp delay_unscale bug
optimizer.zero_grad()
optimizer.step()
while True:
for step, batch in enumerate(train_dataloader):
n_examples += batch['input_ids'].size(0)
loss = model(batch, compute_loss=True)
loss = loss.mean() * batch['targets'].size(1) # instance-leval bce
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
if not delay_unscale:
# gather gradients from every processes
# do this before unscaling to make sure every process uses
# the same gradient scale
grads = [p.grad.data for p in model.parameters()
if p.requires_grad and p.grad is not None]
all_reduce_and_rescale_tensors(grads, float(1))
running_loss(loss.item())
if (step + 1) % opts.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
lr_this_step = get_lr_sched(global_step, opts)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
# log loss
losses = all_gather_list(running_loss)
running_loss = RunningMeter(
'loss', sum(l.val for l in losses)/len(losses))
TB_LOGGER.add_scalar('loss', running_loss.val, global_step)
TB_LOGGER.step()
# update model params
if opts.grad_norm != -1:
grad_norm = clip_grad_norm_(amp.master_params(optimizer),
opts.grad_norm)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if global_step % 100 == 0:
# monitor training throughput
LOGGER.info(f'============Step {global_step}=============')
tot_ex = sum(all_gather_list(n_examples))
ex_per_sec = int(tot_ex / (time()-start))
LOGGER.info(f'{tot_ex} examples trained at '
f'{ex_per_sec} ex/s')
TB_LOGGER.add_scalar('perf/ex_per_s',
ex_per_sec, global_step)
LOGGER.info(f'===========================================')
if global_step % opts.valid_steps == 0:
for split, loader in [("val", val_dataloader),
("test", test_dataloader)]:
LOGGER.info(f"Step {global_step}: start running "
f"validation on {split} split...")
val_log, results = validate(
model, loader, label2ans, split)
with open(f'{opts.output_dir}/results/'
f'{split}_results_{global_step}_'
f'rank{rank}.json', 'w') as f:
json.dump(results, f)
TB_LOGGER.log_scaler_dict(val_log)
model_saver.save(model, global_step)
if global_step >= opts.num_train_steps:
break
if global_step >= opts.num_train_steps:
break
n_epoch += 1
LOGGER.info(f"Step {global_step}: finished {n_epoch} epochs")
for split, loader in [("val", val_dataloader),
("test", test_dataloader)]:
LOGGER.info(f"Step {global_step}: start running "
f"validation on {split} split...")
val_log, results = validate(model, loader, label2ans, split)
with open(f'{opts.output_dir}/results/'
f'{split}_results_{global_step}_'
f'rank{rank}_final.json', 'w') as f:
json.dump(results, f)
TB_LOGGER.log_scaler_dict(val_log)
model_saver.save(model, f'{global_step}_final')
@torch.no_grad()
def validate(model, val_loader, label2ans, split='val'):
model.eval()
val_loss = 0
tot_score = 0
n_ex = 0
st = time()
results = {}
for i, batch in enumerate(val_loader):
scores = model(batch, compute_loss=False)
targets = batch['targets']
loss = F.binary_cross_entropy_with_logits(
scores, targets, reduction='sum')
val_loss += loss.item()
tot_score += compute_score_with_logits(scores, targets).sum().item()
answers = [label2ans[i]
for i in scores.max(dim=-1, keepdim=False
)[1].cpu().tolist()]
qids = batch['qids']
for qid, answer in zip(qids, answers):
results[qid] = answer
n_ex += len(qids)
val_loss = sum(all_gather_list(val_loss))
tot_score = sum(all_gather_list(tot_score))
n_ex = sum(all_gather_list(n_ex))
tot_time = time()-st
val_loss /= n_ex
val_acc = tot_score / n_ex
val_log = {f'valid/{split}_loss': val_loss,
f'valid/{split}_acc': val_acc,
f'valid/{split}_ex_per_s': n_ex/tot_time}
model.train()
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score: {val_acc*100:.2f}")
return val_log, results
def compute_score_with_logits(logits, labels):
logits = torch.max(logits, 1)[1] # argmax
one_hots = torch.zeros(*labels.size(), device=labels.device)
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = (one_hots * labels)
return scores
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_txt_db",
default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument("--train_img_db",
default=None, type=str,
help="The input train images.")
parser.add_argument("--val_txt_db",
default=None, type=str,
help="The input validation corpus. (LMDB)")
parser.add_argument("--val_img_db",
default=None, type=str,
help="The input validation images.")
parser.add_argument("--test_txt_db",
default=None, type=str,
help="The input test corpus. (LMDB)")
parser.add_argument("--test_img_db",
default=None, type=str,
help="The input test images.")
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--model_config",
default=None, type=str,
help="json file for model architecture")
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model (can take 'google-bert') ")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
# Prepro parameters
parser.add_argument('--max_txt_len', type=int, default=60,
help='max number of tokens in text (BERT BPE)')
parser.add_argument('--conf_th', type=float, default=0.2,
help='threshold for dynamic bounding boxes '
'(-1 for fixed)')
parser.add_argument('--max_bb', type=int, default=100,
help='max number of bounding boxes')
parser.add_argument('--min_bb', type=int, default=10,
help='min number of bounding boxes')
parser.add_argument('--num_bb', type=int, default=36,
help='static number of bounding boxes')
# training parameters
parser.add_argument("--train_batch_size",
default=4096, type=int,
help="Total batch size for training. "
"(batch by tokens)")
parser.add_argument("--val_batch_size",
default=4096, type=int,
help="Total batch size for validation. "
"(batch by tokens)")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=16,
help="Number of updates steps to accumualte before "
"performing a backward/update pass.")
parser.add_argument("--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--valid_steps",
default=1000,
type=int,
help="Run validation every X steps")
parser.add_argument("--num_train_steps",
default=100000,
type=int,
help="Total number of training updates to perform.")
parser.add_argument("--optim", default='adam',
choices=['adam', 'adamax', 'adamw'],
help="optimizer")
parser.add_argument("--betas", default=[0.9, 0.98], nargs='+',
help="beta for adam optimizer")
parser.add_argument("--decay", default='linear',
choices=['linear', 'invsqrt', 'constant'],
help="learning rate decay method")
parser.add_argument("--dropout",
default=0.1,
type=float,
help="tune dropout regularization")
parser.add_argument("--weight_decay",
default=0.0,
type=float,
help="weight decay (L2) regularization")
parser.add_argument("--grad_norm",
default=0.25,
type=float,
help="gradient clipping (-1 for no clipping)")
parser.add_argument("--warmup_steps",
default=4000,
type=int,
help="Number of training steps to perform linear "
"learning rate warmup for. (invsqrt decay)")
# device parameters
parser.add_argument('--seed',
type=int,
default=42,
help="random seed for initialization")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
# can use config files
parser.add_argument('--config', help='JSON config files')
args = parse_with_config(parser)
if exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not "
"empty.".format(args.output_dir))
# options safe guard
# TODO
if args.conf_th == -1:
assert args.max_bb + args.max_txt_len + 2 <= 512
else:
assert args.num_bb + args.max_txt_len + 2 <= 512
main(args)