logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

196 lines
8.0 KiB

"""
UNITER for ITM model
"""
import copy
from collections import defaultdict
import torch
from torch import nn
from .model import UniterPreTrainedModel, UniterModel
class UniterForImageTextRetrieval(UniterPreTrainedModel):
""" Finetune UNITER for image text retrieval
"""
def __init__(self, config, img_dim, margin=0.2):
super().__init__(config)
self.bert = UniterModel(config, img_dim)
self.itm_output = nn.Linear(config.hidden_size, 2)
self.rank_output = nn.Linear(config.hidden_size, 1)
self.margin = margin
self.apply(self.init_weights)
def init_output(self):
""" need to be called after from pretrained """
self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
self.rank_output.bias.data = self.itm_output.bias.data[1:]
def forward(self, batch, compute_loss=True):
batch = defaultdict(lambda: None, batch)
input_ids = batch['input_ids']
position_ids = batch['position_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
attention_mask = batch['attn_masks']
gather_index = batch['gather_index']
sequence_output = self.bert(input_ids, position_ids,
img_feat, img_pos_feat,
attention_mask, gather_index,
output_all_encoded_layers=False)
pooled_output = self.bert.pooler(sequence_output)
rank_scores = self.rank_output(pooled_output)
if compute_loss:
# triplet loss
rank_scores_sigmoid = torch.sigmoid(rank_scores)
sample_size = batch['sample_size']
scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
pos = scores[:, :1]
neg = scores[:, 1:]
rank_loss = torch.clamp(self.margin + neg - pos, 0)
return rank_loss
else:
return rank_scores
class UniterForImageTextRetrievalHardNeg(UniterForImageTextRetrieval):
""" Finetune UNITER for image text retrieval
"""
def __init__(self, config, img_dim, margin=0.2, hard_size=16):
super().__init__(config, img_dim, margin)
self.hard_size = hard_size
def forward(self, batch, sample_from='t', compute_loss=True):
# expect same input_ids for all pairs
batch_size = batch['attn_masks'].size(0)
input_ids = batch['input_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
if sample_from == 't':
if input_ids.size(0) == 1:
batch['input_ids'] = input_ids.expand(batch_size, -1)
elif sample_from == 'i':
if img_feat.size(0) == 1:
batch['img_feat'] = img_feat.expand(batch_size, -1, -1)
if img_pos_feat.size(0) == 1:
batch['img_pos_feat'] = img_pos_feat.expand(batch_size, -1, -1)
else:
raise ValueError()
if self.training and compute_loss:
with torch.no_grad():
self.eval()
scores = super().forward(batch, compute_loss=False)
hard_batch = self._get_hard_batch(batch, scores, sample_from)
self.train()
return super().forward(hard_batch, compute_loss=True)
else:
return super().forward(batch, compute_loss)
def _get_hard_batch(self, batch, scores, sample_from='t'):
batch = defaultdict(lambda: None, batch)
input_ids = batch['input_ids']
position_ids = batch['position_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
attention_mask = batch['attn_masks']
gather_index = batch['gather_index']
hard_batch = {'sample_size': self.hard_size + 1}
# NOTE first example is positive
hard_indices = scores.squeeze(-1)[1:].topk(
self.hard_size, sorted=False)[1] + 1
indices = torch.cat([torch.zeros(1, dtype=torch.long,
device=hard_indices.device),
hard_indices])
attention_mask = attention_mask.index_select(0, indices)
gather_index = gather_index.index_select(0, indices)
if position_ids.size(0) != 1:
position_ids = position_ids[:self.hard_size+1]
if sample_from == 't':
# cut to minimum padding
max_len = attention_mask.sum(dim=1).max().item()
max_i = max_len - input_ids.size(1)
attention_mask = attention_mask[:, :max_len]
gather_index = gather_index[:, :max_len]
img_feat = img_feat.index_select(0, indices)[:, :max_i, :]
img_pos_feat = img_pos_feat.index_select(0, indices)[:, :max_i, :]
# expect same input_ids for all pairs
input_ids = input_ids[:self.hard_size+1]
elif sample_from == 'i':
input_ids = input_ids.index_select(0, indices)
# expect same image features for all pairs
img_feat = img_feat[:self.hard_size+1]
img_pos_feat = img_pos_feat[:self.hard_size+1]
else:
raise ValueError()
hard_batch['input_ids'] = input_ids
hard_batch['position_ids'] = position_ids
hard_batch['img_feat'] = img_feat
hard_batch['img_pos_feat'] = img_pos_feat
hard_batch['attn_masks'] = attention_mask
hard_batch['gather_index'] = gather_index
return hard_batch
class UniterForImageTextRetrievalFast(UniterPreTrainedModel):
""" Finetune UNITER for image text retrieval
"""
def __init__(self, config, img_dim, margin=0.2):
super().__init__(config)
self.bert = UniterModel(config, img_dim)
config_img = copy.deepcopy(config)
config_img.num_hidden_layers = config_img.num_hidden_layers_img
self.img_bert = UniterModel(config_img, img_dim)
self.itm_output = nn.Linear(config.hidden_size, 2)
self.rank_output = nn.Linear(config.hidden_size, 1)
self.margin = margin
self.apply(self.init_weights)
def init_output(self):
""" need to be called after from pretrained """
self.rank_output.weight.data = self.itm_output.weight.data[1:, :]
self.rank_output.bias.data = self.itm_output.bias.data[1:]
def forward(self, batch, compute_loss=True):
batch = defaultdict(lambda: None, batch)
input_ids = batch['input_ids']
position_ids = batch['position_ids']
img_feat = batch['img_feat']
img_pos_feat = batch['img_pos_feat']
attention_mask_text = batch['attn_masks_text']
attention_mask_img = batch['attn_masks_img']
gather_index = batch['gather_index']
sequence_output_text = self.bert(input_ids, position_ids,
None, img_pos_feat,
attention_mask_text, gather_index,
output_all_encoded_layers=False)
pooled_output_text = self.bert.pooler(sequence_output_text)
sequence_output_img = self.img_bert(None, position_ids,
img_feat, img_pos_feat,
attention_mask_img, gather_index,
output_all_encoded_layers=False)
pooled_output_img = self.img_bert.pooler(sequence_output_img)
# rank_scores = (pooled_output_text * pooled_output_img).sum(-1)
# rank_scores = self.rank_output(pooled_output)
rank_scores = torch.nn.CosineSimilarity()(pooled_output_text, pooled_output_img)
if compute_loss:
# triplet loss
rank_scores_sigmoid = torch.sigmoid(rank_scores)
sample_size = batch['sample_size']
scores = rank_scores_sigmoid.contiguous().view(-1, sample_size)
pos = scores[:, :1]
neg = scores[:, 1:]
rank_loss = torch.clamp(self.margin + neg - pos, 0)
return rank_loss
else:
return rank_scores