You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Files and versions
215 lines
11 KiB
215 lines
11 KiB
2 years ago
from functools import partial
from models.vit import VisionTransformer
from models.xbert import BertConfig, BertModel, BertLMHeadModel
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
class ALBEF(nn.Module):
def __init__(self,
text_encoder = None,
text_decoder = None,
tokenizer = None,
config = None,
self.tokenizer = tokenizer
self.distill = config['distill']
self.visual_encoder = VisionTransformer(
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
config_encoder = BertConfig.from_json_file(config['bert_config'])
self.text_encoder = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False)
config_decoder = BertConfig.from_json_file(config['bert_config'])
config_decoder.fusion_layer = 0
config_decoder.num_hidden_layers = 6
self.text_decoder = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder)
if self.distill:
self.visual_encoder_m = VisionTransformer(
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False)
self.text_decoder_m = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder)
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
self.momentum = 0.995
def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, train=True):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
if train:
k: number of answers for each question
weights: weight for each answer
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100)
question_output = self.text_encoder(quesiton.input_ids,
attention_mask = quesiton.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
question_states = []
question_atts = []
for b, n in enumerate(k):
question_states += [question_output.last_hidden_state[b]]*n
question_atts += [quesiton.attention_mask[b]]*n
question_states = torch.stack(question_states,0)
question_atts = torch.stack(question_atts,0)
if self.distill:
with torch.no_grad():
image_embeds_m = self.visual_encoder_m(image)
question_output_m = self.text_encoder_m(quesiton.input_ids,
attention_mask = quesiton.attention_mask,
encoder_hidden_states = image_embeds_m,
encoder_attention_mask = image_atts,
return_dict = True)
question_states_m = []
for b, n in enumerate(k):
question_states_m += [question_output_m.last_hidden_state[b]]*n
question_states_m = torch.stack(question_states_m,0)
logits_m = self.text_decoder_m(answer.input_ids,
attention_mask = answer.attention_mask,
encoder_hidden_states = question_states_m,
encoder_attention_mask = question_atts,
return_logits = True,
answer_output = self.text_decoder(answer.input_ids,
attention_mask = answer.attention_mask,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = answer_targets,
return_dict = True,
soft_labels = F.softmax(logits_m,dim=-1),
alpha = alpha,
reduction = 'none',
answer_output = self.text_decoder(answer.input_ids,
attention_mask = answer.attention_mask,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = answer_targets,
return_dict = True,
reduction = 'none',
loss = weights * answer_output.loss
loss = loss.sum()/image.size(0)
return loss
question_output = self.text_encoder(quesiton.input_ids,
attention_mask = quesiton.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True)
topk_ids, topk_probs = self.rank_answer(question_output.last_hidden_state, quesiton.attention_mask,
answer.input_ids, answer.attention_mask, k)
return topk_ids, topk_probs
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
| # initialize
param_m.requires_grad = False # not update by gradient
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
| = * self.momentum + * (1. - self.momentum)
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k):
num_ques = question_states.size(0)
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token
start_output = self.text_decoder(start_ids,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
return_dict = True,
reduction = 'none')
logits = start_output.logits[:,0,:] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:,1]
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(k,dim=1)
# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids =,dim=0)
input_atts =,dim=0)
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100)
# repeat encoder's output for top-k answers
question_states = tile(question_states, 0, k)
question_atts = tile(question_atts, 0, k)
output = self.text_decoder(input_ids,
attention_mask = input_atts,
encoder_hidden_states = question_states,
encoder_attention_mask = question_atts,
labels = targets_ids,
return_dict = True,
reduction = 'none')
answer_loss = output.loss
answer_loss = answer_loss.view(input_ids.size(0),-1)
# topk_prob: first token probability
topk_probs = topk_probs.view(-1,1)
log_probs =[topk_probs.log(), -answer_loss],dim=1)
# re-calculate log probabilities for the answer sequences using chain rule
log_probs_sum = log_probs.sum(1)
log_probs_sum = log_probs_sum.view(num_ques,k)
topk_probs = F.softmax(log_probs_sum, dim=-1)
# get top-k after re-ranking
topk_probs, rerank_id = topk_probs.topk(k,dim=1)
topk_ids = torch.gather(topk_ids, 1, rerank_id)
return topk_ids, topk_probs
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
return torch.index_select(x, dim,