""" Bert for Referring Expression Comprehension """ import sys import torch import torch.nn as nn from torch.nn import functional as F from pytorch_pretrained_bert.modeling import BertPreTrainedModel, BertLayerNorm from .model import BertVisionLanguageEncoder import numpy as np import random class BertForReferringExpressionComprehension(BertPreTrainedModel): """Finetune multi-model BERT for Referring Expression Comprehension """ def __init__(self, config, img_dim, loss="cls", margin=0.2, hard_ratio=0.3, mlp=1): super().__init__(config) self.bert = BertVisionLanguageEncoder(config, img_dim) if mlp == 1: self.re_output = nn.Linear(config.hidden_size, 1) elif mlp == 2: self.re_output = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size), nn.ReLU(), BertLayerNorm(config.hidden_size, eps=1e-12), nn.Linear(config.hidden_size, 1) ) else: sys.exit("MLP restricted to be 1 or 2 layers.") self.loss = loss assert self.loss in ['cls', 'rank'] if self.loss == 'rank': self.margin = margin self.hard_ratio = hard_ratio else: self.crit = nn.CrossEntropyLoss(reduction='none') # initialize self.apply(self.init_bert_weights) def forward(self, input_ids, position_ids, txt_lens, img_feat, img_pos_feat, num_bbs, attn_masks, obj_masks, targets, compute_loss=True): sequence_output = self.bert(input_ids, position_ids, txt_lens, img_feat, img_pos_feat, num_bbs, attn_masks, output_all_encoded_layers=False) # get only the region part sequence_output = self._get_image_hidden(sequence_output, txt_lens, num_bbs) # re score (n, max_num_bb) scores = self.re_output(sequence_output).squeeze(2) scores = scores.masked_fill(obj_masks, -1e4) # mask out non-objects # loss if compute_loss: if self.loss == 'cls': ce_loss = self.crit(scores, targets) # (n, ) as no reduction return ce_loss else: # ranking _n = len(num_bbs) # positive (target) pos_ix = targets pos_sc = scores.gather(1, pos_ix.view(_n, 1)) # (n, 1) pos_sc = torch.sigmoid(pos_sc).view(-1) # (n, ) sc[0, 1] # negative neg_ix = self.sample_neg_ix(scores, targets, num_bbs) neg_sc = scores.gather(1, neg_ix.view(_n, 1)) # (n, 1) neg_sc = torch.sigmoid(neg_sc).view(-1) # (n, ) sc[0, 1] # ranking mm_loss = torch.clamp(self.margin + neg_sc - pos_sc, 0) # (n, ) return mm_loss else: # (n, max_num_bb) return scores def sample_neg_ix(self, scores, targets, num_bbs): """ Inputs: :scores (n, max_num_bb) :targets (n, ) :num_bbs list of [num_bb] return: :neg_ix (n, ) easy/hard negative (!= target) """ neg_ix = [] cand_ixs = torch.argsort(scores, dim=-1, descending=True) # (n, num_bb) for i in range(len(num_bbs)): num_bb = num_bbs[i] if np.random.uniform(0, 1, 1) < self.hard_ratio: # sample hard negative, w/ highest score for ix in cand_ixs[i].tolist(): if ix != targets[i]: assert ix < num_bb, f'ix={ix}, num_bb={num_bb}' neg_ix.append(ix) break else: # sample easy negative, i.e., random one ix = random.randint(0, num_bb-1) # [0, num_bb-1] while ix == targets[i]: ix = random.randint(0, num_bb-1) neg_ix.append(ix) neg_ix = torch.tensor(neg_ix).type(targets.type()) assert neg_ix.numel() == targets.numel() return neg_ix def _get_image_hidden(self, sequence_output, txt_lens, num_bbs): """ Extracting the img_hidden part from sequence_output. Inputs: - sequence_output: (n, txt_len+num_bb, hid_size) - txt_lens : [txt_len] - num_bbs : [num_bb] Output: - img_hidden : (n, max_num_bb, hid_size) """ outputs = [] max_bb = max(num_bbs) hid_size = sequence_output.size(-1) for seq_out, len_, nbb in zip(sequence_output.split(1, dim=0), txt_lens, num_bbs): img_hid = seq_out[:, len_:len_+nbb, :] if nbb < max_bb: img_hid = torch.cat( [img_hid, self._get_pad(img_hid, max_bb-nbb, hid_size)], dim=1) outputs.append(img_hid) img_hidden = torch.cat(outputs, dim=0) return img_hidden def _get_pad(self, t, len_, hidden_size): pad = torch.zeros(1, len_, hidden_size, dtype=t.dtype, device=t.device) return pad