lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
757 lines
34 KiB
757 lines
34 KiB
#!/usr/bin/env python3
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
"""
|
|
Pipeline to train DPR Biencoder
|
|
"""
|
|
|
|
import argparse
|
|
import glob
|
|
import logging
|
|
import math
|
|
import os
|
|
import random
|
|
import time
|
|
import json
|
|
import csv
|
|
import re
|
|
|
|
import torch
|
|
import numpy as np
|
|
|
|
from typing import Tuple
|
|
from collections import defaultdict
|
|
from torch import nn
|
|
from torch import Tensor as T
|
|
from torch.optim.lr_scheduler import LambdaLR
|
|
|
|
import torch.nn.functional as F
|
|
from torch.nn import LayerNorm
|
|
from transformers import BertModel, BertConfig, BertPreTrainedModel
|
|
from transformers.optimization import AdamW
|
|
from uniter_model.model.model import UniterPreTrainedModel, UniterModel, UniterConfig
|
|
from dvl.const import IMG_DIM
|
|
#from dvl.indexer.faiss_indexers import DenseIndexer
|
|
from uniter_model.model.layer import GELU, BertOnlyMLMHead, BertPooler
|
|
from uniter_model.model.model import RegionClassification, RegionFeatureRegression, pad_tensor_to_mul
|
|
|
|
from typing import List
|
|
|
|
|
|
logger = logging.getLogger()
|
|
logger.setLevel(logging.INFO)
|
|
if (logger.hasHandlers()):
|
|
logger.handlers.clear()
|
|
console = logging.StreamHandler()
|
|
logger.addHandler(console)
|
|
|
|
|
|
def dot_product_scores(q_vectors: T, ctx_vectors: T, cosine=False) -> T:
|
|
"""
|
|
calculates q->ctx scores for every row in ctx_vector
|
|
:param q_vector:
|
|
:param ctx_vector:
|
|
:return:
|
|
"""
|
|
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
|
|
r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1))
|
|
if cosine:
|
|
n1 = torch.norm(q_vectors, dim=-1)
|
|
n2 = torch.norm(ctx_vectors, dim=-1)
|
|
n_out = torch.ger(n1, n2)
|
|
return r / n_out
|
|
return r
|
|
|
|
|
|
def cosine_scores(q_vector: T, ctx_vectors: T):
|
|
# q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2
|
|
return F.cosine_similarity(q_vector, ctx_vectors, dim=1)
|
|
|
|
|
|
class BertEncoder(BertPreTrainedModel):
|
|
def __init__(self, config, project_dim: int = 0):
|
|
super().__init__(config)
|
|
self.bert = BertModel(config)
|
|
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
|
|
# self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None
|
|
if project_dim > 0:
|
|
self.encode_proj = nn.Sequential(
|
|
nn.Linear(config.hidden_size, config.hidden_size * 2),
|
|
GELU(),
|
|
LayerNorm(config.hidden_size * 2, eps=1e-12),
|
|
nn.Linear(config.hidden_size * 2, project_dim)
|
|
)
|
|
else:
|
|
self.encode_proj = None
|
|
self.init_weights()
|
|
|
|
@classmethod
|
|
def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\
|
|
-> BertModel:
|
|
|
|
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased')
|
|
if dropout != 0:
|
|
cfg.attention_probs_dropout_prob = dropout
|
|
cfg.hidden_dropout_prob = dropout
|
|
|
|
if checkpoint_path is not None and len(checkpoint_path) > 0:
|
|
#state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
|
state_dict = torch.load(checkpoint_path)
|
|
return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, state_dict=state_dict, **kwargs)
|
|
else:
|
|
return cls.from_pretrained(cfg_name, config=cfg, project_dim=project_dim, **kwargs)
|
|
|
|
def forward(self, input_ids, attention_mask, position_ids,
|
|
img_feat=None, img_pos_feat=None, img_masks=None, gather_index=None):
|
|
if self.config.output_hidden_states:
|
|
sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids,
|
|
token_type_ids=None,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids)
|
|
else:
|
|
hidden_states = None
|
|
sequence_output, pooled_output = self.bert(input_ids=input_ids,
|
|
token_type_ids=None,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids)
|
|
pooled_output = sequence_output[:, 0, :]
|
|
if self.encode_proj:
|
|
pooled_output = self.encode_proj(pooled_output)
|
|
return sequence_output, pooled_output, hidden_states
|
|
|
|
def get_out_size(self):
|
|
if self.encode_proj:
|
|
return self.encode_proj.out_features
|
|
return self.config.hidden_size
|
|
|
|
|
|
class UniterEncoder(UniterPreTrainedModel):
|
|
def __init__(self, config, project_dim: int = 0):
|
|
super().__init__(config)
|
|
self.bert = UniterModel(config, IMG_DIM)
|
|
assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero'
|
|
# self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None # Yen-Chun
|
|
if project_dim > 0:
|
|
self.encode_proj = nn.Sequential(
|
|
nn.Linear(config.hidden_size, config.hidden_size * 2),
|
|
GELU(),
|
|
LayerNorm(config.hidden_size * 2, eps=1e-12),
|
|
nn.Linear(config.hidden_size * 2, project_dim)
|
|
)
|
|
else:
|
|
self.encode_proj = None
|
|
self.apply(self.init_weights)
|
|
|
|
@classmethod
|
|
def init_encoder(cls, cfg_name: str, checkpoint_path: str, project_dim: int = 0, dropout: float = 0.1, **kwargs)\
|
|
-> UniterModel:
|
|
cfg = BertConfig.from_pretrained(cfg_name if cfg_name else 'bert-base-uncased')
|
|
if dropout != 0:
|
|
cfg.attention_probs_dropout_prob = dropout
|
|
cfg.hidden_dropout_prob = dropout
|
|
if checkpoint_path is not None and len(checkpoint_path) > 0 and checkpoint_path.lower() != 'none':
|
|
logger.info(f'load from {checkpoint_path} for uniter encoder')
|
|
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
|
state_dict = torch.load(checkpoint_path)
|
|
#if 'model_dict' in state_dict:
|
|
# state_dict = state_dict['model_dict']
|
|
else:
|
|
logger.info('no checkpoint, random initialization for img encoder')
|
|
state_dict = dict()
|
|
return cls.from_pretrained(cfg_name, state_dict=state_dict, project_dim=project_dim, **kwargs)
|
|
|
|
def forward(self, input_ids, attention_mask, position_ids,
|
|
img_feat, img_pos_feat, img_masks, gather_index=None) -> Tuple[T, ...]:
|
|
if self.config.output_hidden_states:
|
|
sequence_output, pooled_output, hidden_states = self.bert(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
img_feat=img_feat,
|
|
img_pos_feat=img_pos_feat,
|
|
img_masks=img_masks,
|
|
img_type_ids=None,
|
|
gather_index=gather_index,
|
|
output_all_encoded_layers=True
|
|
)
|
|
else:
|
|
hidden_states = None
|
|
sequence_output = self.bert(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attention_mask=attention_mask,
|
|
img_feat=img_feat,
|
|
img_pos_feat=img_pos_feat,
|
|
img_masks=img_masks,
|
|
img_type_ids=None,
|
|
gather_index=gather_index,
|
|
output_all_encoded_layers=False)
|
|
# pooled_output = self.bert.pooler(sequence_output)
|
|
pooled_output = sequence_output[:, 0, :]
|
|
if self.encode_proj:
|
|
pooled_output = self.encode_proj(pooled_output)
|
|
return sequence_output, pooled_output, hidden_states
|
|
|
|
def get_out_size(self):
|
|
if self.encode_proj:
|
|
return self.encode_proj.out_features
|
|
return self.config.hidden_size
|
|
|
|
|
|
class BiEncoder(nn.Module):
|
|
""" Bi-Encoder model component. Encapsulates query/question and context/passage encoders.
|
|
"""
|
|
|
|
def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False, project_dim: int = 0):
|
|
super(BiEncoder, self).__init__()
|
|
logger.info('*'*100)
|
|
logger.info('loading img model')
|
|
if args.img_model_type == 'uniter-base':
|
|
self.img_model = UniterEncoder.init_encoder(args.img_model_config, checkpoint_path=args.img_checkpoint, project_dim=project_dim)
|
|
else:
|
|
raise ValueError(f'image encoder does not support other types ({args.img_model_type}) for now')
|
|
|
|
logger.info('*' * 100)
|
|
logger.info('loading txt model')
|
|
if args.txt_model_type == 'bert-base':
|
|
self.txt_model = BertEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim)
|
|
elif args.txt_model_type == 'uniter-base':
|
|
self.txt_model = UniterEncoder.init_encoder(args.txt_model_config, checkpoint_path=args.txt_checkpoint, project_dim=project_dim)
|
|
else:
|
|
raise ValueError(f'txt encoder does not support other types ({args.txt_model_type}) for now')
|
|
|
|
self.fix_img_encoder = fix_img_encoder
|
|
self.fix_txt_encoder = fix_txt_encoder
|
|
self.project_dim = project_dim
|
|
if fix_txt_encoder:
|
|
for param in self.txt_model.parameters():
|
|
param.requires_grad = False
|
|
if fix_img_encoder:
|
|
for param in self.img_model.parameters():
|
|
param.requires_grad = False
|
|
|
|
@staticmethod
|
|
def get_representation(sub_model, input_ids, attention_mask, position_ids, img_feat, img_pos_feat, img_masks,
|
|
gather_index=None, fix_encoder=False):
|
|
if fix_encoder:
|
|
with torch.no_grad():
|
|
sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids,
|
|
img_feat, img_pos_feat, img_masks,
|
|
gather_index)
|
|
else:
|
|
sequence_output, pooled_output, hidden_states = sub_model(input_ids, attention_mask, position_ids,
|
|
img_feat, img_pos_feat, img_masks,
|
|
gather_index)
|
|
|
|
if sub_model.training:
|
|
sequence_output.requires_grad_(requires_grad=True)
|
|
pooled_output.requires_grad_(requires_grad=True)
|
|
|
|
return sequence_output, pooled_output, hidden_states
|
|
|
|
def forward(self, batch, output_all_encoded_layers=False):
|
|
# batch keys
|
|
# imgs
|
|
# txts
|
|
# caps
|
|
batch = defaultdict(lambda: None, batch)
|
|
|
|
if 'txts' in batch:
|
|
sb = batch['txts']
|
|
txt_seq, txt_pooled, txt_hidden = self.get_representation(self.txt_model, sb['input_ids'],
|
|
sb['attention_mask'], sb['position_ids'],
|
|
sb['img_feat'], sb['img_pos_feat'],
|
|
sb['img_masks'],
|
|
sb['gather_index'], self.fix_txt_encoder)
|
|
else:
|
|
txt_seq, txt_pooled = None, None
|
|
|
|
if 'imgs' in batch:
|
|
sb = batch['imgs']
|
|
img_seq, img_pooled, img_hidden = self.get_representation(self.img_model, sb['input_ids'],
|
|
sb['attention_mask'], sb['position_ids'],
|
|
sb['img_feat'], sb['img_pos_feat'],
|
|
sb['img_masks'],
|
|
sb['gather_index'], self.fix_txt_encoder)
|
|
else:
|
|
img_seq, img_pooled = None, None
|
|
|
|
if 'caps' in batch and batch['caps']['input_ids'] is not None:
|
|
sb = batch['caps']
|
|
cap_seq, cap_pooled, cap_hidden = self.get_representation(self.txt_model, sb['input_ids'],
|
|
sb['attention_mask'], sb['position_ids'],
|
|
sb['img_feat'], sb['img_pos_feat'],
|
|
sb['img_masks'],
|
|
sb['gather_index'], self.fix_txt_encoder)
|
|
else:
|
|
cap_seq, cap_pooled = None, None
|
|
|
|
if output_all_encoded_layers:
|
|
return txt_seq, img_seq, cap_seq
|
|
else:
|
|
return txt_pooled, img_pooled, cap_pooled
|
|
|
|
|
|
class BiEncoderForPretraining(nn.Module):
|
|
""" MLM + MRM """
|
|
def __init__(self, config_file, args, project_dim, img_dim, img_label_dim, nce_temp=1, ot_pos_only=False,
|
|
experiment=None):
|
|
super().__init__()
|
|
config = UniterConfig.from_json_file(config_file)
|
|
self.bert = BiEncoder(args, project_dim=project_dim)
|
|
self.cls = BertOnlyMLMHead(
|
|
config, self.bert.img_model.bert.embeddings.word_embeddings.weight) # ???
|
|
self.feat_regress = RegionFeatureRegression(
|
|
config.hidden_size, img_dim,
|
|
self.bert.img_model.bert.img_embeddings.img_linear.weight)
|
|
self.region_classifier = RegionClassification(
|
|
config.hidden_size, img_label_dim)
|
|
self.itm_output = nn.Linear(config.hidden_size, 2)
|
|
self.cls_concat = args.cls_concat
|
|
'''
|
|
self.nce_output = BertPredictionHeadTransform(config)
|
|
self.nce_output = nn.Sequential(BertPredictionHeadTransform(config),
|
|
nn.Linear(config.hidden_size, img_dim))
|
|
self.nce_norm = LayerNorm(config.hidden_size, eps=1e-12)
|
|
self.nce_temp = nce_temp # temperature
|
|
'''
|
|
self.ot_pos_only = ot_pos_only
|
|
# self.apply(self.init_weights)
|
|
self.vocab_pad = 0
|
|
self.experiment = experiment
|
|
|
|
def pad_vocab(self):
|
|
# FIXME better padding after integrating huggingface ???
|
|
emb_w = self.bert.embeddings.word_embeddings.weight.data
|
|
padded_emb_w, n_pad = pad_tensor_to_mul(emb_w)
|
|
padded_emb_w = nn.Parameter(padded_emb_w)
|
|
self.bert.embeddings.word_embeddings.weight = padded_emb_w
|
|
self.cls.predictions.decoder.weight = padded_emb_w
|
|
self.vocab_pad = n_pad
|
|
|
|
def forward(self, batch, task, compute_loss=True):
|
|
batch = defaultdict(lambda: None, batch)
|
|
if task == 'mlm':
|
|
txt_labels = batch['txt_labels']
|
|
return self.forward_mlm(batch, txt_labels, compute_loss)
|
|
elif task == 'mrfr':
|
|
img_mask_tgt = batch['img_mask_tgt']
|
|
img_masks = batch['img_masks']
|
|
mrfr_feat_target = batch['feat_targets']
|
|
return self.forward_mrfr(batch, img_masks, img_mask_tgt, mrfr_feat_target, compute_loss)
|
|
elif task == 'mrm-nce':
|
|
raise NotImplementedError('nce does not work')
|
|
img_mask_tgt = batch['img_mask_tgt']
|
|
img_masks = batch['img_masks']
|
|
img_masks_in = batch['img_masks_in']
|
|
feat_target = batch['feat_targets']
|
|
neg_feats = batch['neg_feats']
|
|
return self.forward_mrm_nce(batch,
|
|
img_masks_in, img_masks, img_mask_tgt,
|
|
feat_target, neg_feats, compute_loss)
|
|
elif task == 'itm':
|
|
targets = batch['targets']
|
|
ot_inputs = batch['ot_inputs']
|
|
return self.forward_itm(batch,
|
|
targets, ot_inputs, compute_loss)
|
|
elif task.startswith('mrc'):
|
|
img_mask_tgt = batch['img_mask_tgt']
|
|
img_masks = batch['img_masks']
|
|
mrc_label_target = batch['label_targets']
|
|
return self.forward_mrc(batch,
|
|
img_masks, img_mask_tgt,
|
|
mrc_label_target, task, compute_loss)
|
|
else:
|
|
raise ValueError('invalid task')
|
|
|
|
# MLM
|
|
def forward_mlm(self, batch, txt_labels, compute_loss=True):
|
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True)
|
|
# get only the text part
|
|
|
|
img_cls = img_seq[:, 0:1, :].repeat(1, txt_seq.shape[1], 1)
|
|
if self.cls_concat == 'add':
|
|
sequence_output = txt_seq + img_cls
|
|
elif self.cls_concat == 'multiply':
|
|
sequence_output = txt_seq * img_cls
|
|
elif len(self.cls_concat) == 0:
|
|
sequence_output = txt_seq
|
|
else:
|
|
raise NotImplementedError(f'{self.cls_concat} not implemented yet')
|
|
# only compute masked tokens for better efficiency
|
|
masked_output = self._compute_masked_hidden(sequence_output,
|
|
txt_labels != -1)
|
|
prediction_scores = self._pad_layer_unpad(masked_output, self.cls)
|
|
if self.vocab_pad:
|
|
prediction_scores = prediction_scores[:, :-self.vocab_pad]
|
|
|
|
masked_lm_loss = F.cross_entropy(prediction_scores,
|
|
txt_labels[txt_labels != -1],
|
|
reduction='none')
|
|
return masked_lm_loss, prediction_scores
|
|
|
|
def _compute_masked_hidden(self, hidden, mask):
|
|
""" get only the masked region (don't compute unnecessary hiddens) """
|
|
mask = mask.unsqueeze(-1).expand_as(hidden)
|
|
hidden_masked = hidden[mask].contiguous().view(-1, hidden.size(-1))
|
|
return hidden_masked
|
|
|
|
def _pad_layer_unpad(self, input_, layer):
|
|
input_, n_pad = pad_tensor_to_mul(input_)
|
|
output = layer(input_)
|
|
if n_pad:
|
|
output = output[:-n_pad, :]
|
|
return output
|
|
|
|
def mlm_eval(self, batch, gather_tgt):
|
|
raise ValueError('Do not use this')
|
|
sequence_output = self.bert(batch, output_all_encoded_layers=False)
|
|
# get only the text part (excluding [CLS], [SEP])
|
|
sequence_output = sequence_output[:, 1:input_ids.size(1)-1, :]
|
|
# only compute masked tokens for better efficiency
|
|
index = gather_tgt.unsqueeze(-1).expand(
|
|
-1, -1, self.config.hidden_size)
|
|
masked_output = torch.gather(sequence_output, dim=0, index=index)
|
|
prediction_scores = self.cls(masked_output)
|
|
if self.vocab_pad:
|
|
prediction_scores = prediction_scores[..., :-self.vocab_pad]
|
|
return prediction_scores
|
|
|
|
# MRFR
|
|
def forward_mrfr(self, batch, img_masks, img_mask_tgt,
|
|
feat_targets, compute_loss=True):
|
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True)
|
|
txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1)
|
|
if self.cls_concat == 'add':
|
|
sequence_output = img_seq + txt_cls
|
|
elif self.cls_concat == 'multiply':
|
|
sequence_output = img_seq * txt_cls
|
|
elif len(self.cls_concat) == 0:
|
|
sequence_output = img_seq
|
|
else:
|
|
raise NotImplementedError(f'{self.cls_concat} not implemented yet')
|
|
# only compute masked tokens for better efficiency
|
|
masked_output = self._compute_masked_hidden(sequence_output,
|
|
img_mask_tgt)
|
|
prediction_feat = self._pad_layer_unpad(masked_output,
|
|
self.feat_regress)
|
|
|
|
mrfr_loss = F.mse_loss(prediction_feat, feat_targets,
|
|
reduction='none')
|
|
return mrfr_loss, prediction_feat
|
|
|
|
# MRM-NCE
|
|
def forward_mrm_nce(self,batch,
|
|
img_masks_in, img_masks, img_mask_tgt,
|
|
feat_targets, neg_feats, compute_loss=True):
|
|
sequence_output = self.bert(batch,
|
|
output_all_encoded_layers=False,
|
|
img_masks=img_masks_in)
|
|
|
|
# only compute masked tokens for better efficiency
|
|
masked_output = self._compute_masked_hidden(sequence_output,
|
|
img_mask_tgt)
|
|
|
|
masked_output = self._pad_layer_unpad(masked_output, self.nce_output)
|
|
# neg within batch
|
|
batch_neg = self._compute_masked_hidden(img_feat, ~img_masks)
|
|
neg_feats, _ = pad_tensor_to_mul(
|
|
torch.cat([neg_feats, batch_neg], dim=0))
|
|
|
|
# shared image linear transform
|
|
neg_output = self.nce_norm(
|
|
self.bert.img_embeddings.img_linear(neg_feats))
|
|
pos_output = self._pad_layer_unpad(feat_targets,
|
|
self.bert.img_embeddings.img_linear)
|
|
pos_output = self.nce_norm(pos_output)
|
|
|
|
mrm_nce_loss = self.mrm_nce(masked_output, pos_output,
|
|
neg_output, compute_loss=True)
|
|
return mrm_nce_loss, masked_output
|
|
|
|
def mrm_nce(self, masked_output, pos_output, neg_output,
|
|
compute_loss=True):
|
|
# dot product of ground truth feature
|
|
masked_score = masked_output.matmul(pos_output.t())
|
|
# dot product of neative samples
|
|
neg_score = masked_output.matmul(neg_output.t())
|
|
|
|
logits = torch.cat([masked_score, neg_score], dim=1).float()
|
|
targets = torch.arange(0, masked_output.size(0),
|
|
dtype=torch.long, device=logits.device)
|
|
loss = F.cross_entropy(logits/self.nce_temp, targets,
|
|
reduction='none')
|
|
return loss, logits
|
|
|
|
def forward_itm(self, batch, targets, ot_inputs,
|
|
compute_loss=True):
|
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=False)
|
|
# OT loss
|
|
if ot_inputs is not None:
|
|
ot_scatter = ot_inputs['ot_scatter']
|
|
|
|
b = sequence_output.size(0)
|
|
tl = input_ids.size(1)
|
|
il = img_feat.size(1)
|
|
max_l = max(ot_inputs['scatter_max'] + 1, tl+il)
|
|
|
|
ot_scatter = ot_scatter.unsqueeze(-1).expand_as(sequence_output)
|
|
ctx_emb = torch.zeros(b, max_l, self.config.hidden_size,
|
|
dtype=sequence_output.dtype,
|
|
device=sequence_output.device
|
|
).scatter_(dim=1, index=ot_scatter,
|
|
src=sequence_output)
|
|
txt_emb = ctx_emb[:, :tl, :]
|
|
img_emb = ctx_emb[:, tl:tl+il, :]
|
|
|
|
txt_pad = ot_inputs['txt_pad']
|
|
img_pad = ot_inputs['img_pad']
|
|
ot_dist = optimal_transport_dist(txt_emb, img_emb,
|
|
txt_pad, img_pad)
|
|
if self.ot_pos_only:
|
|
ot_loss = ot_dist.masked_select(targets == 1)
|
|
else:
|
|
ot_pos_dist = ot_dist.masked_select(targets == 1)
|
|
ot_neg_dist = ot_dist.masked_select(targets == 0)
|
|
ot_loss = (ot_pos_dist, ot_neg_dist)
|
|
else:
|
|
ot_loss = None
|
|
|
|
loss_function = BiEncoderNllLoss()
|
|
itm_loss1, is_correct1, scores1 = loss_function.calc(txt_seq, img_seq, cap_seq,
|
|
batch['pos_ctx_indices'],
|
|
batch['neg_ctx_indices'],
|
|
0.0, self.experiment, 'none')
|
|
itm_loss2, is_correct2, scores2 = loss_function.calc(img_seq, txt_seq, cap_seq,
|
|
batch['pos_ctx_indices'],
|
|
batch['neg_ctx_indices'],
|
|
0.0, self.experiment, 'none')
|
|
if compute_loss:
|
|
return itm_loss1*0.5 + itm_loss2*0.5, ot_loss
|
|
else:
|
|
return itm_loss1*0.5 + itm_loss2*0.5, ot_loss, is_correct1*0.5 + is_correct2*0.5
|
|
|
|
# MRC
|
|
def forward_mrc(self, batch, img_masks, img_mask_tgt,
|
|
label_targets, task, compute_loss=True):
|
|
txt_seq, img_seq, cap_seq = self.bert(batch, output_all_encoded_layers=True)
|
|
txt_cls = txt_seq[:, 0:1, :].repeat(1, img_seq.shape[1], 1)
|
|
if self.cls_concat == 'add':
|
|
sequence_output = img_seq + txt_cls
|
|
elif self.cls_concat == 'multiply':
|
|
sequence_output = img_seq * txt_cls
|
|
elif len(self.cls_concat) == 0:
|
|
sequence_output = img_seq
|
|
else:
|
|
raise NotImplementedError(f'{self.cls_concat} not implemented yet')
|
|
|
|
# sequence_output = torch.cat([txt_seq, img_seq], dim=1)
|
|
# only compute masked regions for better efficiency
|
|
masked_output = self._compute_masked_hidden(sequence_output, img_mask_tgt)
|
|
prediction_soft_label = self._pad_layer_unpad(masked_output,
|
|
self.region_classifier)
|
|
|
|
if "kl" in task:
|
|
prediction_soft_label = F.log_softmax(
|
|
prediction_soft_label, dim=-1)
|
|
mrc_loss = F.kl_div(
|
|
prediction_soft_label, label_targets, reduction='none')
|
|
else:
|
|
# background class should not be the target
|
|
label_targets = torch.max(label_targets[:, 1:], dim=-1)[1] + 1
|
|
mrc_loss = F.cross_entropy(
|
|
prediction_soft_label, label_targets,
|
|
ignore_index=0, reduction='none')
|
|
return mrc_loss, prediction_soft_label
|
|
|
|
|
|
def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8,
|
|
weight_decay: float = 0.0, ) -> torch.optim.Optimizer:
|
|
no_decay = ['bias', 'LayerNorm.weight']
|
|
|
|
optimizer_grouped_parameters = [
|
|
{'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
|
'weight_decay': weight_decay},
|
|
{'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
|
]
|
|
optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps)
|
|
return optimizer
|
|
|
|
|
|
def setup_for_distributed_mode(model: nn.Module, optimizer: torch.optim.Optimizer, device: object, n_gpu: int = 1,
|
|
local_rank: int = -1,
|
|
fp16: bool = False,
|
|
fp16_opt_level: str = "O1",
|
|
teacher_model = None) -> (nn.Module, torch.optim.Optimizer):
|
|
model.to(device)
|
|
if teacher_model is not None:
|
|
teacher_model.to(device)
|
|
if fp16:
|
|
try:
|
|
import apex
|
|
from apex import amp
|
|
apex.amp.register_half_function(torch, "einsum")
|
|
except ImportError:
|
|
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
|
|
|
if optimizer is None:
|
|
if teacher_model is None:
|
|
model = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
|
|
else:
|
|
model, teacher_model = amp.initialize([model, teacher_model], optimizer, opt_level=fp16_opt_level)
|
|
else:
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
|
|
|
|
#if n_gpu > 1:
|
|
# model = torch.nn.DataParallel(model)
|
|
|
|
# if local_rank != -1:
|
|
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],
|
|
# output_device=local_rank,
|
|
# find_unused_parameters=True)
|
|
return model, optimizer
|
|
|
|
|
|
class BiEncoderNllLoss(object):
|
|
|
|
def calc(self, q_vectors: T, ctx_vectors: T, caption_vectors: T, positive_idx_per_question: list,
|
|
hard_negatice_idx_per_question: list = None, caption_score_weight: float = 0.1,
|
|
experiment=None, reduction='mean'):
|
|
"""
|
|
Computes nll loss for the given lists of question and ctx vectors.
|
|
Note that although hard_negatice_idx_per_question in not currently in use, one can use it for the
|
|
loss modifications. For example - weighted NLL with different factors for hard vs regular negatives.
|
|
:return: a tuple of loss value and amount of correct predictions per batch
|
|
"""
|
|
scores_img = self.get_scores(q_vectors, ctx_vectors)
|
|
if caption_vectors is not None and caption_score_weight != 0:
|
|
scores_caption = self.get_scores(q_vectors, caption_vectors)
|
|
scores = (1 - caption_score_weight) * scores_img + caption_score_weight * scores_caption
|
|
else:
|
|
scores = scores_img
|
|
|
|
if experiment is not None:
|
|
experiment.log_metric('score_img_diag_mean', torch.diag(scores_img).mean().item())
|
|
experiment.log_metric('score_img_offdiag_mean', (scores_img.sum() - torch.diag(scores_img).sum()) /
|
|
(torch.numel(scores_img)-len(torch.diag(scores_img))))
|
|
|
|
experiment.log_metric('score_diag_mean', torch.diag(scores).mean().item())
|
|
experiment.log_metric('score_offdiag_mean', (scores.sum() - torch.diag(scores).sum()) /
|
|
(torch.numel(scores) - len(torch.diag(scores))))
|
|
|
|
if caption_vectors is not None and caption_score_weight != 0:
|
|
experiment.log_metric('score_caption_diag_mean', torch.diag(scores_caption).mean().item())
|
|
experiment.log_metric('score_caption_offdiag_mean', (scores_caption.sum() - torch.diag(scores_caption).sum()) /
|
|
(torch.numel(scores_caption) - len(torch.diag(scores_caption))))
|
|
|
|
if len(q_vectors.size()) > 1:
|
|
q_num = q_vectors.size(0)
|
|
scores = scores.view(q_num, -1)
|
|
|
|
softmax_scores = F.log_softmax(scores, dim=1)
|
|
|
|
loss = F.nll_loss(softmax_scores, torch.tensor(positive_idx_per_question).to(softmax_scores.device),
|
|
reduction=reduction)
|
|
|
|
max_score, max_idxs = torch.max(softmax_scores, 1)
|
|
correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum()
|
|
return loss, correct_predictions_count, scores
|
|
|
|
@staticmethod
|
|
def get_scores(q_vector: T, ctx_vectors: T) -> T:
|
|
f = BiEncoderNllLoss.get_similarity_function()
|
|
return f(q_vector, ctx_vectors)
|
|
|
|
@staticmethod
|
|
def get_similarity_function():
|
|
return dot_product_scores
|
|
|
|
|
|
def get_schedule_linear(optimizer, warmup_steps, training_steps, last_epoch=-1):
|
|
""" Create a schedule with a learning rate that decreases linearly after
|
|
linearly increasing during a warmup period.
|
|
"""
|
|
|
|
def lr_lambda(current_step):
|
|
if current_step < warmup_steps:
|
|
return float(current_step) / float(max(1, warmup_steps))
|
|
return max(
|
|
0.0, float(training_steps - current_step) / float(max(1, training_steps - warmup_steps))
|
|
)
|
|
|
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
|
|
|
|
|
class BiEncoderForVisualQuestionAnswering(nn.Module):
|
|
""" Finetune multi-modal BERT for VQA
|
|
"""
|
|
def __init__(self, args, fix_img_encoder: bool = False, fix_txt_encoder: bool = False,
|
|
seperate_caption_encoder: bool = False,
|
|
project_dim: int = 0,
|
|
hidden_size: int = 0, num_answer: int = 0, intersection=False):
|
|
super(BiEncoderForVisualQuestionAnswering, self).__init__()
|
|
self.biencoder = BiEncoder(args, fix_img_encoder, fix_txt_encoder, project_dim)
|
|
self.intersection = intersection
|
|
if self.intersection:
|
|
hidden_size *= 2
|
|
self.vqa_output = nn.Sequential(
|
|
nn.Linear(hidden_size, hidden_size*2),
|
|
GELU(),
|
|
LayerNorm(hidden_size*2, eps=1e-12),
|
|
nn.Linear(hidden_size*2, num_answer)
|
|
)
|
|
self.init_weights(self.vqa_output)
|
|
|
|
def forward(self, batch, compute_loss=True, targets=None) -> Tuple[T, T]:
|
|
|
|
q_pooled, ctx_pooled, caption_pooled = self.biencoder(batch)
|
|
|
|
if self.intersection:
|
|
pooled_output = torch.cat([q_pooled, ctx_pooled, q_pooled*ctx_pooled, q_pooled + ctx_pooled], dim=1)
|
|
else:
|
|
pooled_output = torch.cat([q_pooled, ctx_pooled], dim=1)
|
|
|
|
answer_scores = self.vqa_output(pooled_output)
|
|
|
|
if compute_loss:
|
|
vqa_loss = F.binary_cross_entropy_with_logits(
|
|
answer_scores, targets, reduction='none')
|
|
return vqa_loss
|
|
else:
|
|
return answer_scores
|
|
|
|
def init_weights(self, module):
|
|
""" Initialize the weights.
|
|
"""
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
# Slightly different from the TF version which uses
|
|
# truncated_normal for initialization
|
|
# cf https://github.com/pytorch/pytorch/pull/5617
|
|
module.weight.data.normal_(mean=0.0,
|
|
std=0.02)
|
|
elif isinstance(module, LayerNorm):
|
|
module.bias.data.zero_()
|
|
module.weight.data.fill_(1.0)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
module.bias.data.zero_()
|
|
|
|
|
|
def load_biencoder_checkpoint(bi_encoder, biencoder_checkpoint):
|
|
if biencoder_checkpoint is not None and len(biencoder_checkpoint) > 0 and biencoder_checkpoint.lower() != 'none':
|
|
logger.info(f'loading ckpt from {biencoder_checkpoint}')
|
|
state_dict = torch.load(biencoder_checkpoint, map_location='cpu')
|
|
try:
|
|
bi_encoder.load_state_dict(state_dict['model_dict'])
|
|
except KeyError:
|
|
logger.info('loading from pre-trained model instead')
|
|
for k in list(state_dict.keys()):
|
|
if k.startswith('bert.'):
|
|
state_dict[k[5:]] = state_dict.pop(k)
|
|
else:
|
|
state_dict.pop(k)
|
|
bi_encoder.load_state_dict(state_dict, strict=True)
|
|
else:
|
|
logger.info('no checkpoint provided, pass')
|