logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

460 lines
18 KiB

# coding=utf-8
# copied from hugginface github
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
# team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT for Referring Expression Comprehension"""
import argparse
import json
import os
from os.path import exists, join
import random
from time import time
import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, Adamax
from torch.utils.data import DataLoader
# to be deprecated once upgraded to 1.2
# from torch.utils.data.distributed import DistributedSampler
from data import DistributedSampler
from apex import amp
from horovod import torch as hvd
import numpy as np
from tqdm import tqdm
from data import (ReImageFeatDir, ReferringExpressionDataset,
ReferringExpressionEvalDataset, re_collate, re_eval_collate,
PrefetchLoader)
from model import BertForReferringExpressionComprehension
from optim import warmup_linear, noam_schedule, AdamW
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
broadcast_tensors)
from utils.save import ModelSaver, save_training_meta
from utils.misc import NoOp, parse_with_config
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info(f"device: {device}, n_gpu: {n_gpu}, rank: {hvd.rank()}, "
f"16-bits training: {opts.fp16}")
if opts.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
"should be >= 1".format(
opts.gradient_accumulation_steps))
random.seed(opts.seed)
np.random.seed(opts.seed)
torch.manual_seed(opts.seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(opts.seed)
# train_samples = None
LOGGER.info(f"Loading Train Dataset {opts.train_txt_db}, "
f"{opts.train_img_dir}")
# load DBs and image dirs
train_img_dir = ReImageFeatDir(opts.train_img_dir)
train_dataset = ReferringExpressionDataset(
opts.train_txt_db, train_img_dir,
max_txt_len=opts.max_txt_len)
val_img_dir = ReImageFeatDir(opts.val_img_dir)
val_dataset = ReferringExpressionEvalDataset(
opts.val_txt_db, val_img_dir,
max_txt_len=opts.max_txt_len)
# Prepro model
if opts.checkpoint and opts.checkpoint != 'scratch':
if opts.checkpoint == 'google-bert':
# from google-bert
checkpoint = None
else:
checkpoint = torch.load(opts.checkpoint)
else:
# from scratch
checkpoint = {}
bert_model = json.load(open(f'{opts.train_txt_db}/meta.json'))['bert']
model = BertForReferringExpressionComprehension.from_pretrained(
bert_model, img_dim=2048,
loss=opts.train_loss,
margin=opts.margin,
hard_ratio=opts.hard_ratio,
mlp=opts.mlp,
state_dict=checkpoint
)
if opts.cut_bert != -1:
# cut some layers of BERT
model.bert.encoder.layer = torch.nn.ModuleList(
model.bert.encoder.layer[:opts.cut_bert]
)
del checkpoint
for name, module in model.named_modules():
# we may want to tune dropout for smaller dataset
if isinstance(module, torch.nn.Dropout):
if module.p != opts.dropout:
module.p = opts.dropout
LOGGER.info(f'{name} set to {opts.dropout}')
model.to(device)
# make sure every process has same model params in the beginning
broadcast_tensors([p.data for p in model.parameters()], 0)
# Prepare optimizer
param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer
if not any(nd in n for nd in no_decay)],
'weight_decay': opts.weight_decay},
{'params': [p for n, p in param_optimizer
if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
# currently Adam only
if opts.optim == 'adam':
OptimCls = Adam
elif opts.optim == 'adamax':
OptimCls = Adamax
elif opts.optim == 'adamw':
OptimCls = AdamW
else:
raise ValueError('invalid optimizer')
optimizer = OptimCls(optimizer_grouped_parameters,
lr=opts.learning_rate, betas=opts.betas)
model, optimizer = amp.initialize(model, optimizer, enabled=opts.fp16,
opt_level='O2')
global_step = 0
LOGGER.info("***** Running training *****")
LOGGER.info(" Num examples = %d", len(train_dataset))
LOGGER.info(" Batch size = %d", opts.train_batch_size)
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
LOGGER.info(" Num steps = %d", opts.num_train_steps)
train_sampler = DistributedSampler(
train_dataset, num_replicas=n_gpu, rank=rank, shuffle=False)
train_dataloader = DataLoader(train_dataset,
sampler=train_sampler,
batch_size=opts.train_batch_size,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=re_collate)
train_dataloader = PrefetchLoader(train_dataloader)
val_sampler = DistributedSampler(
val_dataset, num_replicas=n_gpu, rank=rank, shuffle=False)
val_dataloader = DataLoader(val_dataset,
sampler=val_sampler,
batch_size=opts.val_batch_size,
num_workers=opts.n_workers,
pin_memory=opts.pin_mem,
collate_fn=re_eval_collate)
val_dataloader = PrefetchLoader(val_dataloader)
if rank == 0:
save_training_meta(opts)
TB_LOGGER.create(join(opts.output_dir, 'log'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(join(opts.output_dir, 'ckpt'), 'model_epoch')
os.makedirs(join(opts.output_dir, 'results')) # store ITM predictions
add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
running_loss = RunningMeter(opts.train_loss)
n_examples = 0
n_epoch = 0
best_val_acc, best_epoch = None, None
start = time()
# quick hack for amp delay_unscale bug
optimizer.zero_grad()
optimizer.step()
while True:
model.train()
for step, batch in enumerate(train_dataloader):
if global_step >= opts.num_train_steps:
break
*_, targets = batch
n_examples += targets.size(0)
loss = model(*batch, compute_loss=True)
loss = loss.sum() # sum over vectorized loss TODO: investigate
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale
) as scaled_loss:
scaled_loss.backward()
if not delay_unscale:
# gather gradients from every processes
# do this before unscaling to make sure every process uses
# the same gradient scale
grads = [p.grad.data for p in model.parameters()
if p.requires_grad and p.grad is not None]
all_reduce_and_rescale_tensors(grads, float(1))
running_loss(loss.item())
if (step + 1) % opts.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
if opts.decay == 'linear':
lr_this_step = opts.learning_rate * warmup_linear(
global_step, opts.warmup_steps, opts.num_train_steps)
elif opts.decay == 'invsqrt':
lr_this_step = opts.learning_rate * noam_schedule(
global_step, opts.warmup_steps)
elif opts.decay == 'constant':
lr_this_step = opts.learning_rate
if lr_this_step < 0:
# save guard for possible miscalculation of train steps
lr_this_step = 1e-8
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
# log loss
losses = all_gather_list(running_loss)
running_loss = RunningMeter(
opts.train_loss, sum(l.val for l in losses)/len(losses))
TB_LOGGER.add_scalar('loss_'+opts.train_loss, running_loss.val,
global_step)
TB_LOGGER.step()
# update model params
if opts.grad_norm != -1:
grad_norm = clip_grad_norm_(amp.master_params(optimizer),
opts.grad_norm)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if global_step % 5 == 0:
torch.cuda.empty_cache()
if global_step % 100 == 0:
# monitor training throughput
tot_ex = sum(all_gather_list(n_examples))
ex_per_sec = int(tot_ex / (time()-start))
LOGGER.info(f'{tot_ex} examples trained at '
f'{ex_per_sec} ex/s')
TB_LOGGER.add_scalar('perf/ex_per_s',
ex_per_sec, global_step)
# evaluate after each epoch
val_log, _ = validate(model, val_dataloader)
TB_LOGGER.log_scaler_dict(val_log)
# save model
n_epoch += 1
model_saver.save(model, n_epoch)
LOGGER.info(f"finished {n_epoch} epochs")
# save best model
if best_val_acc is None or val_log['valid/acc'] > best_val_acc:
best_val_acc = val_log['valid/acc']
best_epoch = n_epoch
model_saver.save(model, 'best')
# shuffle training data for the next epoch
train_dataloader.loader.dataset.shuffle()
# is training finished?
if global_step >= opts.num_train_steps:
break
val_log, results = validate(model, val_dataloader)
with open(f'{opts.output_dir}/results/'
f'results_{global_step}_'
f'rank{rank}_final.json', 'w') as f:
json.dump(results, f)
TB_LOGGER.log_scaler_dict(val_log)
model_saver.save(model, f'{global_step}_final')
# print best model
LOGGER.info(f'best_val_acc = {best_val_acc*100:.2f}% '
f'at epoch {best_epoch}.')
@torch.no_grad()
def validate(model, val_dataloader):
LOGGER.info(f"start running evaluation.")
model.eval()
tot_score = 0
n_ex = 0
st = time()
predictions = {}
for i, batch in enumerate(val_dataloader):
# inputs
(*batch_inputs, tgt_box_list, obj_boxes_list, sent_ids) = batch
# scores (n, max_num_bb)
scores = model(*batch_inputs, targets=None, compute_loss=False)
ixs = torch.argmax(scores, 1).cpu().detach().numpy() # (n, )
# pred_boxes
for ix, obj_boxes, tgt_box, sent_id in \
zip(ixs, obj_boxes_list, tgt_box_list, sent_ids):
pred_box = obj_boxes[ix]
predictions['sent_id'] = {'pred_box': pred_box.tolist(),
'tgt_box': tgt_box.tolist()}
if (val_dataloader.loader.dataset.computeIoU(pred_box, tgt_box)
> .5):
tot_score += 1
n_ex += 1
tot_time = time()-st
tot_score = sum(all_gather_list(tot_score))
n_ex = sum(all_gather_list(n_ex))
val_acc = tot_score / n_ex
val_log = {'valid/acc': val_acc, 'valid/ex_per_s': n_ex/tot_time}
model.train()
LOGGER.info(f"validation ({n_ex} sents) finished in "
f"{int(tot_time)} seconds"
f", accuracy: {val_acc*100:.2f}%")
return val_log, predictions
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# Required parameters
parser.add_argument("--train_txt_db",
default=None, type=str,
help="The input train corpus. (LMDB)")
parser.add_argument("--train_img_dir",
default=None, type=str,
help="The input train images.")
parser.add_argument("--val_txt_db",
default=None, type=str,
help="The input validation corpus. (LMDB)")
parser.add_argument("--val_img_dir",
default=None, type=str,
help="The input validation images.")
parser.add_argument('--img_format', default='npz',
choices=['npz', 'lmdb', 'lmdb-compress'],
help='format of image feature')
parser.add_argument("--checkpoint",
default=None, type=str,
help="pretrained model (can take 'google-bert') ")
parser.add_argument("--cut_bert", default=-1, type=int,
help="reduce BERT layers (-1 for original depth)")
parser.add_argument("--mlp", default=1, type=int,
help="number of MLP layers for RE output")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
# Prepro parameters
parser.add_argument('--max_txt_len', type=int, default=60,
help='max number of tokens in text (BERT BPE)')
# training parameters
parser.add_argument("--train_batch_size",
default=128, type=int,
help="Total batch size for training. "
"(batch by examples)")
parser.add_argument("--val_batch_size",
default=256, type=int,
help="Total batch size for validation. "
"(batch by tokens)")
parser.add_argument("--train_loss",
default="cls", type=str,
choices=['cls', 'rank'],
help="loss to used during training")
parser.add_argument("--margin",
default=0.2, type=float,
help="margin of ranking loss")
parser.add_argument("--hard_ratio",
default=0.3, type=float,
help="sampling ratio of hard negatives")
parser.add_argument('--gradient_accumulation_steps',
type=int,
default=16,
help="Number of updates steps to accumualte before "
"performing a backward/update pass.")
parser.add_argument("--learning_rate",
default=3e-5,
type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--num_train_steps",
default=32000,
type=int,
help="Total number of training updates to perform.")
parser.add_argument("--optim", default='adam',
choices=['adam', 'adamax', 'adamw'],
help="optimizer")
parser.add_argument("--betas", default=[0.9, 0.98], nargs='+', type=float,
help="beta for adam optimizer")
parser.add_argument("--decay", default='linear',
choices=['linear', 'invsqrt', 'constant'],
help="learning rate decay method")
parser.add_argument("--dropout",
default=0.1,
type=float,
help="tune dropout regularization")
parser.add_argument("--weight_decay",
default=0.0,
type=float,
help="weight decay (L2) regularization")
parser.add_argument("--grad_norm",
default=0.25,
type=float,
help="gradient clipping (-1 for no clipping)")
parser.add_argument("--warmup_steps",
default=4000,
type=int,
help="Number of training steps to perform linear "
"learning rate warmup for. (invsqrt decay)")
# device parameters
parser.add_argument('--seed',
type=int,
default=24,
help="random seed for initialization")
parser.add_argument('--fp16',
action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true',
help="pin memory")
# can use config files
parser.add_argument('--config', help='JSON config files')
args = parse_with_config(parser)
if exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not "
"empty.".format(args.output_dir))
# options safe guard
main(args)