logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

99 lines
4.1 KiB

from functools import partial
from models.vit import VisionTransformer
from models.xbert import BertConfig, BertModel
import torch
from torch import nn
import torch.nn.functional as F
class ALBEF(nn.Module):
def __init__(self,
text_encoder = None,
tokenizer = None,
config = None,
):
super().__init__()
self.tokenizer = tokenizer
vision_width = config['vision_width']
embed_dim = config['embed_dim']
self.visual_encoder = VisionTransformer(
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))
bert_config = BertConfig.from_json_file(config['bert_config'])
bert_config.num_hidden_layers = 18
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)
#share the cross-attention layers for two images
self.share_cross_attention(self.text_encoder.encoder)
text_width = self.text_encoder.config.hidden_size
self.vision_proj = nn.Linear(vision_width, embed_dim)
self.text_proj = nn.Linear(text_width, embed_dim)
self.temp = nn.Parameter(torch.ones([]) * 0.07)
self.ta_head = nn.Linear(self.text_encoder.config.hidden_size, 3)
def forward(self, image, text):
image_embeds = self.visual_encoder(image)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)
with torch.no_grad():
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1)
sim = image_feat @ image_feat.t() / 0.07
weights = F.softmax(sim,dim=1)
weights.fill_diagonal_(0)
image_inputs = [[],[]]
labels = []
for b in range(image.size(0)):
if torch.rand(1)>1/3:
idx = torch.multinomial(weights[b], 1).item()
if torch.rand(1)>0.5:
image_inputs[0].append(image_embeds[b])
image_inputs[1].append(image_embeds[idx])
labels.append(0)
else:
image_inputs[1].append(image_embeds[b])
image_inputs[0].append(image_embeds[idx])
labels.append(1)
else:
idx = torch.multinomial(weights[b], 2)
image_inputs[0].append(image_embeds[idx[0]])
image_inputs[1].append(image_embeds[idx[1]])
labels.append(2)
image_inputs[0] = torch.stack(image_inputs[0],dim=0)
image_inputs[1] = torch.stack(image_inputs[1],dim=0)
labels = torch.LongTensor(labels).to(image.device)
output = self.text_encoder(text.input_ids,
attention_mask = text.attention_mask,
encoder_hidden_states = image_inputs,
encoder_attention_mask = [image_atts,image_atts],
return_dict = True,
)
pred = self.ta_head(output.last_hidden_state[:,0,:])
loss = F.cross_entropy(pred, labels)
return loss
def share_cross_attention(self, model):
for i in range(6):
layer_num = 6+i*2
modules_0 = model.layer[layer_num].crossattention.self._modules
modules_1 = model.layer[layer_num+1].crossattention.self._modules
for name in modules_0.keys():
if 'key' in name or 'value' in name:
module_0 = modules_0[name]
module_1 = modules_1[name]
if hasattr(module_0, "weight"):
module_0.weight = module_1.weight
if hasattr(module_0, "bias"):
module_0.bias = module_1.bias