lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
141 lines
5.3 KiB
141 lines
5.3 KiB
2 years ago
|
"""
|
||
|
Bert for Referring Expression Comprehension
|
||
|
"""
|
||
|
import sys
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.nn import functional as F
|
||
|
from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertLayerNorm
|
||
|
|
||
|
from .model import BertVisionLanguageEncoder
|
||
|
|
||
|
import numpy as np
|
||
|
import random
|
||
|
|
||
|
|
||
|
class BertForReferringExpressionComprehension(BertPreTrainedModel):
|
||
|
"""Finetune multi-model BERT for Referring Expression Comprehension
|
||
|
"""
|
||
|
def __init__(self, config, img_dim, loss="cls",
|
||
|
margin=0.2, hard_ratio=0.3, mlp=1):
|
||
|
super().__init__(config)
|
||
|
self.bert = BertVisionLanguageEncoder(config, img_dim)
|
||
|
if mlp == 1:
|
||
|
self.re_output = nn.Linear(config.hidden_size, 1)
|
||
|
elif mlp == 2:
|
||
|
self.re_output = nn.Sequential(
|
||
|
nn.Linear(config.hidden_size, config.hidden_size),
|
||
|
nn.ReLU(),
|
||
|
BertLayerNorm(config.hidden_size, eps=1e-12),
|
||
|
nn.Linear(config.hidden_size, 1)
|
||
|
)
|
||
|
else:
|
||
|
sys.exit("MLP restricted to be 1 or 2 layers.")
|
||
|
self.loss = loss
|
||
|
assert self.loss in ['cls', 'rank']
|
||
|
if self.loss == 'rank':
|
||
|
self.margin = margin
|
||
|
self.hard_ratio = hard_ratio
|
||
|
else:
|
||
|
self.crit = nn.CrossEntropyLoss(reduction='none')
|
||
|
# initialize
|
||
|
self.apply(self.init_bert_weights)
|
||
|
|
||
|
def forward(self, input_ids, position_ids, txt_lens, img_feat,
|
||
|
img_pos_feat, num_bbs, attn_masks, obj_masks,
|
||
|
targets, compute_loss=True):
|
||
|
sequence_output = self.bert(input_ids, position_ids, txt_lens,
|
||
|
img_feat, img_pos_feat, num_bbs,
|
||
|
attn_masks,
|
||
|
output_all_encoded_layers=False)
|
||
|
# get only the region part
|
||
|
sequence_output = self._get_image_hidden(sequence_output, txt_lens,
|
||
|
num_bbs)
|
||
|
|
||
|
# re score (n, max_num_bb)
|
||
|
scores = self.re_output(sequence_output).squeeze(2)
|
||
|
scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects
|
||
|
|
||
|
# loss
|
||
|
if compute_loss:
|
||
|
if self.loss == 'cls':
|
||
|
ce_loss = self.crit(scores, targets) # (n, ) as no reduction
|
||
|
return ce_loss
|
||
|
else:
|
||
|
# ranking
|
||
|
_n = len(num_bbs)
|
||
|
# positive (target)
|
||
|
pos_ix = targets
|
||
|
pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1)
|
||
|
pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1]
|
||
|
# negative
|
||
|
neg_ix = self.sample_neg_ix(scores, targets, num_bbs)
|
||
|
neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1)
|
||
|
neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1]
|
||
|
# ranking
|
||
|
mm_loss = torch.clamp(self.margin + neg_sc - pos_sc, 0) # (n, )
|
||
|
return mm_loss
|
||
|
else:
|
||
|
# (n, max_num_bb)
|
||
|
return scores
|
||
|
|
||
|
def sample_neg_ix(self, scores, targets, num_bbs):
|
||
|
"""
|
||
|
Inputs:
|
||
|
:scores (n, max_num_bb)
|
||
|
:targets (n, )
|
||
|
:num_bbs list of [num_bb]
|
||
|
return:
|
||
|
:neg_ix (n, ) easy/hard negative (!= target)
|
||
|
"""
|
||
|
neg_ix = []
|
||
|
cand_ixs = torch.argsort(scores, dim=-1, descending=True) # (n, num_bb)
|
||
|
for i in range(len(num_bbs)):
|
||
|
num_bb = num_bbs[i]
|
||
|
if np.random.uniform(0, 1, 1) < self.hard_ratio:
|
||
|
# sample hard negative, w/ highest score
|
||
|
for ix in cand_ixs[i].tolist():
|
||
|
if ix != targets[i]:
|
||
|
assert ix < num_bb, f'ix={ix}, num_bb={num_bb}'
|
||
|
neg_ix.append(ix)
|
||
|
break
|
||
|
else:
|
||
|
# sample easy negative, i.e., random one
|
||
|
ix = random.randint(0, num_bb-1) # [0, num_bb-1]
|
||
|
while ix == targets[i]:
|
||
|
ix = random.randint(0, num_bb-1)
|
||
|
neg_ix.append(ix)
|
||
|
neg_ix = torch.tensor(neg_ix).type(targets.type())
|
||
|
assert neg_ix.numel() == targets.numel()
|
||
|
return neg_ix
|
||
|
|
||
|
def _get_image_hidden(self, sequence_output, txt_lens, num_bbs):
|
||
|
"""
|
||
|
Extracting the img_hidden part from sequence_output.
|
||
|
Inputs:
|
||
|
- sequence_output: (n, txt_len+num_bb, hid_size)
|
||
|
- txt_lens : [txt_len]
|
||
|
- num_bbs : [num_bb]
|
||
|
Output:
|
||
|
- img_hidden : (n, max_num_bb, hid_size)
|
||
|
"""
|
||
|
outputs = []
|
||
|
max_bb = max(num_bbs)
|
||
|
hid_size = sequence_output.size(-1)
|
||
|
for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0),
|
||
|
txt_lens, num_bbs):
|
||
|
img_hid = seq_out[:, len_:len_+nbb, :]
|
||
|
if nbb < max_bb:
|
||
|
img_hid = torch.cat(
|
||
|
[img_hid, self._get_pad(img_hid, max_bb-nbb, hid_size)],
|
||
|
dim=1)
|
||
|
outputs.append(img_hid)
|
||
|
|
||
|
img_hidden = torch.cat(outputs, dim=0)
|
||
|
return img_hidden
|
||
|
|
||
|
def _get_pad(self, t, len_, hidden_size):
|
||
|
pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device)
|
||
|
return pad
|
||
|
|