from functools import partial from models.vit import VisionTransformer from models.xbert import BertConfig, BertModel, BertLMHeadModel import torch from torch import nn import torch.nn.functional as F import numpy as np class ALBEF(nn.Module): def __init__(self, text_encoder = None, text_decoder = None, tokenizer = None, config = None, ): super().__init__() self.tokenizer = tokenizer self.distill = config['distill'] self.visual_encoder = VisionTransformer( img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) config_encoder = BertConfig.from_json_file(config['bert_config']) self.text_encoder = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False) config_decoder = BertConfig.from_json_file(config['bert_config']) config_decoder.fusion_layer = 0 config_decoder.num_hidden_layers = 6 self.text_decoder = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder) if self.distill: self.visual_encoder_m = VisionTransformer( img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False) self.text_decoder_m = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder) self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], [self.text_encoder,self.text_encoder_m], [self.text_decoder,self.text_decoder_m], ] self.copy_params() self.momentum = 0.995 def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, train=True): image_embeds = self.visual_encoder(image) image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) if train: ''' k: number of answers for each question weights: weight for each answer ''' answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) question_output = self.text_encoder(quesiton.input_ids, attention_mask = quesiton.attention_mask, encoder_hidden_states = image_embeds, encoder_attention_mask = image_atts, return_dict = True) question_states = [] question_atts = [] for b, n in enumerate(k): question_states += [question_output.last_hidden_state[b]]*n question_atts += [quesiton.attention_mask[b]]*n question_states = torch.stack(question_states,0) question_atts = torch.stack(question_atts,0) if self.distill: with torch.no_grad(): self._momentum_update() image_embeds_m = self.visual_encoder_m(image) question_output_m = self.text_encoder_m(quesiton.input_ids, attention_mask = quesiton.attention_mask, encoder_hidden_states = image_embeds_m, encoder_attention_mask = image_atts, return_dict = True) question_states_m = [] for b, n in enumerate(k): question_states_m += [question_output_m.last_hidden_state[b]]*n question_states_m = torch.stack(question_states_m,0) logits_m = self.text_decoder_m(answer.input_ids, attention_mask = answer.attention_mask, encoder_hidden_states = question_states_m, encoder_attention_mask = question_atts, return_logits = True, ) answer_output = self.text_decoder(answer.input_ids, attention_mask = answer.attention_mask, encoder_hidden_states = question_states, encoder_attention_mask = question_atts, labels = answer_targets, return_dict = True, soft_labels = F.softmax(logits_m,dim=-1), alpha = alpha, reduction = 'none', ) else: answer_output = self.text_decoder(answer.input_ids, attention_mask = answer.attention_mask, encoder_hidden_states = question_states, encoder_attention_mask = question_atts, labels = answer_targets, return_dict = True, reduction = 'none', ) loss = weights * answer_output.loss loss = loss.sum()/image.size(0) return loss else: question_output = self.text_encoder(quesiton.input_ids, attention_mask = quesiton.attention_mask, encoder_hidden_states = image_embeds, encoder_attention_mask = image_atts, return_dict = True) topk_ids, topk_probs = self.rank_answer(question_output.last_hidden_state, quesiton.attention_mask, answer.input_ids, answer.attention_mask, k) return topk_ids, topk_probs @torch.no_grad() def copy_params(self): for model_pair in self.model_pairs: for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): param_m.data.copy_(param.data) # initialize param_m.requires_grad = False # not update by gradient @torch.no_grad() def _momentum_update(self): for model_pair in self.model_pairs: for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): num_ques = question_states.size(0) start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token start_output = self.text_decoder(start_ids, encoder_hidden_states = question_states, encoder_attention_mask = question_atts, return_dict = True, reduction = 'none') logits = start_output.logits[:,0,:] # first token's logit # topk_probs: top-k probability # topk_ids: [num_question, k] answer_first_token = answer_ids[:,1] prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) topk_probs, topk_ids = prob_first_token.topk(k,dim=1) # answer input: [num_question*k, answer_len] input_ids = [] input_atts = [] for b, topk_id in enumerate(topk_ids): input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) input_ids = torch.cat(input_ids,dim=0) input_atts = torch.cat(input_atts,dim=0) targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) # repeat encoder's output for top-k answers question_states = tile(question_states, 0, k) question_atts = tile(question_atts, 0, k) output = self.text_decoder(input_ids, attention_mask = input_atts, encoder_hidden_states = question_states, encoder_attention_mask = question_atts, labels = targets_ids, return_dict = True, reduction = 'none') answer_loss = output.loss answer_loss = answer_loss.view(input_ids.size(0),-1) # topk_prob: first token probability topk_probs = topk_probs.view(-1,1) log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1) # re-calculate log probabilities for the answer sequences using chain rule log_probs_sum = log_probs.sum(1) log_probs_sum = log_probs_sum.view(num_ques,k) topk_probs = F.softmax(log_probs_sum, dim=-1) # get top-k after re-ranking topk_probs, rerank_id = topk_probs.topk(k,dim=1) topk_ids = torch.gather(topk_ids, 1, rerank_id) return topk_ids, topk_probs def tile(x, dim, n_tile): init_dim = x.size(dim) repeat_idx = [1] * x.dim() repeat_idx[dim] = n_tile x = x.repeat(*(repeat_idx)) order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) return torch.index_select(x, dim, order_index.to(x.device))