logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

133 lines
6.0 KiB

"""
Bert for VCR model
"""
from torch import nn
from torch.nn import functional as F
from pytorch_pretrained_bert.modeling import (
BertOnlyMLMHead)
from .model import (BertForImageTextPretraining,
_get_image_hidden,
mask_img_feat,
RegionFeatureRegression,
mask_img_feat_for_mrc,
RegionClassification)
import torch
import random
class BertForImageTextPretrainingForGQA(BertForImageTextPretraining):
def init_type_embedding(self):
new_emb = nn.Embedding(3, self.bert.config.hidden_size)
new_emb.apply(self.init_bert_weights)
for i in [0, 1]:
emb = self.bert.embeddings.token_type_embeddings.weight.data[i, :]
new_emb.weight.data[i, :].copy_(emb)
emb = self.bert.embeddings.token_type_embeddings.weight.data[0, :]
new_emb.weight.data[2, :].copy_(emb)
self.bert.embeddings.token_type_embeddings = new_emb
def forward(self, input_ids, position_ids, txt_type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, labels, task, compute_loss=True):
if task == 'mlm':
txt_labels = labels
return self.forward_mlm(input_ids, position_ids, txt_type_ids,
txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, txt_labels, compute_loss)
elif task == 'mrm':
img_mask = labels
return self.forward_mrm(input_ids, position_ids, txt_type_ids,
txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, img_mask, compute_loss)
elif task.startswith('mrc'):
img_mask, mrc_label_target = labels
return self.forward_mrc(input_ids, position_ids, txt_type_ids,
txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, img_mask,
mrc_label_target, task, compute_loss)
else:
raise ValueError('invalid task')
# MLM
def forward_mlm(self, input_ids, position_ids, txt_type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, txt_labels, compute_loss=True):
sequence_output = self.bert(input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask,
output_all_encoded_layers=False,
txt_type_ids=txt_type_ids)
# get only the text part
sequence_output = sequence_output[:, :input_ids.size(1), :]
# only compute masked tokens for better efficiency
prediction_scores = self.masked_compute_scores(
sequence_output, txt_labels != -1)
if self.vocab_pad:
prediction_scores = prediction_scores[:, :-self.vocab_pad]
if compute_loss:
masked_lm_loss = F.cross_entropy(prediction_scores,
txt_labels[txt_labels != -1],
reduction='none')
return masked_lm_loss
else:
return prediction_scores
# MRM
def forward_mrm(self, input_ids, position_ids, txt_type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, img_masks, compute_loss=True):
img_feat, feat_targets = mask_img_feat(img_feat, img_masks)
sequence_output = self.bert(input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask,
output_all_encoded_layers=False,
txt_type_ids=txt_type_ids)
# get only the text part
sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs)
# only compute masked tokens for better efficiency
prediction_feat = self.masked_compute_feat(
sequence_output, img_masks)
if compute_loss:
mrm_loss = F.mse_loss(prediction_feat, feat_targets,
reduction='none')
return mrm_loss
else:
return prediction_feat
# MRC
def forward_mrc(self, input_ids, position_ids, txt_type_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask, img_masks,
label_targets, task, compute_loss=True):
img_feat = mask_img_feat_for_mrc(img_feat, img_masks)
sequence_output = self.bert(input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attention_mask,
output_all_encoded_layers=False,
txt_type_ids=txt_type_ids)
# get only the image part
sequence_output = _get_image_hidden(sequence_output, txt_lens, num_bbs)
# only compute masked tokens for better efficiency
prediction_soft_label = self.masked_predict_labels(
sequence_output, img_masks)
if compute_loss:
if "kl" in task:
prediction_soft_label = F.log_softmax(
prediction_soft_label, dim=-1)
mrc_loss = F.kl_div(
prediction_soft_label, label_targets, reduction='none')
else:
label_targets = torch.max(
label_targets, -1)[1] # argmax
mrc_loss = F.cross_entropy(
prediction_soft_label, label_targets,
ignore_index=0, reduction='none')
return mrc_loss
else:
return prediction_soft_label