lightningdot
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
183 lines
7.5 KiB
183 lines
7.5 KiB
2 years ago
|
"""
|
||
|
Copyright (c) Microsoft Corporation.
|
||
|
Licensed under the MIT license.
|
||
|
|
||
|
Uniter for NLVR2 model
|
||
|
"""
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
from .layer import GELU
|
||
|
from .model import UniterPreTrainedModel, UniterModel
|
||
|
from .attention import MultiheadAttention
|
||
|
|
||
|
|
||
|
class UniterForNlvr2Paired(UniterPreTrainedModel):
|
||
|
""" Finetune UNITER for NLVR2 (paired format)
|
||
|
"""
|
||
|
def __init__(self, config, img_dim):
|
||
|
super().__init__(config)
|
||
|
self.bert = UniterModel(config, img_dim)
|
||
|
self.nlvr2_output = nn.Linear(config.hidden_size*2, 2)
|
||
|
self.apply(self.init_weights)
|
||
|
|
||
|
def init_type_embedding(self):
|
||
|
new_emb = nn.Embedding(3, self.bert.config.hidden_size)
|
||
|
new_emb.apply(self.init_weights)
|
||
|
for i in [0, 1]:
|
||
|
emb = self.bert.embeddings.token_type_embeddings\
|
||
|
.weight.data[i, :]
|
||
|
new_emb.weight.data[i, :].copy_(emb)
|
||
|
new_emb.weight.data[2, :].copy_(emb)
|
||
|
self.bert.embeddings.token_type_embeddings = new_emb
|
||
|
|
||
|
def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
|
||
|
attn_masks, gather_index,
|
||
|
img_type_ids, targets, compute_loss=True):
|
||
|
sequence_output = self.bert(input_ids, position_ids,
|
||
|
img_feat, img_pos_feat,
|
||
|
attn_masks, gather_index,
|
||
|
output_all_encoded_layers=False,
|
||
|
img_type_ids=img_type_ids)
|
||
|
pooled_output = self.bert.pooler(sequence_output)
|
||
|
# concat CLS of the pair
|
||
|
n_pair = pooled_output.size(0) // 2
|
||
|
reshaped_output = pooled_output.contiguous().view(n_pair, -1)
|
||
|
answer_scores = self.nlvr2_output(reshaped_output)
|
||
|
|
||
|
if compute_loss:
|
||
|
nlvr2_loss = F.cross_entropy(
|
||
|
answer_scores, targets, reduction='none')
|
||
|
return nlvr2_loss
|
||
|
else:
|
||
|
return answer_scores
|
||
|
|
||
|
|
||
|
class UniterForNlvr2Triplet(UniterPreTrainedModel):
|
||
|
""" Finetune UNITER for NLVR2 (triplet format)
|
||
|
"""
|
||
|
def __init__(self, config, img_dim):
|
||
|
super().__init__(config)
|
||
|
self.bert = UniterModel(config, img_dim)
|
||
|
self.nlvr2_output = nn.Linear(config.hidden_size, 2)
|
||
|
self.apply(self.init_weights)
|
||
|
|
||
|
def init_type_embedding(self):
|
||
|
new_emb = nn.Embedding(3, self.bert.config.hidden_size)
|
||
|
new_emb.apply(self.init_weights)
|
||
|
for i in [0, 1]:
|
||
|
emb = self.bert.embeddings.token_type_embeddings\
|
||
|
.weight.data[i, :]
|
||
|
new_emb.weight.data[i, :].copy_(emb)
|
||
|
new_emb.weight.data[2, :].copy_(emb)
|
||
|
self.bert.embeddings.token_type_embeddings = new_emb
|
||
|
|
||
|
def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
|
||
|
attn_masks, gather_index,
|
||
|
img_type_ids, targets, compute_loss=True):
|
||
|
sequence_output = self.bert(input_ids, position_ids,
|
||
|
img_feat, img_pos_feat,
|
||
|
attn_masks, gather_index,
|
||
|
output_all_encoded_layers=False,
|
||
|
img_type_ids=img_type_ids)
|
||
|
pooled_output = self.bert.pooler(sequence_output)
|
||
|
answer_scores = self.nlvr2_output(pooled_output)
|
||
|
|
||
|
if compute_loss:
|
||
|
nlvr2_loss = F.cross_entropy(
|
||
|
answer_scores, targets, reduction='none')
|
||
|
return nlvr2_loss
|
||
|
else:
|
||
|
return answer_scores
|
||
|
|
||
|
|
||
|
class AttentionPool(nn.Module):
|
||
|
""" attention pooling layer """
|
||
|
def __init__(self, hidden_size, drop=0.0):
|
||
|
super().__init__()
|
||
|
self.fc = nn.Sequential(nn.Linear(hidden_size, 1), GELU())
|
||
|
self.dropout = nn.Dropout(drop)
|
||
|
|
||
|
def forward(self, input_, mask=None):
|
||
|
"""input: [B, T, D], mask = [B, T]"""
|
||
|
score = self.fc(input_).squeeze(-1)
|
||
|
if mask is not None:
|
||
|
mask = mask.to(dtype=input_.dtype) * -1e4
|
||
|
score = score + mask
|
||
|
norm_score = self.dropout(F.softmax(score, dim=1))
|
||
|
output = norm_score.unsqueeze(1).matmul(input_).squeeze(1)
|
||
|
return output
|
||
|
|
||
|
|
||
|
class UniterForNlvr2PairedAttn(UniterPreTrainedModel):
|
||
|
""" Finetune UNITER for NLVR2
|
||
|
(paired format with additional attention layer)
|
||
|
"""
|
||
|
def __init__(self, config, img_dim):
|
||
|
super().__init__(config)
|
||
|
self.bert = UniterModel(config, img_dim)
|
||
|
self.attn1 = MultiheadAttention(config.hidden_size,
|
||
|
config.num_attention_heads,
|
||
|
config.attention_probs_dropout_prob)
|
||
|
self.attn2 = MultiheadAttention(config.hidden_size,
|
||
|
config.num_attention_heads,
|
||
|
config.attention_probs_dropout_prob)
|
||
|
self.fc = nn.Sequential(
|
||
|
nn.Linear(2*config.hidden_size, config.hidden_size),
|
||
|
GELU(),
|
||
|
nn.Dropout(config.hidden_dropout_prob))
|
||
|
self.attn_pool = AttentionPool(config.hidden_size,
|
||
|
config.attention_probs_dropout_prob)
|
||
|
self.nlvr2_output = nn.Linear(2*config.hidden_size, 2)
|
||
|
self.apply(self.init_weights)
|
||
|
|
||
|
def init_type_embedding(self):
|
||
|
new_emb = nn.Embedding(3, self.bert.config.hidden_size)
|
||
|
new_emb.apply(self.init_weights)
|
||
|
for i in [0, 1]:
|
||
|
emb = self.bert.embeddings.token_type_embeddings\
|
||
|
.weight.data[i, :]
|
||
|
new_emb.weight.data[i, :].copy_(emb)
|
||
|
new_emb.weight.data[2, :].copy_(emb)
|
||
|
self.bert.embeddings.token_type_embeddings = new_emb
|
||
|
|
||
|
def forward(self, input_ids, position_ids, img_feat, img_pos_feat,
|
||
|
attn_masks, gather_index,
|
||
|
img_type_ids, targets, compute_loss=True):
|
||
|
sequence_output = self.bert(input_ids, position_ids,
|
||
|
img_feat, img_pos_feat,
|
||
|
attn_masks, gather_index,
|
||
|
output_all_encoded_layers=False,
|
||
|
img_type_ids=img_type_ids)
|
||
|
# separate left image and right image
|
||
|
bs, tl, d = sequence_output.size()
|
||
|
left_out, right_out = sequence_output.contiguous().view(
|
||
|
bs//2, tl*2, d).chunk(2, dim=1)
|
||
|
# bidirectional attention
|
||
|
mask = attn_masks == 0
|
||
|
left_mask, right_mask = mask.contiguous().view(bs//2, tl*2
|
||
|
).chunk(2, dim=1)
|
||
|
left_out = left_out.transpose(0, 1)
|
||
|
right_out = right_out.transpose(0, 1)
|
||
|
l2r_attn, _ = self.attn1(left_out, right_out, right_out,
|
||
|
key_padding_mask=right_mask)
|
||
|
r2l_attn, _ = self.attn2(right_out, left_out, left_out,
|
||
|
key_padding_mask=left_mask)
|
||
|
left_out = self.fc(torch.cat([l2r_attn, left_out], dim=-1)
|
||
|
).transpose(0, 1)
|
||
|
right_out = self.fc(torch.cat([r2l_attn, right_out], dim=-1)
|
||
|
).transpose(0, 1)
|
||
|
# attention pooling and final prediction
|
||
|
left_out = self.attn_pool(left_out, left_mask)
|
||
|
right_out = self.attn_pool(right_out, right_mask)
|
||
|
answer_scores = self.nlvr2_output(
|
||
|
torch.cat([left_out, right_out], dim=-1))
|
||
|
|
||
|
if compute_loss:
|
||
|
nlvr2_loss = F.cross_entropy(
|
||
|
answer_scores, targets, reduction='none')
|
||
|
return nlvr2_loss
|
||
|
else:
|
||
|
return answer_scores
|