lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
49 lines
1.7 KiB
49 lines
1.7 KiB
"""
|
|
Bert for VQA model
|
|
"""
|
|
from collections import defaultdict
|
|
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
from .layer import GELU
|
|
from .model import UniterPreTrainedModel, UniterModel
|
|
|
|
LayerNorm = nn.LayerNorm
|
|
|
|
class UniterForVisualQuestionAnswering(UniterPreTrainedModel):
|
|
""" Finetune multi-modal BERT for VQA
|
|
"""
|
|
def __init__(self, config, img_dim, num_answer):
|
|
super().__init__(config)
|
|
self.bert = UniterModel(config, img_dim)
|
|
self.vqa_output = nn.Sequential(
|
|
nn.Linear(config.hidden_size, config.hidden_size*2),
|
|
GELU(),
|
|
LayerNorm(config.hidden_size*2, eps=1e-12),
|
|
nn.Linear(config.hidden_size*2, num_answer)
|
|
)
|
|
self.apply(self.init_weights)
|
|
|
|
def forward(self, batch, compute_loss=True):
|
|
batch = defaultdict(lambda: None, batch)
|
|
input_ids = batch['input_ids']
|
|
position_ids = batch['position_ids']
|
|
img_feat = batch['img_feat']
|
|
img_pos_feat = batch['img_pos_feat']
|
|
attn_masks = batch['attn_masks']
|
|
gather_index = batch['gather_index']
|
|
sequence_output = self.bert(input_ids, position_ids,
|
|
img_feat, img_pos_feat,
|
|
attn_masks, gather_index,
|
|
output_all_encoded_layers=False)
|
|
pooled_output = self.bert.pooler(sequence_output)
|
|
answer_scores = self.vqa_output(pooled_output)
|
|
|
|
if compute_loss:
|
|
targets = batch['targets']
|
|
vqa_loss = F.binary_cross_entropy_with_logits(
|
|
answer_scores, targets, reduction='none')
|
|
return vqa_loss
|
|
else:
|
|
return answer_scores
|