""" UNITER for ITM model """ import copy from collections import defaultdict import torch from torch import nn from .model import UniterPreTrainedModel, UniterModel class UniterForImageTextRetrieval(UniterPreTrainedModel): """ Finetune UNITER for image text retrieval """ def __init__(self, config, img_dim, margin=0.2): super().__init__(config) self.bert = UniterModel(config, img_dim) self.itm_output = nn.Linear(config.hidden_size, 2) self.rank_output = nn.Linear(config.hidden_size, 1) self.margin = margin self.apply(self.init_weights) def init_output(self): """ need to be called after from pretrained """ self.rank_output.weight.data = self.itm_output.weight.data[1:, :] self.rank_output.bias.data = self.itm_output.bias.data[1:] def forward(self, batch, compute_loss=True): batch = defaultdict(lambda: None, batch) input_ids = batch['input_ids'] position_ids = batch['position_ids'] img_feat = batch['img_feat'] img_pos_feat = batch['img_pos_feat'] attention_mask = batch['attn_masks'] gather_index = batch['gather_index'] sequence_output = self.bert(input_ids, position_ids, img_feat, img_pos_feat, attention_mask, gather_index, output_all_encoded_layers=False) pooled_output = self.bert.pooler(sequence_output) rank_scores = self.rank_output(pooled_output) if compute_loss: # triplet loss rank_scores_sigmoid = torch.sigmoid(rank_scores) sample_size = batch['sample_size'] scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) pos = scores[:, :1] neg = scores[:, 1:] rank_loss = torch.clamp(self.margin + neg - pos, 0) return rank_loss else: return rank_scores class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval): """ Finetune UNITER for image text retrieval """ def __init__(self, config, img_dim, margin=0.2, hard_size=16): super().__init__(config, img_dim, margin) self.hard_size = hard_size def forward(self, batch, sample_from='t', compute_loss=True): # expect same input_ids for all pairs batch_size = batch['attn_masks'].size(0) input_ids = batch['input_ids'] img_feat = batch['img_feat'] img_pos_feat = batch['img_pos_feat'] if sample_from == 't': if input_ids.size(0) == 1: batch['input_ids'] = input_ids.expand(batch_size, -1) elif sample_from == 'i': if img_feat.size(0) == 1: batch['img_feat'] = img_feat.expand(batch_size, -1, -1) if img_pos_feat.size(0) == 1: batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1) else: raise ValueError() if self.training and compute_loss: with torch.no_grad(): self.eval() scores = super().forward(batch, compute_loss=False) hard_batch = self._get_hard_batch(batch, scores, sample_from) self.train() return super().forward(hard_batch, compute_loss=True) else: return super().forward(batch, compute_loss) def _get_hard_batch(self, batch, scores, sample_from='t'): batch = defaultdict(lambda: None, batch) input_ids = batch['input_ids'] position_ids = batch['position_ids'] img_feat = batch['img_feat'] img_pos_feat = batch['img_pos_feat'] attention_mask = batch['attn_masks'] gather_index = batch['gather_index'] hard_batch = {'sample_size': self.hard_size + 1} # NOTE first example is positive hard_indices = scores.squeeze(-1)[1:].topk( self.hard_size, sorted=False)[1] + 1 indices = torch.cat([torch.zeros(1, dtype=torch.long, device=hard_indices.device), hard_indices]) attention_mask = attention_mask.index_select(0, indices) gather_index = gather_index.index_select(0, indices) if position_ids.size(0) != 1: position_ids = position_ids[:self.hard_size+1] if sample_from == 't': # cut to minimum padding max_len = attention_mask.sum(dim=1).max().item() max_i = max_len - input_ids.size(1) attention_mask = attention_mask[:, :max_len] gather_index = gather_index[:, :max_len] img_feat = img_feat.index_select(0, indices)[:, :max_i, :] img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :] # expect same input_ids for all pairs input_ids = input_ids[:self.hard_size+1] elif sample_from == 'i': input_ids = input_ids.index_select(0, indices) # expect same image features for all pairs img_feat = img_feat[:self.hard_size+1] img_pos_feat = img_pos_feat[:self.hard_size+1] else: raise ValueError() hard_batch['input_ids'] = input_ids hard_batch['position_ids'] = position_ids hard_batch['img_feat'] = img_feat hard_batch['img_pos_feat'] = img_pos_feat hard_batch['attn_masks'] = attention_mask hard_batch['gather_index'] = gather_index return hard_batch class UniterForImageTextRetrievalFast(UniterPreTrainedModel): """ Finetune UNITER for image text retrieval """ def __init__(self, config, img_dim, margin=0.2): super().__init__(config) self.bert = UniterModel(config, img_dim) config_img = copy.deepcopy(config) config_img.num_hidden_layers = config_img.num_hidden_layers_img self.img_bert = UniterModel(config_img, img_dim) self.itm_output = nn.Linear(config.hidden_size, 2) self.rank_output = nn.Linear(config.hidden_size, 1) self.margin = margin self.apply(self.init_weights) def init_output(self): """ need to be called after from pretrained """ self.rank_output.weight.data = self.itm_output.weight.data[1:, :] self.rank_output.bias.data = self.itm_output.bias.data[1:] def forward(self, batch, compute_loss=True): batch = defaultdict(lambda: None, batch) input_ids = batch['input_ids'] position_ids = batch['position_ids'] img_feat = batch['img_feat'] img_pos_feat = batch['img_pos_feat'] attention_mask_text = batch['attn_masks_text'] attention_mask_img = batch['attn_masks_img'] gather_index = batch['gather_index'] sequence_output_text = self.bert(input_ids, position_ids, None, img_pos_feat, attention_mask_text, gather_index, output_all_encoded_layers=False) pooled_output_text = self.bert.pooler(sequence_output_text) sequence_output_img = self.img_bert(None, position_ids, img_feat, img_pos_feat, attention_mask_img, gather_index, output_all_encoded_layers=False) pooled_output_img = self.img_bert.pooler(sequence_output_img) # rank_scores = (pooled_output_text * pooled_output_img).sum(-1) # rank_scores = self.rank_output(pooled_output) rank_scores = torch.nn.CosineSimilarity()(pooled_output_text, pooled_output_img) if compute_loss: # triplet loss rank_scores_sigmoid = torch.sigmoid(rank_scores) sample_size = batch['sample_size'] scores = rank_scores_sigmoid.contiguous().view(-1, sample_size) pos = scores[:, :1] neg = scores[:, 1:] rank_loss = torch.clamp(self.margin + neg - pos, 0) return rank_loss else: return rank_scores