logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

140 lines
5.3 KiB

"""
Bert for Referring Expression Comprehension
"""
import sys
import torch
import torch.nn as nn
from torch.nn import functional as F
from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertLayerNorm
from .model import BertVisionLanguageEncoder
import numpy as np
import random
class BertForReferringExpressionComprehension(BertPreTrainedModel):
"""Finetune multi-model BERT for Referring Expression Comprehension
"""
def __init__(self, config, img_dim, loss="cls",
margin=0.2, hard_ratio=0.3, mlp=1):
super().__init__(config)
self.bert = BertVisionLanguageEncoder(config, img_dim)
if mlp == 1:
self.re_output = nn.Linear(config.hidden_size, 1)
elif mlp == 2:
self.re_output = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.ReLU(),
BertLayerNorm(config.hidden_size, eps=1e-12),
nn.Linear(config.hidden_size, 1)
)
else:
sys.exit("MLP restricted to be 1 or 2 layers.")
self.loss = loss
assert self.loss in ['cls', 'rank']
if self.loss == 'rank':
self.margin = margin
self.hard_ratio = hard_ratio
else:
self.crit = nn.CrossEntropyLoss(reduction='none')
# initialize
self.apply(self.init_bert_weights)
def forward(self, input_ids, position_ids, txt_lens, img_feat,
img_pos_feat, num_bbs, attn_masks, obj_masks,
targets, compute_loss=True):
sequence_output = self.bert(input_ids, position_ids, txt_lens,
img_feat, img_pos_feat, num_bbs,
attn_masks,
output_all_encoded_layers=False)
# get only the region part
sequence_output = self._get_image_hidden(sequence_output, txt_lens,
num_bbs)
# re score (n, max_num_bb)
scores = self.re_output(sequence_output).squeeze(2)
scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects
# loss
if compute_loss:
if self.loss == 'cls':
ce_loss = self.crit(scores, targets) # (n, ) as no reduction
return ce_loss
else:
# ranking
_n = len(num_bbs)
# positive (target)
pos_ix = targets
pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1)
pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1]
# negative
neg_ix = self.sample_neg_ix(scores, targets, num_bbs)
neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1)
neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1]
# ranking
mm_loss = torch.clamp(self.margin + neg_sc - pos_sc, 0) # (n, )
return mm_loss
else:
# (n, max_num_bb)
return scores
def sample_neg_ix(self, scores, targets, num_bbs):
"""
Inputs:
:scores (n, max_num_bb)
:targets (n, )
:num_bbs list of [num_bb]
return:
:neg_ix (n, ) easy/hard negative (!= target)
"""
neg_ix = []
cand_ixs = torch.argsort(scores, dim=-1, descending=True) # (n, num_bb)
for i in range(len(num_bbs)):
num_bb = num_bbs[i]
if np.random.uniform(0, 1, 1) < self.hard_ratio:
# sample hard negative, w/ highest score
for ix in cand_ixs[i].tolist():
if ix != targets[i]:
assert ix < num_bb, f'ix={ix}, num_bb={num_bb}'
neg_ix.append(ix)
break
else:
# sample easy negative, i.e., random one
ix = random.randint(0, num_bb-1) # [0, num_bb-1]
while ix == targets[i]:
ix = random.randint(0, num_bb-1)
neg_ix.append(ix)
neg_ix = torch.tensor(neg_ix).type(targets.type())
assert neg_ix.numel() == targets.numel()
return neg_ix
def _get_image_hidden(self, sequence_output, txt_lens, num_bbs):
"""
Extracting the img_hidden part from sequence_output.
Inputs:
- sequence_output: (n, txt_len+num_bb, hid_size)
- txt_lens : [txt_len]
- num_bbs : [num_bb]
Output:
- img_hidden : (n, max_num_bb, hid_size)
"""
outputs = []
max_bb = max(num_bbs)
hid_size = sequence_output.size(-1)
for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0),
txt_lens, num_bbs):
img_hid = seq_out[:, len_:len_+nbb, :]
if nbb < max_bb:
img_hid = torch.cat(
[img_hid, self._get_pad(img_hid, max_bb-nbb, hid_size)],
dim=1)
outputs.append(img_hid)
img_hidden = torch.cat(outputs, dim=0)
return img_hidden
def _get_pad(self, t, len_, hidden_size):
pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device)
return pad