albef
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
215 lines
11 KiB
215 lines
11 KiB
2 years ago
|
from functools import partial
|
||
|
from models.vit import VisionTransformer
|
||
|
from models.xbert import BertConfig, BertModel, BertLMHeadModel
|
||
|
|
||
|
import torch
|
||
|
from torch import nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import numpy as np
|
||
|
|
||
|
class ALBEF(nn.Module):
|
||
|
def __init__(self,
|
||
|
text_encoder = None,
|
||
|
text_decoder = None,
|
||
|
tokenizer = None,
|
||
|
config = None,
|
||
|
):
|
||
|
super().__init__()
|
||
|
|
||
|
self.tokenizer = tokenizer
|
||
|
self.distill = config['distill']
|
||
|
|
||
|
self.visual_encoder = VisionTransformer(
|
||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
||
|
|
||
|
config_encoder = BertConfig.from_json_file(config['bert_config'])
|
||
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False)
|
||
|
|
||
|
config_decoder = BertConfig.from_json_file(config['bert_config'])
|
||
|
config_decoder.fusion_layer = 0
|
||
|
config_decoder.num_hidden_layers = 6
|
||
|
self.text_decoder = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder)
|
||
|
|
||
|
if self.distill:
|
||
|
self.visual_encoder_m = VisionTransformer(
|
||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
||
|
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False)
|
||
|
self.text_decoder_m = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder)
|
||
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
||
|
[self.text_encoder,self.text_encoder_m],
|
||
|
[self.text_decoder,self.text_decoder_m],
|
||
|
]
|
||
|
self.copy_params()
|
||
|
self.momentum = 0.995
|
||
|
|
||
|
|
||
|
def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, train=True):
|
||
|
|
||
|
image_embeds = self.visual_encoder(image)
|
||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
||
|
|
||
|
if train:
|
||
|
'''
|
||
|
k: number of answers for each question
|
||
|
weights: weight for each answer
|
||
|
'''
|
||
|
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
|
||
|
|
||
|
question_output = self.text_encoder(quesiton.input_ids,
|
||
|
attention_mask = quesiton.attention_mask,
|
||
|
encoder_hidden_states = image_embeds,
|
||
|
encoder_attention_mask = image_atts,
|
||
|
return_dict = True)
|
||
|
|
||
|
question_states = []
|
||
|
question_atts = []
|
||
|
for b, n in enumerate(k):
|
||
|
question_states += [question_output.last_hidden_state[b]]*n
|
||
|
question_atts += [quesiton.attention_mask[b]]*n
|
||
|
question_states = torch.stack(question_states,0)
|
||
|
question_atts = torch.stack(question_atts,0)
|
||
|
|
||
|
if self.distill:
|
||
|
with torch.no_grad():
|
||
|
self._momentum_update()
|
||
|
image_embeds_m = self.visual_encoder_m(image)
|
||
|
question_output_m = self.text_encoder_m(quesiton.input_ids,
|
||
|
attention_mask = quesiton.attention_mask,
|
||
|
encoder_hidden_states = image_embeds_m,
|
||
|
encoder_attention_mask = image_atts,
|
||
|
return_dict = True)
|
||
|
|
||
|
question_states_m = []
|
||
|
for b, n in enumerate(k):
|
||
|
question_states_m += [question_output_m.last_hidden_state[b]]*n
|
||
|
question_states_m = torch.stack(question_states_m,0)
|
||
|
|
||
|
logits_m = self.text_decoder_m(answer.input_ids,
|
||
|
attention_mask = answer.attention_mask,
|
||
|
encoder_hidden_states = question_states_m,
|
||
|
encoder_attention_mask = question_atts,
|
||
|
return_logits = True,
|
||
|
)
|
||
|
|
||
|
answer_output = self.text_decoder(answer.input_ids,
|
||
|
attention_mask = answer.attention_mask,
|
||
|
encoder_hidden_states = question_states,
|
||
|
encoder_attention_mask = question_atts,
|
||
|
labels = answer_targets,
|
||
|
return_dict = True,
|
||
|
soft_labels = F.softmax(logits_m,dim=-1),
|
||
|
alpha = alpha,
|
||
|
reduction = 'none',
|
||
|
)
|
||
|
else:
|
||
|
answer_output = self.text_decoder(answer.input_ids,
|
||
|
attention_mask = answer.attention_mask,
|
||
|
encoder_hidden_states = question_states,
|
||
|
encoder_attention_mask = question_atts,
|
||
|
labels = answer_targets,
|
||
|
return_dict = True,
|
||
|
reduction = 'none',
|
||
|
)
|
||
|
loss = weights * answer_output.loss
|
||
|
loss = loss.sum()/image.size(0)
|
||
|
|
||
|
return loss
|
||
|
|
||
|
|
||
|
else:
|
||
|
question_output = self.text_encoder(quesiton.input_ids,
|
||
|
attention_mask = quesiton.attention_mask,
|
||
|
encoder_hidden_states = image_embeds,
|
||
|
encoder_attention_mask = image_atts,
|
||
|
return_dict = True)
|
||
|
topk_ids, topk_probs = self.rank_answer(question_output.last_hidden_state, quesiton.attention_mask,
|
||
|
answer.input_ids, answer.attention_mask, k)
|
||
|
return topk_ids, topk_probs
|
||
|
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def copy_params(self):
|
||
|
for model_pair in self.model_pairs:
|
||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
||
|
param_m.data.copy_(param.data) # initialize
|
||
|
param_m.requires_grad = False # not update by gradient
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def _momentum_update(self):
|
||
|
for model_pair in self.model_pairs:
|
||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
||
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
||
|
|
||
|
|
||
|
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
|
||
|
|
||
|
num_ques = question_states.size(0)
|
||
|
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
|
||
|
|
||
|
start_output = self.text_decoder(start_ids,
|
||
|
encoder_hidden_states = question_states,
|
||
|
encoder_attention_mask = question_atts,
|
||
|
return_dict = True,
|
||
|
reduction = 'none')
|
||
|
logits = start_output.logits[:,0,:] # first token's logit
|
||
|
|
||
|
# topk_probs: top-k probability
|
||
|
# topk_ids: [num_question, k]
|
||
|
answer_first_token = answer_ids[:,1]
|
||
|
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
|
||
|
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
|
||
|
|
||
|
# answer input: [num_question*k, answer_len]
|
||
|
input_ids = []
|
||
|
input_atts = []
|
||
|
for b, topk_id in enumerate(topk_ids):
|
||
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
|
||
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
|
||
|
input_ids = torch.cat(input_ids,dim=0)
|
||
|
input_atts = torch.cat(input_atts,dim=0)
|
||
|
|
||
|
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
|
||
|
|
||
|
# repeat encoder's output for top-k answers
|
||
|
question_states = tile(question_states, 0, k)
|
||
|
question_atts = tile(question_atts, 0, k)
|
||
|
|
||
|
output = self.text_decoder(input_ids,
|
||
|
attention_mask = input_atts,
|
||
|
encoder_hidden_states = question_states,
|
||
|
encoder_attention_mask = question_atts,
|
||
|
labels = targets_ids,
|
||
|
return_dict = True,
|
||
|
reduction = 'none')
|
||
|
|
||
|
answer_loss = output.loss
|
||
|
answer_loss = answer_loss.view(input_ids.size(0),-1)
|
||
|
|
||
|
# topk_prob: first token probability
|
||
|
topk_probs = topk_probs.view(-1,1)
|
||
|
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1)
|
||
|
|
||
|
# re-calculate log probabilities for the answer sequences using chain rule
|
||
|
log_probs_sum = log_probs.sum(1)
|
||
|
log_probs_sum = log_probs_sum.view(num_ques,k)
|
||
|
|
||
|
topk_probs = F.softmax(log_probs_sum, dim=-1)
|
||
|
# get top-k after re-ranking
|
||
|
topk_probs, rerank_id = topk_probs.topk(k,dim=1)
|
||
|
topk_ids = torch.gather(topk_ids, 1, rerank_id)
|
||
|
|
||
|
return topk_ids, topk_probs
|
||
|
|
||
|
def tile(x, dim, n_tile):
|
||
|
init_dim = x.size(dim)
|
||
|
repeat_idx = [1] * x.dim()
|
||
|
repeat_idx[dim] = n_tile
|
||
|
x = x.repeat(*(repeat_idx))
|
||
|
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
||
|
return torch.index_select(x, dim, order_index.to(x.device))
|