albef
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
128 lines
6.0 KiB
128 lines
6.0 KiB
from functools import partial
|
|
from models.vit import VisionTransformer
|
|
from models.xbert import BertConfig, BertModel
|
|
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
class ALBEF(nn.Module):
|
|
def __init__(self,
|
|
text_encoder = None,
|
|
tokenizer = None,
|
|
config = None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.tokenizer = tokenizer
|
|
self.distill = config['distill']
|
|
|
|
self.visual_encoder = VisionTransformer(
|
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
|
|
|
bert_config = BertConfig.from_json_file(config['bert_config'])
|
|
bert_config.num_hidden_layers = 18
|
|
|
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
|
|
self.cls_head = nn.Sequential(
|
|
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
|
)
|
|
|
|
self.share_cross_attention(self.text_encoder.encoder)
|
|
|
|
if self.distill:
|
|
self.visual_encoder_m = VisionTransformer(
|
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
|
|
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
|
|
self.share_cross_attention(self.text_encoder_m.encoder)
|
|
|
|
self.cls_head_m = nn.Sequential(
|
|
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size),
|
|
nn.ReLU(),
|
|
nn.Linear(self.text_encoder.config.hidden_size, 2)
|
|
)
|
|
|
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m],
|
|
[self.text_encoder,self.text_encoder_m],
|
|
[self.cls_head,self.cls_head_m],
|
|
]
|
|
self.copy_params()
|
|
self.momentum = 0.995
|
|
|
|
|
|
def forward(self, image, text, targets, alpha=0, train=True):
|
|
|
|
image_embeds = self.visual_encoder(image)
|
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
|
|
|
|
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0))
|
|
|
|
output = self.text_encoder(text.input_ids,
|
|
attention_mask = text.attention_mask,
|
|
encoder_hidden_states = [image0_embeds,image1_embeds],
|
|
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
|
|
image_atts[image0_embeds.size(0):]],
|
|
return_dict = True,
|
|
)
|
|
hidden_state = output.last_hidden_state[:,0,:]
|
|
prediction = self.cls_head(hidden_state)
|
|
|
|
if train:
|
|
if self.distill:
|
|
with torch.no_grad():
|
|
self._momentum_update()
|
|
image_embeds_m = self.visual_encoder_m(image)
|
|
image0_embeds_m, image1_embeds_m = torch.split(image_embeds_m,targets.size(0))
|
|
output_m = self.text_encoder_m(text.input_ids,
|
|
attention_mask = text.attention_mask,
|
|
encoder_hidden_states = [image0_embeds_m,image1_embeds_m],
|
|
encoder_attention_mask = [image_atts[:image0_embeds.size(0)],
|
|
image_atts[image0_embeds.size(0):]],
|
|
return_dict = True,
|
|
)
|
|
prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:])
|
|
|
|
loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum(
|
|
F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean()
|
|
else:
|
|
loss = F.cross_entropy(prediction, targets)
|
|
return loss
|
|
else:
|
|
return prediction
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
def copy_params(self):
|
|
for model_pair in self.model_pairs:
|
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
|
param_m.data.copy_(param.data) # initialize
|
|
param_m.requires_grad = False # not update by gradient
|
|
|
|
|
|
@torch.no_grad()
|
|
def _momentum_update(self):
|
|
for model_pair in self.model_pairs:
|
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()):
|
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum)
|
|
|
|
|
|
def share_cross_attention(self, model):
|
|
|
|
for i in range(6):
|
|
layer_num = 6+i*2
|
|
modules_0 = model.layer[layer_num].crossattention.self._modules
|
|
modules_1 = model.layer[layer_num+1].crossattention.self._modules
|
|
|
|
for name in modules_0.keys():
|
|
if 'key' in name or 'value' in name:
|
|
module_0 = modules_0[name]
|
|
module_1 = modules_1[name]
|
|
if hasattr(module_0, "weight"):
|
|
module_0.weight = module_1.weight
|
|
if hasattr(module_0, "bias"):
|
|
module_0.bias = module_1.bias
|