logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

834 lines
33 KiB

# coding=utf-8
# copied from hugginface github
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc.
# team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""UNITER pre-training runner."""
import argparse
from collections import defaultdict
import json
import math
import os
from os.path import exists, join
from time import time
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.nn.utils import clip_grad_norm_
from apex import amp
from horovod import torch as hvd
from tqdm import tqdm
from data import (TokenBucketSampler, TokenBucketSamplerForItm,
MetaLoader, PrefetchLoader,
TxtTokLmdb, ImageLmdbGroup, ConcatDatasetWithLens,
MlmDataset, MlmEvalDataset,
BlindMlmDataset, BlindMlmEvalDataset,
MrfrDataset, OnlyImgMrfrDataset,
MrcDataset, OnlyImgMrcDataset,
mlm_collate, mlm_eval_collate,
mlm_blind_collate, mlm_blind_eval_collate,
mrfr_collate, mrfr_only_img_collate,
mrc_collate, mrc_only_img_collate,
ItmDataset, itm_collate, itm_ot_collate)
from data.mrm_nce import NegativeImageSampler, MrmNceDataset, mrm_nce_collate
from model import UniterForPretraining
from optim import get_lr_sched
from optim.misc import build_optimizer
from utils.logger import LOGGER, TB_LOGGER, RunningMeter, add_log_to_file
from utils.distributed import (all_reduce_and_rescale_tensors, all_gather_list,
broadcast_tensors)
from utils.save import ModelSaver, save_training_meta
from utils.misc import NoOp, parse_with_config, set_dropout, set_random_seed
from utils.const import IMG_DIM, IMG_LABEL_DIM, BUCKET_SIZE
WARM_STEP = 500
def build_dataloader(dataset, collate_fn, is_train, opts):
if is_train:
batch_size = opts.train_batch_size
else:
batch_size = opts.val_batch_size
sampler = TokenBucketSampler(dataset.lens, bucket_size=BUCKET_SIZE,
batch_size=batch_size, droplast=is_train)
loader = DataLoader(dataset, batch_sampler=sampler,
num_workers=opts.n_workers, pin_memory=opts.pin_mem,
collate_fn=collate_fn)
return loader
def build_dataloader_itm(dataset, collate_fn, is_train, opts):
if is_train:
batch_size = opts.train_batch_size
else:
batch_size = opts.val_batch_size
sampler = TokenBucketSamplerForItm(
dataset, bucket_size=BUCKET_SIZE,
batch_size=batch_size, droplast=is_train)
loader = DataLoader(dataset, batch_sampler=sampler,
num_workers=opts.n_workers, pin_memory=opts.pin_mem,
collate_fn=collate_fn)
return loader
def build_mlm_dataset(txt_db, img_db, blind, is_train, opts):
if is_train:
if blind:
collate_fn = mlm_blind_collate
datasets = [BlindMlmDataset(t) for t in txt_db]
else:
collate_fn = mlm_collate
datasets = [MlmDataset(t, i) for t, i in zip(txt_db, img_db)]
dataset = ConcatDatasetWithLens(datasets)
else:
if blind:
collate_fn = mlm_blind_collate
dataset = BlindMlmDataset(txt_db)
else:
collate_fn = mlm_collate
dataset = MlmDataset(txt_db, img_db)
return dataset, collate_fn
def build_mrfr_dataset(txt_db, img_db, only_i, is_train, opts):
collate_fn = (mrfr_only_img_collate if only_i
else mrfr_collate)
if is_train:
if only_i:
datasets = [OnlyImgMrfrDataset(opts.mrm_prob, i) for i in img_db]
else:
datasets = [MrfrDataset(opts.mrm_prob, t, i)
for t, i in zip(txt_db, img_db)]
dataset = ConcatDatasetWithLens(datasets)
else:
if only_i:
dataset = OnlyImgMrfrDataset(opts.mrm_prob, img_db)
else:
dataset = MrfrDataset(opts.mrm_prob, txt_db, img_db)
return dataset, collate_fn
def build_mrm_nce_dataset(txt_db, img_db, only_i, is_train, opts):
assert not only_i
neg_sampler = NegativeImageSampler(img_db, opts.neg_size)
collate_fn = mrm_nce_collate(neg_sampler)
if is_train:
datasets = [MrmNceDataset(opts.mrm_prob, t, i)
for t, i in zip(txt_db, img_db)]
dataset = ConcatDatasetWithLens(datasets)
else:
dataset = MrmNceDataset(opts.mrm_prob, txt_db, img_db)
return dataset, collate_fn
def build_mrc_dataset(txt_db, img_db, only_i, is_train, opts):
collate_fn = (mrc_only_img_collate if only_i
else mrc_collate)
if is_train:
if only_i:
datasets = [OnlyImgMrcDataset(opts.mrm_prob, i) for i in img_db]
else:
datasets = [MrcDataset(opts.mrm_prob, t, i)
for t, i in zip(txt_db, img_db)]
dataset = ConcatDatasetWithLens(datasets)
else:
if only_i:
dataset = OnlyImgMrcDataset(opts.mrm_prob, img_db)
else:
dataset = MrcDataset(opts.mrm_prob, txt_db, img_db)
return dataset, collate_fn
def build_itm_dataset(txt_db, img_db, is_train, opts):
if is_train:
datasets = [ItmDataset(t, i, opts.itm_neg_prob)
for t, i in zip(txt_db, img_db)]
dataset = ConcatDatasetWithLens(datasets)
else:
dataset = ItmDataset(txt_db, img_db, opts.itm_neg_prob)
collate_fn = itm_ot_collate if opts.itm_ot_lambda > 0 else itm_collate
return dataset, collate_fn
def create_dataloaders(datasets, is_train, opts, all_img_dbs=None):
if all_img_dbs is None:
all_img_dbs = ImageLmdbGroup(opts.conf_th, opts.max_bb, opts.min_bb,
opts.num_bb, opts.compressed_db)
dataloaders = {}
for dset in datasets:
if is_train:
assert len(dset['db']) == len(dset['img'])
assert len(dset['tasks']) == len(dset['mix_ratio'])
img_db = [all_img_dbs[path] for path in dset['img']]
else:
assert len(dset['db']) == len(dset['img']) == 1
img_db = all_img_dbs[dset['img'][0]]
for i, t in enumerate(dset['tasks']):
task = f'{t}_{dset["name"]}'
if is_train:
LOGGER.info(f"Loading {task} train dataset "
f"{dset['db']}, {[img.img_dir for img in img_db]}")
txt_db = [TxtTokLmdb(path, opts.max_txt_len)
for path in dset['db']]
else:
LOGGER.info(f"Loading {task} validation dataset, "
f"{dset['db']}, {img_db.img_dir}")
txt_db = TxtTokLmdb(dset['db'][0], -1)
if task.startswith('mlm'):
blind = 'blind' in task
dataset = build_mlm_dataset(txt_db, img_db,
blind, is_train, opts)
elif task.startswith('mrfr'):
only_i = 'only_i' in task
dataset = build_mrfr_dataset(txt_db, img_db,
only_i, is_train, opts)
elif task.startswith('mrm-nce'):
only_i = 'only_i' in task
dataset = build_mrm_nce_dataset(txt_db, img_db,
only_i, is_train, opts)
elif task.startswith('mrc'):
only_i = 'only_i' in task
dataset = build_mrc_dataset(txt_db, img_db,
only_i, is_train, opts)
elif task.startswith('itm'):
dataset = build_itm_dataset(txt_db, img_db, is_train, opts)
else:
raise ValueError(f'Undefined task {task}')
LOGGER.info(f"{len(dataset[0])*hvd.size()} samples loaded")
if task.startswith('itm'):
# itm handles distributed training in dset not sampler
loader = build_dataloader_itm(*dataset, is_train, opts)
else:
loader = build_dataloader(*dataset, is_train, opts)
if is_train:
ratio = dset['mix_ratio'][i]
dataloaders[task] = (loader, ratio)
else:
dataloaders[task] = PrefetchLoader(loader)
return dataloaders, all_img_dbs
def main(opts):
hvd.init()
n_gpu = hvd.size()
device = torch.device("cuda", hvd.local_rank())
torch.cuda.set_device(hvd.local_rank())
rank = hvd.rank()
opts.rank = rank
LOGGER.info("device: {} n_gpu: {}, rank: {}, "
"16-bits training: {}".format(
device, n_gpu, hvd.rank(), opts.fp16))
if opts.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
"should be >= 1".format(
opts.gradient_accumulation_steps))
set_random_seed(opts.seed)
if rank == 0:
save_training_meta(opts)
TB_LOGGER.create(join(opts.output_dir, 'log'))
pbar = tqdm(total=opts.num_train_steps)
model_saver = ModelSaver(join(args.output_dir, 'ckpt'))
add_log_to_file(join(opts.output_dir, 'log', 'log.txt'))
else:
LOGGER.disabled = True
pbar = NoOp()
model_saver = NoOp()
all_dbs = [db for datasets in [opts.train_datasets, opts.val_datasets]
for dset in datasets for db in dset['db']]
tokenizer = json.load(open(f'{all_dbs[0]}/meta.json'))['bert']
assert all(tokenizer == json.load(open(f'{db}/meta.json'))['bert']
for db in all_dbs)
# build data loaders
train_dataloaders, all_img_dbs = create_dataloaders(
opts.train_datasets, True, opts)
val_dataloaders, _ = create_dataloaders(
opts.val_datasets, False, opts, all_img_dbs)
meta_loader = MetaLoader(train_dataloaders,
accum_steps=opts.gradient_accumulation_steps,
distributed=n_gpu > 1)
meta_loader = PrefetchLoader(meta_loader)
# Prepare model
if opts.checkpoint:
checkpoint = torch.load(opts.checkpoint)
else:
checkpoint = {}
model = UniterForPretraining.from_pretrained(
opts.model_config, checkpoint,
img_dim=IMG_DIM, img_label_dim=IMG_LABEL_DIM,
nce_temp=opts.nce_temp, ot_pos_only=opts.ot_pos_only)
model.pad_vocab() # tensor core padding for vocabulary
model.to(device)
model.train()
# make sure every process has same model parameters in the beginning
broadcast_tensors([p.data for p in model.parameters()], 0)
set_dropout(model, opts.dropout)
# Prepare optimizer
optimizer = build_optimizer(model, opts)
task2scaler = {t: i for i, t in enumerate(train_dataloaders.keys())}
model, optimizer = amp.initialize(model, optimizer,
num_losses=len(task2scaler),
enabled=opts.fp16, opt_level='O2')
global_step = 0
LOGGER.info(f"***** Running training with {n_gpu} GPUs *****")
LOGGER.info(" Batch size = %d", opts.train_batch_size)
LOGGER.info(" Accumulate steps = %d", opts.gradient_accumulation_steps)
LOGGER.info(" Num steps = %d", opts.num_train_steps)
# to compute training statistics
task2loss = {task: RunningMeter(f'loss/{task}')
for task in train_dataloaders.keys()}
# ITM w/ OT
if opts.itm_ot_lambda > 0:
for task in train_dataloaders.keys():
if task.startswith('itm'):
task2loss[f'{task}_xe'] = RunningMeter(f'loss/{task}_xe')
task2loss[f'{task}_ot'] = RunningMeter(f'loss/{task}_ot')
if not opts.ot_pos_only:
task2loss[f'{task}_ot_pos'] = RunningMeter(
f'loss/{task}_ot_pos')
task2loss[f'{task}_ot_neg'] = RunningMeter(
f'loss/{task}_ot_neg')
n_examples = defaultdict(int)
n_in_units = defaultdict(int)
n_loss_units = defaultdict(int)
n_neg_nce = defaultdict(int)
grad_norm = 0
start = time()
# quick hack for amp delay_unscale bug
optimizer.zero_grad()
optimizer.step()
for step, (name, batch) in enumerate(meta_loader):
# forward pass
assert all(name == n for n in all_gather_list(name))
n_examples[name] += batch['input_ids'].size(0)
n_in_units[name] += (batch['attn_masks'] == 1).sum().item()
if 'nce' in name:
n_neg_nce[name] += batch['neg_feats'].size(0)
task = name.split('_')[0]
loss = model(batch, task=task, compute_loss=True)
if task.startswith('itm'):
# OT
itm_loss, ot_loss = loss
n_loss_units[name] += itm_loss.size(0)
itm_loss = itm_loss.mean()
if ot_loss is not None:
if not opts.ot_pos_only:
ot_pos, ot_neg = ot_loss
ot_loss = (ot_pos.sum() - ot_neg.sum()
) / (ot_pos.size(0) + ot_neg.size(0))
# NOTE: be ware of empty tensor
ot_pos = ot_pos.mean().item()
if not math.isnan(ot_pos):
task2loss[f'{name}_ot_pos'](ot_pos)
ot_neg = ot_neg.mean().item()
if not math.isnan(ot_neg):
task2loss[f'{name}_ot_neg'](ot_neg)
else:
ot_loss = ot_loss.mean()
loss = itm_loss + opts.itm_ot_lambda * ot_loss
task2loss[f'{name}_xe'](itm_loss.item())
task2loss[f'{name}_ot'](ot_loss.item())
else:
loss = itm_loss
else:
n_loss_units[name] += loss.size(0)
loss = loss.mean() # loss is not normalized in model
# backward pass
delay_unscale = (step+1) % opts.gradient_accumulation_steps != 0
with amp.scale_loss(loss, optimizer, delay_unscale=delay_unscale,
loss_id=task2scaler[name]) as scaled_loss:
scaled_loss.backward()
if not delay_unscale:
# gather gradients from every processes
# do this before unscaling to make sure every process uses
# the same gradient scale
grads = [p.grad.data for p in model.parameters()
if p.requires_grad and p.grad is not None]
all_reduce_and_rescale_tensors(grads, float(1))
task2loss[name](loss.item())
# optimizer update and logging
if (step + 1) % opts.gradient_accumulation_steps == 0:
global_step += 1
# learning rate scheduling
lr_this_step = get_lr_sched(global_step, opts)
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
TB_LOGGER.add_scalar('lr', lr_this_step, global_step)
# log loss
for t, l in task2loss.items():
loss = sum(v for v in all_gather_list(l.val)
if v is not None) / hvd.size()
task2loss[t] = RunningMeter(f'loss/{t}', loss)
TB_LOGGER.log_scaler_dict({l.name: l.val
for l in task2loss.values()
if l.val is not None})
TB_LOGGER.step()
# update model params
if opts.grad_norm != -1:
'''
if global_step % 10 == 0 and not opts.fp16:
bias = model.bert.img_embeddings.img_linear.bias
weight = model.bert.img_embeddings.img_linear.weight
print(f"bnorm: {bias.norm()}")
print(f"wnorm: {weight.norm()}")
print(f"bgnorm: {bias.grad.norm()}")
print(f"wgnorm: {weight.grad.norm()}")
mask = model.bert.img_embeddings.mask_embedding.weight
print(f"mnorm: {mask.norm()}")
print(f"mgnorm: {mask.grad.norm()}")
print([(n, p.grad.norm().item())
for n, p in model.named_parameters()
if p.grad is not None
and p.grad.norm().item() > grad_norm/10])
'''
grad_norm = clip_grad_norm_(amp.master_params(optimizer),
opts.grad_norm)
TB_LOGGER.add_scalar('grad_norm', grad_norm, global_step)
optimizer.step()
optimizer.zero_grad()
pbar.update(1)
if global_step % 100 == 0:
# monitor training throughput
LOGGER.info(f'==============Step {global_step}===============')
for t in train_dataloaders.keys():
assert all(tt == t for tt in all_gather_list(t))
tot_ex = sum(all_gather_list(n_examples[t]))
ex_per_sec = int(tot_ex / (time()-start))
tot_in = sum(all_gather_list(n_in_units[t]))
in_per_sec = int(tot_in / (time()-start))
tot_l = sum(all_gather_list(n_loss_units[t]))
l_per_sec = int(tot_l / (time()-start))
LOGGER.info(f'{t}: {tot_ex} examples trained at '
f'{ex_per_sec} ex/s')
TB_LOGGER.add_scalar(f'perf/{t}_ex_per_s', ex_per_sec,
global_step)
TB_LOGGER.add_scalar(f'perf/{t}_in_per_s', in_per_sec,
global_step)
TB_LOGGER.add_scalar(f'perf/{t}_loss_per_s', l_per_sec,
global_step)
if 'nce' in t:
avg_neg = sum(all_gather_list(n_neg_nce[t])
) / hvd.size() // step
LOGGER.info(f'{t}: averaging '
f'{avg_neg} negative samples')
LOGGER.info(f'===============================================')
if global_step % opts.valid_steps == 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(model, val_dataloaders)
model_saver.save(model, global_step, optimizer)
if global_step >= opts.num_train_steps:
break
if global_step % opts.valid_steps != 0:
LOGGER.info(f'Step {global_step}: start validation')
validate(model, val_dataloaders)
model_saver.save(model, global_step)
def validate(model, val_dataloaders):
model.eval()
for task, loader in val_dataloaders.items():
LOGGER.info(f"validate on {task} task")
if task.startswith('mlm'):
val_log = validate_mlm(model, loader)
elif task.startswith('mrfr'):
val_log = validate_mrfr(model, loader)
elif task.startswith('mrm-nce'):
val_log = validate_mrm_nce(model, loader)
elif task.startswith('mrc'):
val_log = validate_mrc(model, loader, task)
elif task.startswith('itm'):
val_log = validate_itm(model, loader)
else:
raise ValueError(f'Undefined task {task}')
val_log = {f'{task}_{k}': v for k, v in val_log.items()}
TB_LOGGER.log_scaler_dict(
{f'valid_{task}/{k}': v for k, v in val_log.items()})
model.train()
@torch.no_grad()
def validate_mlm(model, val_loader):
LOGGER.info("start running MLM validation...")
val_loss = 0
n_correct = 0
n_word = 0
st = time()
for i, batch in enumerate(val_loader):
scores = model(batch, task='mlm', compute_loss=False)
labels = batch['txt_labels']
labels = labels[labels != -1]
loss = F.cross_entropy(scores, labels, reduction='sum')
val_loss += loss.item()
n_correct += (scores.max(dim=-1)[1] == labels).sum().item()
n_word += labels.numel()
val_loss = sum(all_gather_list(val_loss))
n_correct = sum(all_gather_list(n_correct))
n_word = sum(all_gather_list(n_word))
tot_time = time()-st
val_loss /= n_word
acc = n_correct / n_word
val_log = {'loss': val_loss,
'acc': acc,
'tok_per_s': n_word/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"acc: {acc*100:.2f}")
return val_log
@torch.no_grad()
def validate_mlm_old(model, val_loader):
LOGGER.info("start running MLM validation...")
val_loss = 0
n_correct = 0
n_word = 0
st = time()
for i, batch in enumerate(val_loader):
scores = model.forward(batch, task='mlm', compute_loss=False)
loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-1,
reduction='sum')
scores = scores.contiguous().view(-1, model.config.vocab_size)
labels = batch['txt_labels'].contiguous().view(-1)
loss = loss_fct(scores, labels)
val_loss += loss.item()
n_correct += accuracy_count(scores, labels)
n_word += batch['txt_labels'].numel()
val_loss = sum(all_gather_list(val_loss))
n_correct = sum(all_gather_list(n_correct))
n_word = sum(all_gather_list(n_word))
tot_time = time()-st
val_loss /= n_word
acc = n_correct / n_word
val_log = {'loss': val_loss,
'acc': acc,
'tok_per_s': n_word/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"acc: {acc*100:.2f}")
return val_log
def accuracy_count(out, labels):
outputs = out.max(dim=-1)[1]
mask = labels != -1
n_correct = (outputs == labels).masked_select(mask).sum().item()
return n_correct
@torch.no_grad()
def validate_mrfr(model, val_loader):
LOGGER.info("start running MRFR validation...")
val_loss = 0
n_feat = 0
st = time()
for i, batch in enumerate(val_loader):
loss = model(batch, task='mrfr', compute_loss=True)
val_loss += loss.sum().item() / IMG_DIM
n_feat += batch['img_mask_tgt'].sum().item()
val_loss = sum(all_gather_list(val_loss))
n_feat = sum(all_gather_list(n_feat))
tot_time = time()-st
val_loss /= n_feat
val_log = {'loss': val_loss,
'feat_per_s': n_feat/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"loss: {val_loss:.2f}")
return val_log
@torch.no_grad()
def validate_mrm_nce(model, val_loader):
LOGGER.info("start running MRM-NCE validation...")
val_loss = 0
val_l2 = 0
n_correct = 0
cosine = 0
n_feat = 0
n_neg = 0
st = time()
for i, batch in enumerate(val_loader):
feats, pos_feats, neg_feats = model(batch, task='mrm-nce',
compute_loss=False)
logits = model.mrm_nce(feats, pos_feats, neg_feats,
compute_loss=False)
targets = torch.arange(0, logits.size(0),
dtype=torch.long, device=logits.device)
val_loss += F.cross_entropy(logits, targets, reduction='sum').item()
val_l2 += F.mse_loss(feats, pos_feats, reduction='sum'
).item() / feats.size(-1)
n_correct += (logits.max(dim=-1)[1] == targets).sum().item()
cosine += F.cosine_similarity(feats, pos_feats, dim=-1).sum().item()
nf = batch['img_mask_tgt'].sum().item()
n_feat += nf
n_neg += neg_feats.size(0) * nf
val_loss = sum(all_gather_list(val_loss))
val_l2 = sum(all_gather_list(val_l2))
n_correct = sum(all_gather_list(n_correct))
cosine = sum(all_gather_list(cosine))
n_feat = sum(all_gather_list(n_feat))
n_neg = sum(all_gather_list(n_neg))
tot_time = time()-st
val_loss /= n_feat
val_acc = n_correct / n_feat
val_log = {'loss': val_loss,
'acc': val_acc,
'l2': val_l2 / n_feat,
'cosine': cosine / n_feat,
'feat_per_s': n_feat/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"loss: {val_loss:.2f}, acc: {val_acc*100:.2f} "
f"(average {n_neg/n_feat:.0f} negatives)")
return val_log
@torch.no_grad()
def validate_mrc(model, val_loader, task):
LOGGER.info("start running MRC validation...")
val_loss = 0
n_feat = 0
st = time()
tot_score = 0
for i, batch in enumerate(val_loader):
prediction_soft_label = model(
batch, task=task, compute_loss=False)
if "kl" in task:
prediction_soft_label = F.log_softmax(
prediction_soft_label, dim=-1)
label_targets = batch['label_targets']
loss = F.kl_div(
prediction_soft_label, label_targets, reduction='sum')
tot_score += compute_accuracy_for_soft_targets(
prediction_soft_label, label_targets)
else:
# background class should not be the target
cls_label_targets = label_targets[:, 1:].max(dim=-1)[1] + 1
loss = F.cross_entropy(
prediction_soft_label, cls_label_targets,
ignore_index=0, reduction='sum')
tot_score += compute_accuracy_for_soft_targets(
prediction_soft_label[:, 1:], label_targets[:, 1:])
val_loss += loss.item()
n_feat += batch['img_mask_tgt'].sum().item()
val_loss = sum(all_gather_list(val_loss))
tot_score = sum(all_gather_list(tot_score))
n_feat = sum(all_gather_list(n_feat))
tot_time = time()-st
val_loss /= n_feat
val_acc = tot_score / n_feat
val_log = {'loss': val_loss,
'acc': val_acc,
'feat_per_s': n_feat/tot_time}
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score: {val_acc*100:.2f}")
return val_log
def compute_accuracy_for_soft_targets(out, labels):
outputs = out.max(dim=-1)[1]
labels = labels.max(dim=-1)[1] # argmax
n_correct = (outputs == labels).sum().item()
return n_correct
@torch.no_grad()
def validate_itm(model, val_loader):
LOGGER.info("start running ITM validation...")
val_loss = 0
tot_ot_loss = 0
tot_ot_pos = 0
tot_ot_neg = 0
tot_score = 0
n_ex = 0
st = time()
for i, batch in enumerate(val_loader):
scores, ot_loss = model(batch, task='itm', compute_loss=False)
if ot_loss is not None:
if isinstance(ot_loss, tuple):
ot_pos, ot_neg = ot_loss
ot_pos = ot_pos.sum().item()
ot_neg = ot_neg.sum().item()
tot_ot_pos += ot_pos
tot_ot_neg += ot_neg
tot_ot_loss += ot_pos - ot_neg
else:
tot_ot_loss += ot_loss.sum().item()
targets = batch['targets']
loss = F.cross_entropy(scores, targets, reduction='sum')
val_loss += loss.item()
tot_score += (scores.max(dim=-1)[1] == targets).sum().item()
n_ex += len(targets)
val_loss = sum(all_gather_list(val_loss))
tot_score = sum(all_gather_list(tot_score))
n_ex = sum(all_gather_list(n_ex))
tot_time = time()-st
val_loss /= n_ex
val_acc = tot_score / n_ex
val_log = {'valid/loss': val_loss,
'valid/acc': val_acc,
'valid/ex_per_s': n_ex/tot_time}
if ot_loss is not None:
tot_ot_loss = sum(all_gather_list(tot_ot_loss))
tot_ot_pos = sum(all_gather_list(tot_ot_pos))
tot_ot_neg = sum(all_gather_list(tot_ot_neg))
val_log['valid/ot_loss'] = tot_ot_loss / n_ex
val_log['valid/ot_pos'] = tot_ot_pos / n_ex
val_log['valid/ot_neg'] = tot_ot_neg / n_ex
LOGGER.info(f"validation finished in {int(tot_time)} seconds, "
f"score: {val_acc*100:.2f}")
return val_log
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Required parameters
# NOTE: train tasks and val tasks cannot take command line arguments
parser.add_argument('--compressed_db', action='store_true',
help='use compressed LMDB')
parser.add_argument("--model_config", type=str,
help="path to model structure config json")
parser.add_argument("--checkpoint", default=None, type=str,
help="path to model checkpoint (*.pt)")
parser.add_argument(
"--output_dir", default=None, type=str,
help="The output directory where the model checkpoints will be "
"written.")
parser.add_argument('--mrm_prob', default=0.15, type=float,
help='probability to mask in MRM training')
parser.add_argument('--neg_size', default=128, type=int,
help='negative image size for NCE')
parser.add_argument('--nce_temp', default=1.0, type=float,
help='softmax temperature for NCE')
parser.add_argument('--itm_neg_prob', default=0.5, type=float,
help='probability to make negative examples'
'in ITM training')
parser.add_argument('--itm_ot_lambda', default=0.0, type=float,
help='weight of OT (optimal transport) loss')
parser.add_argument('--ot_pos_only', action='store_true',
help='use OT distance of positive pairs only')
# Prepro parameters
parser.add_argument('--max_txt_len', type=int, default=60,
help='max number of tokens in text (BERT BPE)')
parser.add_argument('--conf_th', type=float, default=0.2,
help='threshold for dynamic bounding boxes '
'(-1 for fixed)')
parser.add_argument('--max_bb', type=int, default=100,
help='max number of bounding boxes')
parser.add_argument('--min_bb', type=int, default=10,
help='min number of bounding boxes')
parser.add_argument('--num_bb', type=int, default=36,
help='static number of bounding boxes')
# training parameters
parser.add_argument("--train_batch_size", default=4096, type=int,
help="Total batch size for training. "
"(batch by tokens)")
parser.add_argument("--val_batch_size", default=4096, type=int,
help="Total batch size for validation. "
"(batch by tokens)")
parser.add_argument('--gradient_accumulation_steps', type=int, default=16,
help="Number of updates steps to accumualte before "
"performing a backward/update pass.")
parser.add_argument("--learning_rate", default=3e-5, type=float,
help="The initial learning rate for Adam.")
parser.add_argument("--valid_steps", default=1000, type=int,
help="Run validation every X steps")
parser.add_argument("--num_train_steps", default=100000, type=int,
help="Total number of training updates to perform.")
parser.add_argument("--optim", default='adamw',
choices=['adam', 'adamax', 'adamw'],
help="optimizer")
parser.add_argument("--betas", default=[0.9, 0.98], nargs='+',
help="beta for adam optimizer")
parser.add_argument("--decay", default='linear',
choices=['linear', 'invsqrt'],
help="learning rate decay method")
parser.add_argument("--dropout", default=0.1, type=float,
help="tune dropout regularization")
parser.add_argument("--weight_decay", default=0.01, type=float,
help="weight decay (L2) regularization")
parser.add_argument("--grad_norm", default=2.0, type=float,
help="gradient clipping (-1 for no clipping)")
parser.add_argument("--warmup_steps", default=10000, type=int,
help="Number of training steps to perform linear "
"learning rate warmup for. (invsqrt decay)")
# device parameters
parser.add_argument('--seed', type=int, default=42,
help="random seed for initialization")
parser.add_argument('--fp16', action='store_true',
help="Whether to use 16-bit float precision instead "
"of 32-bit")
parser.add_argument('--n_workers', type=int, default=4,
help="number of data workers")
parser.add_argument('--pin_mem', action='store_true', help="pin memory")
# can use config files
parser.add_argument('--config', required=True, help='JSON config files')
args = parse_with_config(parser)
if exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not "
"empty.".format(args.output_dir))
# options safe guard
if args.conf_th == -1:
assert args.max_bb + args.max_txt_len + 2 <= 512
else:
assert args.num_bb + args.max_txt_len + 2 <= 512
main(args)