albef
copied
wxywb
2 years ago
18 changed files with 3847 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .albef import Albef |
||||
|
|
||||
|
def albef(model_name: str, modality: str): |
||||
|
return Albef(model_name, modality) |
@ -0,0 +1,113 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import sys |
||||
|
from pathlib import Path |
||||
|
from PIL import Image |
||||
|
import torch |
||||
|
import yaml |
||||
|
from torchvision import transforms |
||||
|
|
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee import register |
||||
|
|
||||
|
@register(output_schema=['vec']) |
||||
|
class Albef(NNOperator): |
||||
|
""" |
||||
|
ALBEF multi-modal embedding operator |
||||
|
""" |
||||
|
def prepare_model(checkpoint_path, model): |
||||
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
||||
|
state_dict = checkpoint['model'] |
||||
|
pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) |
||||
|
state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped |
||||
|
m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m) |
||||
|
state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped |
||||
|
for key in list(state_dict.keys()): |
||||
|
if 'bert' in key: |
||||
|
encoder_key = key.replace('bert.','') |
||||
|
state_dict[encoder_key] = state_dict[key] |
||||
|
del state_dict[key] |
||||
|
msg = model.load_state_dict(state_dict,strict=False) |
||||
|
print('load checkpoint from ' + checkpoint_path) |
||||
|
return model |
||||
|
|
||||
|
def __init__(self, model_name: str, modality: str): |
||||
|
self.modality = modality |
||||
|
config = self._configs()[model_name] |
||||
|
|
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
|
||||
|
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
||||
|
tokenizer = BertTokenizer.from_pretrained(config) |
||||
|
model = ALBEF(config=config, text_encoder=config['text_encoder'], tokenizer=tokenizer) |
||||
|
cfg = yaml.load(open(config['cfg'], 'r'), Loader=yaml.Loader) |
||||
|
checkpoint_path = cfg['ckpt_path'] |
||||
|
|
||||
|
self.model = self.prepare_model(checkpoint_path, model) |
||||
|
|
||||
|
self.test_transform = transforms.Compose([ |
||||
|
transforms.Resize((cfg['image_res'],cfg['image_res']),interpolation=Image.BICUBIC), |
||||
|
transforms.ToTensor(), |
||||
|
normalize, |
||||
|
]) |
||||
|
|
||||
|
|
||||
|
def inference_single_data(self, data): |
||||
|
if self.modality == 'image': |
||||
|
vec = self._inference_from_image(data) |
||||
|
elif self.modality == 'text': |
||||
|
vec = self._inference_from_text(data) |
||||
|
else: |
||||
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
||||
|
return vec.detach().cpu().numpy().flatten() |
||||
|
|
||||
|
def __call__(self, data): |
||||
|
if not isinstance(data, list): |
||||
|
data = [data] |
||||
|
else: |
||||
|
data = data |
||||
|
results = [] |
||||
|
for single_data in data: |
||||
|
result = self.inference_single_data(single_data) |
||||
|
results.append(result) |
||||
|
if len(data) == 1: |
||||
|
return results[0] |
||||
|
else: |
||||
|
return results |
||||
|
|
||||
|
def _inference_from_text(self, text): |
||||
|
tokens = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].to(self.device) |
||||
|
text_features = self.text_encoder(tokens).logits |
||||
|
return text_features |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def _inference_from_image(self, img): |
||||
|
image = to_pil(img) |
||||
|
image = self.processor(images=image, return_tensors="pt").to(self.device) |
||||
|
image_features = self.clip_model.get_image_features(**image) |
||||
|
return image_features |
||||
|
|
||||
|
def _configs(self): |
||||
|
config = {} |
||||
|
config['albef_4m'] = {} |
||||
|
config['albef_4m']['tokenizer'] = 'bert-base-uncased' |
||||
|
config['albef_4m']['text_encoder'] = 'bert-base-uncased' |
||||
|
config['albef_4m']['cfg_path'] = './configs/Retrieval_flickr.yaml' |
||||
|
config['albef_4m']['ckpt_path'] = '' |
||||
|
|
||||
|
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,128 @@ |
|||||
|
from functools import partial |
||||
|
from models.vit import VisionTransformer |
||||
|
from models.xbert import BertConfig, BertModel |
||||
|
|
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
class ALBEF(nn.Module): |
||||
|
def __init__(self, |
||||
|
text_encoder = None, |
||||
|
tokenizer = None, |
||||
|
config = None, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.tokenizer = tokenizer |
||||
|
self.distill = config['distill'] |
||||
|
|
||||
|
self.visual_encoder = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
|
||||
|
bert_config = BertConfig.from_json_file(config['bert_config']) |
||||
|
bert_config.num_hidden_layers = 18 |
||||
|
|
||||
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
self.cls_head = nn.Sequential( |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, 2) |
||||
|
) |
||||
|
|
||||
|
self.share_cross_attention(self.text_encoder.encoder) |
||||
|
|
||||
|
if self.distill: |
||||
|
self.visual_encoder_m = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
self.share_cross_attention(self.text_encoder_m.encoder) |
||||
|
|
||||
|
self.cls_head_m = nn.Sequential( |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, 2) |
||||
|
) |
||||
|
|
||||
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
||||
|
[self.text_encoder,self.text_encoder_m], |
||||
|
[self.cls_head,self.cls_head_m], |
||||
|
] |
||||
|
self.copy_params() |
||||
|
self.momentum = 0.995 |
||||
|
|
||||
|
|
||||
|
def forward(self, image, text, targets, alpha=0, train=True): |
||||
|
|
||||
|
image_embeds = self.visual_encoder(image) |
||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
||||
|
|
||||
|
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) |
||||
|
|
||||
|
output = self.text_encoder(text.input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = [image0_embeds,image1_embeds], |
||||
|
encoder_attention_mask = [image_atts[:image0_embeds.size(0)], |
||||
|
image_atts[image0_embeds.size(0):]], |
||||
|
return_dict = True, |
||||
|
) |
||||
|
hidden_state = output.last_hidden_state[:,0,:] |
||||
|
prediction = self.cls_head(hidden_state) |
||||
|
|
||||
|
if train: |
||||
|
if self.distill: |
||||
|
with torch.no_grad(): |
||||
|
self._momentum_update() |
||||
|
image_embeds_m = self.visual_encoder_m(image) |
||||
|
image0_embeds_m, image1_embeds_m = torch.split(image_embeds_m,targets.size(0)) |
||||
|
output_m = self.text_encoder_m(text.input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = [image0_embeds_m,image1_embeds_m], |
||||
|
encoder_attention_mask = [image_atts[:image0_embeds.size(0)], |
||||
|
image_atts[image0_embeds.size(0):]], |
||||
|
return_dict = True, |
||||
|
) |
||||
|
prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:]) |
||||
|
|
||||
|
loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum( |
||||
|
F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean() |
||||
|
else: |
||||
|
loss = F.cross_entropy(prediction, targets) |
||||
|
return loss |
||||
|
else: |
||||
|
return prediction |
||||
|
|
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def copy_params(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data.copy_(param.data) # initialize |
||||
|
param_m.requires_grad = False # not update by gradient |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _momentum_update(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
||||
|
|
||||
|
|
||||
|
def share_cross_attention(self, model): |
||||
|
|
||||
|
for i in range(6): |
||||
|
layer_num = 6+i*2 |
||||
|
modules_0 = model.layer[layer_num].crossattention.self._modules |
||||
|
modules_1 = model.layer[layer_num+1].crossattention.self._modules |
||||
|
|
||||
|
for name in modules_0.keys(): |
||||
|
if 'key' in name or 'value' in name: |
||||
|
module_0 = modules_0[name] |
||||
|
module_1 = modules_1[name] |
||||
|
if hasattr(module_0, "weight"): |
||||
|
module_0.weight = module_1.weight |
||||
|
if hasattr(module_0, "bias"): |
||||
|
module_0.bias = module_1.bias |
@ -0,0 +1,291 @@ |
|||||
|
''' |
||||
|
* Copyright (c) 2021, salesforce.com, inc. |
||||
|
* All rights reserved. |
||||
|
* SPDX-License-Identifier: BSD-3-Clause |
||||
|
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
||||
|
''' |
||||
|
|
||||
|
from functools import partial |
||||
|
from models.vit import VisionTransformer, interpolate_pos_embed |
||||
|
from models.xbert import BertConfig, BertForMaskedLM |
||||
|
|
||||
|
import torch |
||||
|
import torch.nn.functional as F |
||||
|
from torch import nn |
||||
|
|
||||
|
import numpy as np |
||||
|
import random |
||||
|
|
||||
|
|
||||
|
class ALBEF(nn.Module): |
||||
|
def __init__(self, |
||||
|
text_encoder = None, |
||||
|
tokenizer = None, |
||||
|
config = None, |
||||
|
temp = 0.07, |
||||
|
init_deit = True |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.tokenizer = tokenizer |
||||
|
self.mlm_probability = config['mlm_probability'] |
||||
|
embed_dim = config['embed_dim'] |
||||
|
|
||||
|
self.visual_encoder = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
|
||||
|
if init_deit: |
||||
|
checkpoint = torch.hub.load_state_dict_from_url( |
||||
|
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", |
||||
|
map_location="cpu", check_hash=True) |
||||
|
state_dict = checkpoint["model"] |
||||
|
pos_embed_reshaped = interpolate_pos_embed(state_dict['pos_embed'], self.visual_encoder) |
||||
|
state_dict['pos_embed'] = pos_embed_reshaped |
||||
|
msg = self.visual_encoder.load_state_dict(state_dict,strict=False) |
||||
|
print(msg) |
||||
|
|
||||
|
vision_width = config['vision_width'] |
||||
|
bert_config = BertConfig.from_json_file(config['bert_config']) |
||||
|
|
||||
|
self.text_encoder = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config) |
||||
|
|
||||
|
text_width = self.text_encoder.config.hidden_size |
||||
|
self.vision_proj = nn.Linear(vision_width, embed_dim) |
||||
|
self.text_proj = nn.Linear(text_width, embed_dim) |
||||
|
|
||||
|
self.temp = nn.Parameter(torch.ones([]) * config['temp']) |
||||
|
self.queue_size = config['queue_size'] |
||||
|
self.momentum = config['momentum'] |
||||
|
self.itm_head = nn.Linear(text_width, 2) |
||||
|
|
||||
|
# create momentum models |
||||
|
self.visual_encoder_m = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
self.vision_proj_m = nn.Linear(vision_width, embed_dim) |
||||
|
self.text_encoder_m = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config) |
||||
|
self.text_proj_m = nn.Linear(text_width, embed_dim) |
||||
|
|
||||
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
||||
|
[self.vision_proj,self.vision_proj_m], |
||||
|
[self.text_encoder,self.text_encoder_m], |
||||
|
[self.text_proj,self.text_proj_m], |
||||
|
] |
||||
|
|
||||
|
self.copy_params() |
||||
|
|
||||
|
# create the queue |
||||
|
self.register_buffer("image_queue", torch.randn(embed_dim, self.queue_size)) |
||||
|
self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size)) |
||||
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
||||
|
|
||||
|
self.image_queue = nn.functional.normalize(self.image_queue, dim=0) |
||||
|
self.text_queue = nn.functional.normalize(self.text_queue, dim=0) |
||||
|
|
||||
|
|
||||
|
|
||||
|
def forward(self, image, text, alpha=0): |
||||
|
with torch.no_grad(): |
||||
|
self.temp.clamp_(0.001,0.5) |
||||
|
|
||||
|
image_embeds = self.visual_encoder(image) |
||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
||||
|
|
||||
|
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
||||
|
|
||||
|
text_output = self.text_encoder.bert(text.input_ids, attention_mask = text.attention_mask, |
||||
|
return_dict = True, mode = 'text') |
||||
|
text_embeds = text_output.last_hidden_state |
||||
|
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1) |
||||
|
|
||||
|
# get momentum features |
||||
|
with torch.no_grad(): |
||||
|
self._momentum_update() |
||||
|
image_embeds_m = self.visual_encoder_m(image) |
||||
|
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) |
||||
|
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) |
||||
|
text_output_m = self.text_encoder_m.bert(text.input_ids, attention_mask = text.attention_mask, |
||||
|
return_dict = True, mode = 'text') |
||||
|
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) |
||||
|
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) |
||||
|
|
||||
|
sim_i2t_m = image_feat_m @ text_feat_all / self.temp |
||||
|
sim_t2i_m = text_feat_m @ image_feat_all / self.temp |
||||
|
|
||||
|
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) |
||||
|
sim_targets.fill_diagonal_(1) |
||||
|
|
||||
|
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets |
||||
|
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets |
||||
|
|
||||
|
sim_i2t = image_feat @ text_feat_all / self.temp |
||||
|
sim_t2i = text_feat @ image_feat_all / self.temp |
||||
|
|
||||
|
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() |
||||
|
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() |
||||
|
|
||||
|
loss_ita = (loss_i2t+loss_t2i)/2 |
||||
|
|
||||
|
self._dequeue_and_enqueue(image_feat_m, text_feat_m) |
||||
|
|
||||
|
###=================================### |
||||
|
# forward the positve image-text pair |
||||
|
output_pos = self.text_encoder.bert(encoder_embeds = text_embeds, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True, |
||||
|
mode = 'fusion', |
||||
|
) |
||||
|
with torch.no_grad(): |
||||
|
bs = image.size(0) |
||||
|
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1) |
||||
|
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1) |
||||
|
|
||||
|
weights_i2t.fill_diagonal_(0) |
||||
|
weights_t2i.fill_diagonal_(0) |
||||
|
|
||||
|
# select a negative image for each text |
||||
|
image_embeds_neg = [] |
||||
|
for b in range(bs): |
||||
|
neg_idx = torch.multinomial(weights_t2i[b], 1).item() |
||||
|
image_embeds_neg.append(image_embeds[neg_idx]) |
||||
|
image_embeds_neg = torch.stack(image_embeds_neg,dim=0) |
||||
|
|
||||
|
# select a negative text for each image |
||||
|
text_embeds_neg = [] |
||||
|
text_atts_neg = [] |
||||
|
for b in range(bs): |
||||
|
neg_idx = torch.multinomial(weights_i2t[b], 1).item() |
||||
|
text_embeds_neg.append(text_embeds[neg_idx]) |
||||
|
text_atts_neg.append(text.attention_mask[neg_idx]) |
||||
|
text_embeds_neg = torch.stack(text_embeds_neg,dim=0) |
||||
|
text_atts_neg = torch.stack(text_atts_neg,dim=0) |
||||
|
|
||||
|
text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0) |
||||
|
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) |
||||
|
|
||||
|
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) |
||||
|
image_atts_all = torch.cat([image_atts,image_atts],dim=0) |
||||
|
|
||||
|
output_neg = self.text_encoder.bert(encoder_embeds = text_embeds_all, |
||||
|
attention_mask = text_atts_all, |
||||
|
encoder_hidden_states = image_embeds_all, |
||||
|
encoder_attention_mask = image_atts_all, |
||||
|
return_dict = True, |
||||
|
mode = 'fusion', |
||||
|
) |
||||
|
|
||||
|
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) |
||||
|
vl_output = self.itm_head(vl_embeddings) |
||||
|
|
||||
|
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], |
||||
|
dim=0).to(image.device) |
||||
|
loss_itm = F.cross_entropy(vl_output, itm_labels) |
||||
|
|
||||
|
##================= MLM ========================## |
||||
|
input_ids = text.input_ids.clone() |
||||
|
labels = input_ids.clone() |
||||
|
|
||||
|
probability_matrix = torch.full(labels.shape, self.mlm_probability) |
||||
|
input_ids, labels = self.mask(input_ids, self.text_encoder.config.vocab_size, image.device, targets=labels, |
||||
|
probability_matrix = probability_matrix) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
logits_m = self.text_encoder_m(input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds_m, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True, |
||||
|
return_logits = True, |
||||
|
) |
||||
|
mlm_output = self.text_encoder(input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True, |
||||
|
labels = labels, |
||||
|
soft_labels = F.softmax(logits_m,dim=-1), |
||||
|
alpha = alpha |
||||
|
) |
||||
|
loss_mlm = mlm_output.loss |
||||
|
|
||||
|
return loss_mlm, loss_ita, loss_itm |
||||
|
|
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def copy_params(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data.copy_(param.data) # initialize |
||||
|
param_m.requires_grad = False # not update by gradient |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _momentum_update(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
||||
|
|
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _dequeue_and_enqueue(self, image_feat, text_feat): |
||||
|
# gather keys before updating queue |
||||
|
image_feats = concat_all_gather(image_feat) |
||||
|
text_feats = concat_all_gather(text_feat) |
||||
|
|
||||
|
batch_size = image_feats.shape[0] |
||||
|
|
||||
|
ptr = int(self.queue_ptr) |
||||
|
assert self.queue_size % batch_size == 0 # for simplicity |
||||
|
|
||||
|
# replace the keys at ptr (dequeue and enqueue) |
||||
|
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T |
||||
|
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T |
||||
|
ptr = (ptr + batch_size) % self.queue_size # move pointer |
||||
|
|
||||
|
self.queue_ptr[0] = ptr |
||||
|
|
||||
|
|
||||
|
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): |
||||
|
if masked_indices is None: |
||||
|
masked_indices = torch.bernoulli(probability_matrix).bool() |
||||
|
|
||||
|
masked_indices[input_ids == self.tokenizer.pad_token_id] = False |
||||
|
masked_indices[input_ids == self.tokenizer.cls_token_id] = False |
||||
|
|
||||
|
if targets is not None: |
||||
|
targets[~masked_indices] = -100 # We only compute loss on masked tokens |
||||
|
|
||||
|
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) |
||||
|
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices |
||||
|
input_ids[indices_replaced] = self.tokenizer.mask_token_id |
||||
|
|
||||
|
# 10% of the time, we replace masked input tokens with random word |
||||
|
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
||||
|
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) |
||||
|
input_ids[indices_random] = random_words[indices_random] |
||||
|
# The rest of the time (10% of the time) we keep the masked input tokens unchanged |
||||
|
|
||||
|
if targets is not None: |
||||
|
return input_ids, targets |
||||
|
else: |
||||
|
return input_ids |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def concat_all_gather(tensor): |
||||
|
""" |
||||
|
Performs all_gather operation on the provided tensors. |
||||
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
||||
|
""" |
||||
|
tensors_gather = [torch.ones_like(tensor) |
||||
|
for _ in range(torch.distributed.get_world_size())] |
||||
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
||||
|
|
||||
|
output = torch.cat(tensors_gather, dim=0) |
||||
|
return output |
||||
|
|
@ -0,0 +1,99 @@ |
|||||
|
from functools import partial |
||||
|
from models.vit import VisionTransformer |
||||
|
from models.xbert import BertConfig, BertModel |
||||
|
|
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
class ALBEF(nn.Module): |
||||
|
def __init__(self, |
||||
|
text_encoder = None, |
||||
|
tokenizer = None, |
||||
|
config = None, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.tokenizer = tokenizer |
||||
|
vision_width = config['vision_width'] |
||||
|
embed_dim = config['embed_dim'] |
||||
|
|
||||
|
self.visual_encoder = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
|
||||
|
bert_config = BertConfig.from_json_file(config['bert_config']) |
||||
|
bert_config.num_hidden_layers = 18 |
||||
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
|
||||
|
#share the cross-attention layers for two images |
||||
|
self.share_cross_attention(self.text_encoder.encoder) |
||||
|
|
||||
|
text_width = self.text_encoder.config.hidden_size |
||||
|
self.vision_proj = nn.Linear(vision_width, embed_dim) |
||||
|
self.text_proj = nn.Linear(text_width, embed_dim) |
||||
|
self.temp = nn.Parameter(torch.ones([]) * 0.07) |
||||
|
self.ta_head = nn.Linear(self.text_encoder.config.hidden_size, 3) |
||||
|
|
||||
|
|
||||
|
def forward(self, image, text): |
||||
|
image_embeds = self.visual_encoder(image) |
||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
||||
|
with torch.no_grad(): |
||||
|
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
||||
|
sim = image_feat @ image_feat.t() / 0.07 |
||||
|
weights = F.softmax(sim,dim=1) |
||||
|
weights.fill_diagonal_(0) |
||||
|
|
||||
|
image_inputs = [[],[]] |
||||
|
labels = [] |
||||
|
for b in range(image.size(0)): |
||||
|
if torch.rand(1)>1/3: |
||||
|
idx = torch.multinomial(weights[b], 1).item() |
||||
|
if torch.rand(1)>0.5: |
||||
|
image_inputs[0].append(image_embeds[b]) |
||||
|
image_inputs[1].append(image_embeds[idx]) |
||||
|
labels.append(0) |
||||
|
else: |
||||
|
image_inputs[1].append(image_embeds[b]) |
||||
|
image_inputs[0].append(image_embeds[idx]) |
||||
|
labels.append(1) |
||||
|
else: |
||||
|
idx = torch.multinomial(weights[b], 2) |
||||
|
image_inputs[0].append(image_embeds[idx[0]]) |
||||
|
image_inputs[1].append(image_embeds[idx[1]]) |
||||
|
labels.append(2) |
||||
|
|
||||
|
image_inputs[0] = torch.stack(image_inputs[0],dim=0) |
||||
|
image_inputs[1] = torch.stack(image_inputs[1],dim=0) |
||||
|
labels = torch.LongTensor(labels).to(image.device) |
||||
|
|
||||
|
output = self.text_encoder(text.input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_inputs, |
||||
|
encoder_attention_mask = [image_atts,image_atts], |
||||
|
return_dict = True, |
||||
|
) |
||||
|
|
||||
|
pred = self.ta_head(output.last_hidden_state[:,0,:]) |
||||
|
loss = F.cross_entropy(pred, labels) |
||||
|
|
||||
|
return loss |
||||
|
|
||||
|
|
||||
|
|
||||
|
def share_cross_attention(self, model): |
||||
|
|
||||
|
for i in range(6): |
||||
|
layer_num = 6+i*2 |
||||
|
modules_0 = model.layer[layer_num].crossattention.self._modules |
||||
|
modules_1 = model.layer[layer_num+1].crossattention.self._modules |
||||
|
|
||||
|
for name in modules_0.keys(): |
||||
|
if 'key' in name or 'value' in name: |
||||
|
module_0 = modules_0[name] |
||||
|
module_1 = modules_1[name] |
||||
|
if hasattr(module_0, "weight"): |
||||
|
module_0.weight = module_1.weight |
||||
|
if hasattr(module_0, "bias"): |
||||
|
module_0.bias = module_1.bias |
@ -0,0 +1,217 @@ |
|||||
|
from functools import partial |
||||
|
from models.vit import VisionTransformer |
||||
|
from models.xbert import BertConfig, BertModel |
||||
|
|
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
class ALBEF(nn.Module): |
||||
|
def __init__(self, |
||||
|
text_encoder = None, |
||||
|
tokenizer = None, |
||||
|
config = None, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.tokenizer = tokenizer |
||||
|
self.distill = config['distill'] |
||||
|
embed_dim = config['embed_dim'] |
||||
|
vision_width = config['vision_width'] |
||||
|
self.visual_encoder = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
|
||||
|
bert_config = BertConfig.from_json_file(config['bert_config']) |
||||
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
|
||||
|
text_width = self.text_encoder.config.hidden_size |
||||
|
self.vision_proj = nn.Linear(vision_width, embed_dim) |
||||
|
self.text_proj = nn.Linear(text_width, embed_dim) |
||||
|
|
||||
|
self.temp = nn.Parameter(torch.ones([]) * config['temp']) |
||||
|
self.queue_size = config['queue_size'] |
||||
|
self.momentum = config['momentum'] |
||||
|
self.itm_head = nn.Linear(text_width, 2) |
||||
|
|
||||
|
# create momentum models |
||||
|
self.visual_encoder_m = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
self.vision_proj_m = nn.Linear(vision_width, embed_dim) |
||||
|
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
self.text_proj_m = nn.Linear(text_width, embed_dim) |
||||
|
|
||||
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
||||
|
[self.vision_proj,self.vision_proj_m], |
||||
|
[self.text_encoder,self.text_encoder_m], |
||||
|
[self.text_proj,self.text_proj_m], |
||||
|
] |
||||
|
self.copy_params() |
||||
|
|
||||
|
# create the queue |
||||
|
self.register_buffer("image_queue", torch.randn(embed_dim, self.queue_size)) |
||||
|
self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size)) |
||||
|
self.register_buffer("idx_queue", torch.full((1,self.queue_size),-100)) |
||||
|
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
||||
|
|
||||
|
self.image_queue = nn.functional.normalize(self.image_queue, dim=0) |
||||
|
self.text_queue = nn.functional.normalize(self.text_queue, dim=0) |
||||
|
|
||||
|
|
||||
|
def forward(self, image, text, alpha, idx): |
||||
|
|
||||
|
image_embeds = self.visual_encoder(image) |
||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
||||
|
|
||||
|
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
||||
|
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, |
||||
|
return_dict = True, mode = 'text') |
||||
|
text_embeds = text_output.last_hidden_state |
||||
|
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1) |
||||
|
|
||||
|
idx = idx.view(-1,1) |
||||
|
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) |
||||
|
pos_idx = torch.eq(idx, idx_all).float() |
||||
|
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
self._momentum_update() |
||||
|
image_embeds_m = self.visual_encoder_m(image) |
||||
|
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) |
||||
|
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) |
||||
|
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, |
||||
|
return_dict = True, mode = 'text') |
||||
|
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) |
||||
|
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) |
||||
|
|
||||
|
if self.distill: |
||||
|
sim_i2t_m = image_feat_m @ text_feat_all / self.temp |
||||
|
sim_t2i_m = text_feat_m @ image_feat_all / self.temp |
||||
|
|
||||
|
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets |
||||
|
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets |
||||
|
|
||||
|
sim_i2t = image_feat @ text_feat_all / self.temp |
||||
|
sim_t2i = text_feat @ image_feat_all / self.temp |
||||
|
|
||||
|
if self.distill: |
||||
|
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() |
||||
|
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() |
||||
|
else: |
||||
|
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean() |
||||
|
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean() |
||||
|
|
||||
|
loss_ita = (loss_i2t+loss_t2i)/2 |
||||
|
|
||||
|
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) |
||||
|
|
||||
|
###=================================### |
||||
|
# forward the positve image-text pair |
||||
|
output_pos = self.text_encoder(encoder_embeds = text_embeds, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True, |
||||
|
mode = 'fusion', |
||||
|
) |
||||
|
with torch.no_grad(): |
||||
|
bs = image.size(0) |
||||
|
weights_i2t = F.softmax(sim_i2t[:,:bs]+1e-4,dim=1) |
||||
|
weights_t2i = F.softmax(sim_t2i[:,:bs]+1e-4,dim=1) |
||||
|
|
||||
|
mask = torch.eq(idx, idx.T) |
||||
|
weights_i2t.masked_fill_(mask, 0) |
||||
|
weights_t2i.masked_fill_(mask, 0) |
||||
|
|
||||
|
# select a negative image for each text |
||||
|
image_embeds_neg = [] |
||||
|
for b in range(bs): |
||||
|
neg_idx = torch.multinomial(weights_t2i[b], 1).item() |
||||
|
image_embeds_neg.append(image_embeds[neg_idx]) |
||||
|
image_embeds_neg = torch.stack(image_embeds_neg,dim=0) |
||||
|
|
||||
|
# select a negative text for each image |
||||
|
text_embeds_neg = [] |
||||
|
text_atts_neg = [] |
||||
|
for b in range(bs): |
||||
|
neg_idx = torch.multinomial(weights_i2t[b], 1).item() |
||||
|
text_embeds_neg.append(text_embeds[neg_idx]) |
||||
|
text_atts_neg.append(text.attention_mask[neg_idx]) |
||||
|
text_embeds_neg = torch.stack(text_embeds_neg,dim=0) |
||||
|
text_atts_neg = torch.stack(text_atts_neg,dim=0) |
||||
|
|
||||
|
text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0) |
||||
|
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) |
||||
|
|
||||
|
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) |
||||
|
image_atts_all = torch.cat([image_atts,image_atts],dim=0) |
||||
|
|
||||
|
output_neg = self.text_encoder(encoder_embeds = text_embeds_all, |
||||
|
attention_mask = text_atts_all, |
||||
|
encoder_hidden_states = image_embeds_all, |
||||
|
encoder_attention_mask = image_atts_all, |
||||
|
return_dict = True, |
||||
|
mode = 'fusion', |
||||
|
) |
||||
|
|
||||
|
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) |
||||
|
vl_output = self.itm_head(vl_embeddings) |
||||
|
|
||||
|
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], |
||||
|
dim=0).to(image.device) |
||||
|
loss_itm = F.cross_entropy(vl_output, itm_labels) |
||||
|
|
||||
|
return loss_ita, loss_itm |
||||
|
|
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def copy_params(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data.copy_(param.data) # initialize |
||||
|
param_m.requires_grad = False # not update by gradient |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _momentum_update(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _dequeue_and_enqueue(self, image_feat, text_feat, idx): |
||||
|
# gather keys before updating queue |
||||
|
image_feats = concat_all_gather(image_feat) |
||||
|
text_feats = concat_all_gather(text_feat) |
||||
|
idxs = concat_all_gather(idx) |
||||
|
|
||||
|
batch_size = image_feats.shape[0] |
||||
|
|
||||
|
ptr = int(self.queue_ptr) |
||||
|
assert self.queue_size % batch_size == 0 # for simplicity |
||||
|
|
||||
|
# replace the keys at ptr (dequeue and enqueue) |
||||
|
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T |
||||
|
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T |
||||
|
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T |
||||
|
ptr = (ptr + batch_size) % self.queue_size # move pointer |
||||
|
|
||||
|
self.queue_ptr[0] = ptr |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def concat_all_gather(tensor): |
||||
|
""" |
||||
|
Performs all_gather operation on the provided tensors. |
||||
|
*** Warning ***: torch.distributed.all_gather has no gradient. |
||||
|
""" |
||||
|
tensors_gather = [torch.ones_like(tensor) |
||||
|
for _ in range(torch.distributed.get_world_size())] |
||||
|
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
||||
|
|
||||
|
output = torch.cat(tensors_gather, dim=0) |
||||
|
return output |
||||
|
|
@ -0,0 +1,110 @@ |
|||||
|
from functools import partial |
||||
|
from models.vit import VisionTransformer |
||||
|
from models.xbert import BertConfig, BertModel |
||||
|
|
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
class ALBEF(nn.Module): |
||||
|
def __init__(self, |
||||
|
text_encoder = None, |
||||
|
tokenizer = None, |
||||
|
config = None, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.tokenizer = tokenizer |
||||
|
self.distill = config['distill'] |
||||
|
|
||||
|
self.visual_encoder = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
|
||||
|
bert_config = BertConfig.from_json_file(config['bert_config']) |
||||
|
|
||||
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
|
||||
|
self.cls_head = nn.Sequential( |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, 3) |
||||
|
) |
||||
|
|
||||
|
if self.distill: |
||||
|
self.visual_encoder_m = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
||||
|
self.cls_head_m = nn.Sequential( |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
||||
|
nn.ReLU(), |
||||
|
nn.Linear(self.text_encoder.config.hidden_size, 3) |
||||
|
) |
||||
|
|
||||
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
||||
|
[self.text_encoder,self.text_encoder_m], |
||||
|
[self.cls_head,self.cls_head_m], |
||||
|
] |
||||
|
self.copy_params() |
||||
|
self.momentum = 0.995 |
||||
|
|
||||
|
|
||||
|
def forward(self, image, text, targets, alpha=0, train=True): |
||||
|
|
||||
|
image_embeds = self.visual_encoder(image) |
||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
||||
|
|
||||
|
if train: |
||||
|
output = self.text_encoder(text.input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True |
||||
|
) |
||||
|
prediction = self.cls_head(output.last_hidden_state[:,0,:]) |
||||
|
if self.distill: |
||||
|
with torch.no_grad(): |
||||
|
self._momentum_update() |
||||
|
image_embeds_m = self.visual_encoder_m(image) |
||||
|
output_m = self.text_encoder_m(text.input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds_m, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True |
||||
|
) |
||||
|
prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:]) |
||||
|
|
||||
|
loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum( |
||||
|
F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean() |
||||
|
else: |
||||
|
loss = F.cross_entropy(prediction, targets) |
||||
|
return loss |
||||
|
|
||||
|
else: |
||||
|
output = self.text_encoder(text.input_ids, |
||||
|
attention_mask = text.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True |
||||
|
) |
||||
|
prediction = self.cls_head(output.last_hidden_state[:,0,:]) |
||||
|
return prediction |
||||
|
|
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def copy_params(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data.copy_(param.data) # initialize |
||||
|
param_m.requires_grad = False # not update by gradient |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _momentum_update(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
||||
|
|
||||
|
|
@ -0,0 +1,214 @@ |
|||||
|
from functools import partial |
||||
|
from models.vit import VisionTransformer |
||||
|
from models.xbert import BertConfig, BertModel, BertLMHeadModel |
||||
|
|
||||
|
import torch |
||||
|
from torch import nn |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
import numpy as np |
||||
|
|
||||
|
class ALBEF(nn.Module): |
||||
|
def __init__(self, |
||||
|
text_encoder = None, |
||||
|
text_decoder = None, |
||||
|
tokenizer = None, |
||||
|
config = None, |
||||
|
): |
||||
|
super().__init__() |
||||
|
|
||||
|
self.tokenizer = tokenizer |
||||
|
self.distill = config['distill'] |
||||
|
|
||||
|
self.visual_encoder = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
|
||||
|
config_encoder = BertConfig.from_json_file(config['bert_config']) |
||||
|
self.text_encoder = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False) |
||||
|
|
||||
|
config_decoder = BertConfig.from_json_file(config['bert_config']) |
||||
|
config_decoder.fusion_layer = 0 |
||||
|
config_decoder.num_hidden_layers = 6 |
||||
|
self.text_decoder = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder) |
||||
|
|
||||
|
if self.distill: |
||||
|
self.visual_encoder_m = VisionTransformer( |
||||
|
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
||||
|
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
||||
|
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False) |
||||
|
self.text_decoder_m = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder) |
||||
|
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
||||
|
[self.text_encoder,self.text_encoder_m], |
||||
|
[self.text_decoder,self.text_decoder_m], |
||||
|
] |
||||
|
self.copy_params() |
||||
|
self.momentum = 0.995 |
||||
|
|
||||
|
|
||||
|
def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, train=True): |
||||
|
|
||||
|
image_embeds = self.visual_encoder(image) |
||||
|
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
||||
|
|
||||
|
if train: |
||||
|
''' |
||||
|
k: number of answers for each question |
||||
|
weights: weight for each answer |
||||
|
''' |
||||
|
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) |
||||
|
|
||||
|
question_output = self.text_encoder(quesiton.input_ids, |
||||
|
attention_mask = quesiton.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True) |
||||
|
|
||||
|
question_states = [] |
||||
|
question_atts = [] |
||||
|
for b, n in enumerate(k): |
||||
|
question_states += [question_output.last_hidden_state[b]]*n |
||||
|
question_atts += [quesiton.attention_mask[b]]*n |
||||
|
question_states = torch.stack(question_states,0) |
||||
|
question_atts = torch.stack(question_atts,0) |
||||
|
|
||||
|
if self.distill: |
||||
|
with torch.no_grad(): |
||||
|
self._momentum_update() |
||||
|
image_embeds_m = self.visual_encoder_m(image) |
||||
|
question_output_m = self.text_encoder_m(quesiton.input_ids, |
||||
|
attention_mask = quesiton.attention_mask, |
||||
|
encoder_hidden_states = image_embeds_m, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True) |
||||
|
|
||||
|
question_states_m = [] |
||||
|
for b, n in enumerate(k): |
||||
|
question_states_m += [question_output_m.last_hidden_state[b]]*n |
||||
|
question_states_m = torch.stack(question_states_m,0) |
||||
|
|
||||
|
logits_m = self.text_decoder_m(answer.input_ids, |
||||
|
attention_mask = answer.attention_mask, |
||||
|
encoder_hidden_states = question_states_m, |
||||
|
encoder_attention_mask = question_atts, |
||||
|
return_logits = True, |
||||
|
) |
||||
|
|
||||
|
answer_output = self.text_decoder(answer.input_ids, |
||||
|
attention_mask = answer.attention_mask, |
||||
|
encoder_hidden_states = question_states, |
||||
|
encoder_attention_mask = question_atts, |
||||
|
labels = answer_targets, |
||||
|
return_dict = True, |
||||
|
soft_labels = F.softmax(logits_m,dim=-1), |
||||
|
alpha = alpha, |
||||
|
reduction = 'none', |
||||
|
) |
||||
|
else: |
||||
|
answer_output = self.text_decoder(answer.input_ids, |
||||
|
attention_mask = answer.attention_mask, |
||||
|
encoder_hidden_states = question_states, |
||||
|
encoder_attention_mask = question_atts, |
||||
|
labels = answer_targets, |
||||
|
return_dict = True, |
||||
|
reduction = 'none', |
||||
|
) |
||||
|
loss = weights * answer_output.loss |
||||
|
loss = loss.sum()/image.size(0) |
||||
|
|
||||
|
return loss |
||||
|
|
||||
|
|
||||
|
else: |
||||
|
question_output = self.text_encoder(quesiton.input_ids, |
||||
|
attention_mask = quesiton.attention_mask, |
||||
|
encoder_hidden_states = image_embeds, |
||||
|
encoder_attention_mask = image_atts, |
||||
|
return_dict = True) |
||||
|
topk_ids, topk_probs = self.rank_answer(question_output.last_hidden_state, quesiton.attention_mask, |
||||
|
answer.input_ids, answer.attention_mask, k) |
||||
|
return topk_ids, topk_probs |
||||
|
|
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def copy_params(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data.copy_(param.data) # initialize |
||||
|
param_m.requires_grad = False # not update by gradient |
||||
|
|
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def _momentum_update(self): |
||||
|
for model_pair in self.model_pairs: |
||||
|
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
||||
|
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
||||
|
|
||||
|
|
||||
|
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): |
||||
|
|
||||
|
num_ques = question_states.size(0) |
||||
|
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token |
||||
|
|
||||
|
start_output = self.text_decoder(start_ids, |
||||
|
encoder_hidden_states = question_states, |
||||
|
encoder_attention_mask = question_atts, |
||||
|
return_dict = True, |
||||
|
reduction = 'none') |
||||
|
logits = start_output.logits[:,0,:] # first token's logit |
||||
|
|
||||
|
# topk_probs: top-k probability |
||||
|
# topk_ids: [num_question, k] |
||||
|
answer_first_token = answer_ids[:,1] |
||||
|
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) |
||||
|
topk_probs, topk_ids = prob_first_token.topk(k,dim=1) |
||||
|
|
||||
|
# answer input: [num_question*k, answer_len] |
||||
|
input_ids = [] |
||||
|
input_atts = [] |
||||
|
for b, topk_id in enumerate(topk_ids): |
||||
|
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) |
||||
|
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) |
||||
|
input_ids = torch.cat(input_ids,dim=0) |
||||
|
input_atts = torch.cat(input_atts,dim=0) |
||||
|
|
||||
|
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) |
||||
|
|
||||
|
# repeat encoder's output for top-k answers |
||||
|
question_states = tile(question_states, 0, k) |
||||
|
question_atts = tile(question_atts, 0, k) |
||||
|
|
||||
|
output = self.text_decoder(input_ids, |
||||
|
attention_mask = input_atts, |
||||
|
encoder_hidden_states = question_states, |
||||
|
encoder_attention_mask = question_atts, |
||||
|
labels = targets_ids, |
||||
|
return_dict = True, |
||||
|
reduction = 'none') |
||||
|
|
||||
|
answer_loss = output.loss |
||||
|
answer_loss = answer_loss.view(input_ids.size(0),-1) |
||||
|
|
||||
|
# topk_prob: first token probability |
||||
|
topk_probs = topk_probs.view(-1,1) |
||||
|
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1) |
||||
|
|
||||
|
# re-calculate log probabilities for the answer sequences using chain rule |
||||
|
log_probs_sum = log_probs.sum(1) |
||||
|
log_probs_sum = log_probs_sum.view(num_ques,k) |
||||
|
|
||||
|
topk_probs = F.softmax(log_probs_sum, dim=-1) |
||||
|
# get top-k after re-ranking |
||||
|
topk_probs, rerank_id = topk_probs.topk(k,dim=1) |
||||
|
topk_ids = torch.gather(topk_ids, 1, rerank_id) |
||||
|
|
||||
|
return topk_ids, topk_probs |
||||
|
|
||||
|
def tile(x, dim, n_tile): |
||||
|
init_dim = x.size(dim) |
||||
|
repeat_idx = [1] * x.dim() |
||||
|
repeat_idx[dim] = n_tile |
||||
|
x = x.repeat(*(repeat_idx)) |
||||
|
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) |
||||
|
return torch.index_select(x, dim, order_index.to(x.device)) |
@ -0,0 +1,539 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
"""Tokenization classes for Bert.""" |
||||
|
|
||||
|
|
||||
|
import collections |
||||
|
import os |
||||
|
import unicodedata |
||||
|
from typing import List, Optional, Tuple |
||||
|
|
||||
|
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace |
||||
|
from transformers.utils import logging |
||||
|
|
||||
|
|
||||
|
logger = logging.get_logger(__name__) |
||||
|
|
||||
|
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} |
||||
|
|
||||
|
PRETRAINED_VOCAB_FILES_MAP = { |
||||
|
"vocab_file": { |
||||
|
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", |
||||
|
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", |
||||
|
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", |
||||
|
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", |
||||
|
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", |
||||
|
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", |
||||
|
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", |
||||
|
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", |
||||
|
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", |
||||
|
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", |
||||
|
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", |
||||
|
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", |
||||
|
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", |
||||
|
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", |
||||
|
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", |
||||
|
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", |
||||
|
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", |
||||
|
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
||||
|
"bert-base-uncased": 512, |
||||
|
"bert-large-uncased": 512, |
||||
|
"bert-base-cased": 512, |
||||
|
"bert-large-cased": 512, |
||||
|
"bert-base-multilingual-uncased": 512, |
||||
|
"bert-base-multilingual-cased": 512, |
||||
|
"bert-base-chinese": 512, |
||||
|
"bert-base-german-cased": 512, |
||||
|
"bert-large-uncased-whole-word-masking": 512, |
||||
|
"bert-large-cased-whole-word-masking": 512, |
||||
|
"bert-large-uncased-whole-word-masking-finetuned-squad": 512, |
||||
|
"bert-large-cased-whole-word-masking-finetuned-squad": 512, |
||||
|
"bert-base-cased-finetuned-mrpc": 512, |
||||
|
"bert-base-german-dbmdz-cased": 512, |
||||
|
"bert-base-german-dbmdz-uncased": 512, |
||||
|
"TurkuNLP/bert-base-finnish-cased-v1": 512, |
||||
|
"TurkuNLP/bert-base-finnish-uncased-v1": 512, |
||||
|
"wietsedv/bert-base-dutch-cased": 512, |
||||
|
} |
||||
|
|
||||
|
PRETRAINED_INIT_CONFIGURATION = { |
||||
|
"bert-base-uncased": {"do_lower_case": True}, |
||||
|
"bert-large-uncased": {"do_lower_case": True}, |
||||
|
"bert-base-cased": {"do_lower_case": False}, |
||||
|
"bert-large-cased": {"do_lower_case": False}, |
||||
|
"bert-base-multilingual-uncased": {"do_lower_case": True}, |
||||
|
"bert-base-multilingual-cased": {"do_lower_case": False}, |
||||
|
"bert-base-chinese": {"do_lower_case": False}, |
||||
|
"bert-base-german-cased": {"do_lower_case": False}, |
||||
|
"bert-large-uncased-whole-word-masking": {"do_lower_case": True}, |
||||
|
"bert-large-cased-whole-word-masking": {"do_lower_case": False}, |
||||
|
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, |
||||
|
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, |
||||
|
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, |
||||
|
"bert-base-german-dbmdz-cased": {"do_lower_case": False}, |
||||
|
"bert-base-german-dbmdz-uncased": {"do_lower_case": True}, |
||||
|
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, |
||||
|
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, |
||||
|
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, |
||||
|
} |
||||
|
|
||||
|
|
||||
|
def load_vocab(vocab_file): |
||||
|
"""Loads a vocabulary file into a dictionary.""" |
||||
|
vocab = collections.OrderedDict() |
||||
|
with open(vocab_file, "r", encoding="utf-8") as reader: |
||||
|
tokens = reader.readlines() |
||||
|
for index, token in enumerate(tokens): |
||||
|
token = token.rstrip("\n") |
||||
|
vocab[token] = index |
||||
|
return vocab |
||||
|
|
||||
|
|
||||
|
def whitespace_tokenize(text): |
||||
|
"""Runs basic whitespace cleaning and splitting on a piece of text.""" |
||||
|
text = text.strip() |
||||
|
if not text: |
||||
|
return [] |
||||
|
tokens = text.split() |
||||
|
return tokens |
||||
|
|
||||
|
|
||||
|
class BertTokenizer(PreTrainedTokenizer): |
||||
|
r""" |
||||
|
Construct a BERT tokenizer. Based on WordPiece. |
||||
|
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. |
||||
|
Users should refer to this superclass for more information regarding those methods. |
||||
|
Args: |
||||
|
vocab_file (:obj:`str`): |
||||
|
File containing the vocabulary. |
||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): |
||||
|
Whether or not to lowercase the input when tokenizing. |
||||
|
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): |
||||
|
Whether or not to do basic tokenization before WordPiece. |
||||
|
never_split (:obj:`Iterable`, `optional`): |
||||
|
Collection of tokens which will never be split during tokenization. Only has an effect when |
||||
|
:obj:`do_basic_tokenize=True` |
||||
|
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): |
||||
|
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this |
||||
|
token instead. |
||||
|
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): |
||||
|
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for |
||||
|
sequence classification or for a text and a question for question answering. It is also used as the last |
||||
|
token of a sequence built with special tokens. |
||||
|
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): |
||||
|
The token used for padding, for example when batching sequences of different lengths. |
||||
|
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): |
||||
|
The classifier token which is used when doing sequence classification (classification of the whole sequence |
||||
|
instead of per-token classification). It is the first token of the sequence when built with special tokens. |
||||
|
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): |
||||
|
The token used for masking values. This is the token used when training this model with masked language |
||||
|
modeling. This is the token which the model will try to predict. |
||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): |
||||
|
Whether or not to tokenize Chinese characters. |
||||
|
This should likely be deactivated for Japanese (see this `issue |
||||
|
<https://github.com/huggingface/transformers/issues/328>`__). |
||||
|
strip_accents: (:obj:`bool`, `optional`): |
||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the |
||||
|
value for :obj:`lowercase` (as in the original BERT). |
||||
|
""" |
||||
|
|
||||
|
vocab_files_names = VOCAB_FILES_NAMES |
||||
|
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
||||
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION |
||||
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
vocab_file, |
||||
|
do_lower_case=True, |
||||
|
do_basic_tokenize=True, |
||||
|
never_split=None, |
||||
|
unk_token="[UNK]", |
||||
|
sep_token="[SEP]", |
||||
|
pad_token="[PAD]", |
||||
|
cls_token="[CLS]", |
||||
|
mask_token="[MASK]", |
||||
|
tokenize_chinese_chars=True, |
||||
|
strip_accents=None, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__( |
||||
|
do_lower_case=do_lower_case, |
||||
|
do_basic_tokenize=do_basic_tokenize, |
||||
|
never_split=never_split, |
||||
|
unk_token=unk_token, |
||||
|
sep_token=sep_token, |
||||
|
pad_token=pad_token, |
||||
|
cls_token=cls_token, |
||||
|
mask_token=mask_token, |
||||
|
tokenize_chinese_chars=tokenize_chinese_chars, |
||||
|
strip_accents=strip_accents, |
||||
|
**kwargs, |
||||
|
) |
||||
|
|
||||
|
if not os.path.isfile(vocab_file): |
||||
|
raise ValueError( |
||||
|
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " |
||||
|
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) |
||||
|
) |
||||
|
self.vocab = load_vocab(vocab_file) |
||||
|
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) |
||||
|
self.do_basic_tokenize = do_basic_tokenize |
||||
|
if do_basic_tokenize: |
||||
|
self.basic_tokenizer = BasicTokenizer( |
||||
|
do_lower_case=do_lower_case, |
||||
|
never_split=never_split, |
||||
|
tokenize_chinese_chars=tokenize_chinese_chars, |
||||
|
strip_accents=strip_accents, |
||||
|
) |
||||
|
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) |
||||
|
|
||||
|
@property |
||||
|
def do_lower_case(self): |
||||
|
return self.basic_tokenizer.do_lower_case |
||||
|
|
||||
|
@property |
||||
|
def vocab_size(self): |
||||
|
return len(self.vocab) |
||||
|
|
||||
|
def get_vocab(self): |
||||
|
return dict(self.vocab, **self.added_tokens_encoder) |
||||
|
|
||||
|
def _tokenize(self, text): |
||||
|
split_tokens = [] |
||||
|
if self.do_basic_tokenize: |
||||
|
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): |
||||
|
|
||||
|
# If the token is part of the never_split set |
||||
|
if token in self.basic_tokenizer.never_split: |
||||
|
split_tokens.append(token) |
||||
|
else: |
||||
|
split_tokens += self.wordpiece_tokenizer.tokenize(token) |
||||
|
else: |
||||
|
split_tokens = self.wordpiece_tokenizer.tokenize(text) |
||||
|
return split_tokens |
||||
|
|
||||
|
def _convert_token_to_id(self, token): |
||||
|
""" Converts a token (str) in an id using the vocab. """ |
||||
|
return self.vocab.get(token, self.vocab.get(self.unk_token)) |
||||
|
|
||||
|
def _convert_id_to_token(self, index): |
||||
|
"""Converts an index (integer) in a token (str) using the vocab.""" |
||||
|
return self.ids_to_tokens.get(index, self.unk_token) |
||||
|
|
||||
|
def convert_tokens_to_string(self, tokens): |
||||
|
""" Converts a sequence of tokens (string) in a single string. """ |
||||
|
out_string = " ".join(tokens).replace(" ##", "").strip() |
||||
|
return out_string |
||||
|
|
||||
|
def build_inputs_with_special_tokens( |
||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
||||
|
) -> List[int]: |
||||
|
""" |
||||
|
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and |
||||
|
adding special tokens. A BERT sequence has the following format: |
||||
|
- single sequence: ``[CLS] X `` |
||||
|
- pair of sequences: ``[CLS] A [SEP] B [SEP]`` |
||||
|
Args: |
||||
|
token_ids_0 (:obj:`List[int]`): |
||||
|
List of IDs to which the special tokens will be added. |
||||
|
token_ids_1 (:obj:`List[int]`, `optional`): |
||||
|
Optional second list of IDs for sequence pairs. |
||||
|
Returns: |
||||
|
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. |
||||
|
""" |
||||
|
if token_ids_1 is None: |
||||
|
return [self.cls_token_id] + token_ids_0 |
||||
|
cls = [self.cls_token_id] |
||||
|
sep = [self.sep_token_id] |
||||
|
return cls + token_ids_0 + sep + token_ids_1 + sep |
||||
|
|
||||
|
def get_special_tokens_mask( |
||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False |
||||
|
) -> List[int]: |
||||
|
""" |
||||
|
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding |
||||
|
special tokens using the tokenizer ``prepare_for_model`` method. |
||||
|
Args: |
||||
|
token_ids_0 (:obj:`List[int]`): |
||||
|
List of IDs. |
||||
|
token_ids_1 (:obj:`List[int]`, `optional`): |
||||
|
Optional second list of IDs for sequence pairs. |
||||
|
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): |
||||
|
Whether or not the token list is already formatted with special tokens for the model. |
||||
|
Returns: |
||||
|
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. |
||||
|
""" |
||||
|
|
||||
|
if already_has_special_tokens: |
||||
|
if token_ids_1 is not None: |
||||
|
raise ValueError( |
||||
|
"You should not supply a second sequence if the provided sequence of " |
||||
|
"ids is already formatted with special tokens for the model." |
||||
|
) |
||||
|
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) |
||||
|
|
||||
|
if token_ids_1 is not None: |
||||
|
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
||||
|
return [1] + ([0] * len(token_ids_0)) |
||||
|
|
||||
|
def create_token_type_ids_from_sequences( |
||||
|
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
||||
|
) -> List[int]: |
||||
|
""" |
||||
|
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence |
||||
|
pair mask has the following format: |
||||
|
:: |
||||
|
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 |
||||
|
| first sequence | second sequence | |
||||
|
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). |
||||
|
Args: |
||||
|
token_ids_0 (:obj:`List[int]`): |
||||
|
List of IDs. |
||||
|
token_ids_1 (:obj:`List[int]`, `optional`): |
||||
|
Optional second list of IDs for sequence pairs. |
||||
|
Returns: |
||||
|
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given |
||||
|
sequence(s). |
||||
|
""" |
||||
|
sep = [self.sep_token_id] |
||||
|
cls = [self.cls_token_id] |
||||
|
if token_ids_1 is None: |
||||
|
return len(cls + token_ids_0 + sep) * [0] |
||||
|
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] |
||||
|
|
||||
|
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
||||
|
index = 0 |
||||
|
if os.path.isdir(save_directory): |
||||
|
vocab_file = os.path.join( |
||||
|
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] |
||||
|
) |
||||
|
else: |
||||
|
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory |
||||
|
with open(vocab_file, "w", encoding="utf-8") as writer: |
||||
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): |
||||
|
if index != token_index: |
||||
|
logger.warning( |
||||
|
"Saving vocabulary to {}: vocabulary indices are not consecutive." |
||||
|
" Please check that the vocabulary is not corrupted!".format(vocab_file) |
||||
|
) |
||||
|
index = token_index |
||||
|
writer.write(token + "\n") |
||||
|
index += 1 |
||||
|
return (vocab_file,) |
||||
|
|
||||
|
|
||||
|
class BasicTokenizer(object): |
||||
|
""" |
||||
|
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). |
||||
|
Args: |
||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): |
||||
|
Whether or not to lowercase the input when tokenizing. |
||||
|
never_split (:obj:`Iterable`, `optional`): |
||||
|
Collection of tokens which will never be split during tokenization. Only has an effect when |
||||
|
:obj:`do_basic_tokenize=True` |
||||
|
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): |
||||
|
Whether or not to tokenize Chinese characters. |
||||
|
This should likely be deactivated for Japanese (see this `issue |
||||
|
<https://github.com/huggingface/transformers/issues/328>`__). |
||||
|
strip_accents: (:obj:`bool`, `optional`): |
||||
|
Whether or not to strip all accents. If this option is not specified, then it will be determined by the |
||||
|
value for :obj:`lowercase` (as in the original BERT). |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): |
||||
|
if never_split is None: |
||||
|
never_split = [] |
||||
|
self.do_lower_case = do_lower_case |
||||
|
self.never_split = set(never_split) |
||||
|
self.tokenize_chinese_chars = tokenize_chinese_chars |
||||
|
self.strip_accents = strip_accents |
||||
|
|
||||
|
def tokenize(self, text, never_split=None): |
||||
|
""" |
||||
|
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see |
||||
|
WordPieceTokenizer. |
||||
|
Args: |
||||
|
**never_split**: (`optional`) list of str |
||||
|
Kept for backward compatibility purposes. Now implemented directly at the base class level (see |
||||
|
:func:`PreTrainedTokenizer.tokenize`) List of token not to split. |
||||
|
""" |
||||
|
# union() returns a new set by concatenating the two sets. |
||||
|
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split |
||||
|
text = self._clean_text(text) |
||||
|
|
||||
|
# This was added on November 1st, 2018 for the multilingual and Chinese |
||||
|
# models. This is also applied to the English models now, but it doesn't |
||||
|
# matter since the English models were not trained on any Chinese data |
||||
|
# and generally don't have any Chinese data in them (there are Chinese |
||||
|
# characters in the vocabulary because Wikipedia does have some Chinese |
||||
|
# words in the English Wikipedia.). |
||||
|
if self.tokenize_chinese_chars: |
||||
|
text = self._tokenize_chinese_chars(text) |
||||
|
orig_tokens = whitespace_tokenize(text) |
||||
|
split_tokens = [] |
||||
|
for token in orig_tokens: |
||||
|
if token not in never_split: |
||||
|
if self.do_lower_case: |
||||
|
token = token.lower() |
||||
|
if self.strip_accents is not False: |
||||
|
token = self._run_strip_accents(token) |
||||
|
elif self.strip_accents: |
||||
|
token = self._run_strip_accents(token) |
||||
|
split_tokens.extend(self._run_split_on_punc(token, never_split)) |
||||
|
|
||||
|
output_tokens = whitespace_tokenize(" ".join(split_tokens)) |
||||
|
return output_tokens |
||||
|
|
||||
|
def _run_strip_accents(self, text): |
||||
|
"""Strips accents from a piece of text.""" |
||||
|
text = unicodedata.normalize("NFD", text) |
||||
|
output = [] |
||||
|
for char in text: |
||||
|
cat = unicodedata.category(char) |
||||
|
if cat == "Mn": |
||||
|
continue |
||||
|
output.append(char) |
||||
|
return "".join(output) |
||||
|
|
||||
|
def _run_split_on_punc(self, text, never_split=None): |
||||
|
"""Splits punctuation on a piece of text.""" |
||||
|
if never_split is not None and text in never_split: |
||||
|
return [text] |
||||
|
chars = list(text) |
||||
|
i = 0 |
||||
|
start_new_word = True |
||||
|
output = [] |
||||
|
while i < len(chars): |
||||
|
char = chars[i] |
||||
|
if _is_punctuation(char): |
||||
|
output.append([char]) |
||||
|
start_new_word = True |
||||
|
else: |
||||
|
if start_new_word: |
||||
|
output.append([]) |
||||
|
start_new_word = False |
||||
|
output[-1].append(char) |
||||
|
i += 1 |
||||
|
|
||||
|
return ["".join(x) for x in output] |
||||
|
|
||||
|
def _tokenize_chinese_chars(self, text): |
||||
|
"""Adds whitespace around any CJK character.""" |
||||
|
output = [] |
||||
|
for char in text: |
||||
|
cp = ord(char) |
||||
|
if self._is_chinese_char(cp): |
||||
|
output.append(" ") |
||||
|
output.append(char) |
||||
|
output.append(" ") |
||||
|
else: |
||||
|
output.append(char) |
||||
|
return "".join(output) |
||||
|
|
||||
|
def _is_chinese_char(self, cp): |
||||
|
"""Checks whether CP is the codepoint of a CJK character.""" |
||||
|
# This defines a "chinese character" as anything in the CJK Unicode block: |
||||
|
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) |
||||
|
# |
||||
|
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, |
||||
|
# despite its name. The modern Korean Hangul alphabet is a different block, |
||||
|
# as is Japanese Hiragana and Katakana. Those alphabets are used to write |
||||
|
# space-separated words, so they are not treated specially and handled |
||||
|
# like the all of the other languages. |
||||
|
if ( |
||||
|
(cp >= 0x4E00 and cp <= 0x9FFF) |
||||
|
or (cp >= 0x3400 and cp <= 0x4DBF) # |
||||
|
or (cp >= 0x20000 and cp <= 0x2A6DF) # |
||||
|
or (cp >= 0x2A700 and cp <= 0x2B73F) # |
||||
|
or (cp >= 0x2B740 and cp <= 0x2B81F) # |
||||
|
or (cp >= 0x2B820 and cp <= 0x2CEAF) # |
||||
|
or (cp >= 0xF900 and cp <= 0xFAFF) |
||||
|
or (cp >= 0x2F800 and cp <= 0x2FA1F) # |
||||
|
): # |
||||
|
return True |
||||
|
|
||||
|
return False |
||||
|
|
||||
|
def _clean_text(self, text): |
||||
|
"""Performs invalid character removal and whitespace cleanup on text.""" |
||||
|
output = [] |
||||
|
for char in text: |
||||
|
cp = ord(char) |
||||
|
if cp == 0 or cp == 0xFFFD or _is_control(char): |
||||
|
continue |
||||
|
if _is_whitespace(char): |
||||
|
output.append(" ") |
||||
|
else: |
||||
|
output.append(char) |
||||
|
return "".join(output) |
||||
|
|
||||
|
|
||||
|
class WordpieceTokenizer(object): |
||||
|
"""Runs WordPiece tokenization.""" |
||||
|
|
||||
|
def __init__(self, vocab, unk_token, max_input_chars_per_word=100): |
||||
|
self.vocab = vocab |
||||
|
self.unk_token = unk_token |
||||
|
self.max_input_chars_per_word = max_input_chars_per_word |
||||
|
|
||||
|
def tokenize(self, text): |
||||
|
""" |
||||
|
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform |
||||
|
tokenization using the given vocabulary. |
||||
|
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. |
||||
|
Args: |
||||
|
text: A single token or whitespace separated tokens. This should have |
||||
|
already been passed through `BasicTokenizer`. |
||||
|
Returns: |
||||
|
A list of wordpiece tokens. |
||||
|
""" |
||||
|
|
||||
|
output_tokens = [] |
||||
|
for token in whitespace_tokenize(text): |
||||
|
chars = list(token) |
||||
|
if len(chars) > self.max_input_chars_per_word: |
||||
|
output_tokens.append(self.unk_token) |
||||
|
continue |
||||
|
|
||||
|
is_bad = False |
||||
|
start = 0 |
||||
|
sub_tokens = [] |
||||
|
while start < len(chars): |
||||
|
end = len(chars) |
||||
|
cur_substr = None |
||||
|
while start < end: |
||||
|
substr = "".join(chars[start:end]) |
||||
|
if start > 0: |
||||
|
substr = "##" + substr |
||||
|
if substr in self.vocab: |
||||
|
cur_substr = substr |
||||
|
break |
||||
|
end -= 1 |
||||
|
if cur_substr is None: |
||||
|
is_bad = True |
||||
|
break |
||||
|
sub_tokens.append(cur_substr) |
||||
|
start = end |
||||
|
|
||||
|
if is_bad: |
||||
|
output_tokens.append(self.unk_token) |
||||
|
else: |
||||
|
output_tokens.extend(sub_tokens) |
||||
|
return output_tokens |
@ -0,0 +1,202 @@ |
|||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
from functools import partial |
||||
|
|
||||
|
from timm.models.vision_transformer import _cfg, PatchEmbed |
||||
|
from timm.models.registry import register_model |
||||
|
from timm.models.layers import trunc_normal_, DropPath |
||||
|
|
||||
|
|
||||
|
class Mlp(nn.Module): |
||||
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks |
||||
|
""" |
||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
||||
|
super().__init__() |
||||
|
out_features = out_features or in_features |
||||
|
hidden_features = hidden_features or in_features |
||||
|
self.fc1 = nn.Linear(in_features, hidden_features) |
||||
|
self.act = act_layer() |
||||
|
self.fc2 = nn.Linear(hidden_features, out_features) |
||||
|
self.drop = nn.Dropout(drop) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.fc1(x) |
||||
|
x = self.act(x) |
||||
|
x = self.drop(x) |
||||
|
x = self.fc2(x) |
||||
|
x = self.drop(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class Attention(nn.Module): |
||||
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
||||
|
super().__init__() |
||||
|
self.num_heads = num_heads |
||||
|
head_dim = dim // num_heads |
||||
|
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights |
||||
|
self.scale = qk_scale or head_dim ** -0.5 |
||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
||||
|
self.attn_drop = nn.Dropout(attn_drop) |
||||
|
self.proj = nn.Linear(dim, dim) |
||||
|
self.proj_drop = nn.Dropout(proj_drop) |
||||
|
self.attn_gradients = None |
||||
|
self.attention_map = None |
||||
|
|
||||
|
def save_attn_gradients(self, attn_gradients): |
||||
|
self.attn_gradients = attn_gradients |
||||
|
|
||||
|
def get_attn_gradients(self): |
||||
|
return self.attn_gradients |
||||
|
|
||||
|
def save_attention_map(self, attention_map): |
||||
|
self.attention_map = attention_map |
||||
|
|
||||
|
def get_attention_map(self): |
||||
|
return self.attention_map |
||||
|
|
||||
|
def forward(self, x, register_hook=False): |
||||
|
B, N, C = x.shape |
||||
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
||||
|
|
||||
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
||||
|
attn = attn.softmax(dim=-1) |
||||
|
attn = self.attn_drop(attn) |
||||
|
|
||||
|
if register_hook: |
||||
|
self.save_attention_map(attn) |
||||
|
attn.register_hook(self.save_attn_gradients) |
||||
|
|
||||
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
||||
|
x = self.proj(x) |
||||
|
x = self.proj_drop(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class Block(nn.Module): |
||||
|
|
||||
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
||||
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
||||
|
super().__init__() |
||||
|
self.norm1 = norm_layer(dim) |
||||
|
self.attn = Attention( |
||||
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
||||
|
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
||||
|
self.norm2 = norm_layer(dim) |
||||
|
mlp_hidden_dim = int(dim * mlp_ratio) |
||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
||||
|
|
||||
|
def forward(self, x, register_hook=False): |
||||
|
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) |
||||
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class VisionTransformer(nn.Module): |
||||
|
""" Vision Transformer |
||||
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - |
||||
|
https://arxiv.org/abs/2010.11929 |
||||
|
""" |
||||
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, |
||||
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, |
||||
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None): |
||||
|
""" |
||||
|
Args: |
||||
|
img_size (int, tuple): input image size |
||||
|
patch_size (int, tuple): patch size |
||||
|
in_chans (int): number of input channels |
||||
|
num_classes (int): number of classes for classification head |
||||
|
embed_dim (int): embedding dimension |
||||
|
depth (int): depth of transformer |
||||
|
num_heads (int): number of attention heads |
||||
|
mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
||||
|
qkv_bias (bool): enable bias for qkv if True |
||||
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
||||
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
||||
|
drop_rate (float): dropout rate |
||||
|
attn_drop_rate (float): attention dropout rate |
||||
|
drop_path_rate (float): stochastic depth rate |
||||
|
norm_layer: (nn.Module): normalization layer |
||||
|
""" |
||||
|
super().__init__() |
||||
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models |
||||
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
||||
|
|
||||
|
self.patch_embed = PatchEmbed( |
||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
||||
|
num_patches = self.patch_embed.num_patches |
||||
|
|
||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
||||
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
||||
|
self.pos_drop = nn.Dropout(p=drop_rate) |
||||
|
|
||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
||||
|
self.blocks = nn.ModuleList([ |
||||
|
Block( |
||||
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
||||
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) |
||||
|
for i in range(depth)]) |
||||
|
self.norm = norm_layer(embed_dim) |
||||
|
|
||||
|
trunc_normal_(self.pos_embed, std=.02) |
||||
|
trunc_normal_(self.cls_token, std=.02) |
||||
|
self.apply(self._init_weights) |
||||
|
|
||||
|
def _init_weights(self, m): |
||||
|
if isinstance(m, nn.Linear): |
||||
|
trunc_normal_(m.weight, std=.02) |
||||
|
if isinstance(m, nn.Linear) and m.bias is not None: |
||||
|
nn.init.constant_(m.bias, 0) |
||||
|
elif isinstance(m, nn.LayerNorm): |
||||
|
nn.init.constant_(m.bias, 0) |
||||
|
nn.init.constant_(m.weight, 1.0) |
||||
|
|
||||
|
@torch.jit.ignore |
||||
|
def no_weight_decay(self): |
||||
|
return {'pos_embed', 'cls_token'} |
||||
|
|
||||
|
def forward(self, x, register_blk=-1): |
||||
|
B = x.shape[0] |
||||
|
x = self.patch_embed(x) |
||||
|
|
||||
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks |
||||
|
x = torch.cat((cls_tokens, x), dim=1) |
||||
|
|
||||
|
x = x + self.pos_embed[:,:x.size(1),:] |
||||
|
x = self.pos_drop(x) |
||||
|
|
||||
|
for i,blk in enumerate(self.blocks): |
||||
|
x = blk(x, register_blk==i) |
||||
|
x = self.norm(x) |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
|
||||
|
|
||||
|
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): |
||||
|
# interpolate position embedding |
||||
|
embedding_size = pos_embed_checkpoint.shape[-1] |
||||
|
num_patches = visual_encoder.patch_embed.num_patches |
||||
|
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches |
||||
|
# height (== width) for the checkpoint position embedding |
||||
|
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
||||
|
# height (== width) for the new position embedding |
||||
|
new_size = int(num_patches ** 0.5) |
||||
|
|
||||
|
if orig_size!=new_size: |
||||
|
# class_token and dist_token are kept unchanged |
||||
|
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
||||
|
# only the position tokens are interpolated |
||||
|
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
||||
|
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
||||
|
pos_tokens = torch.nn.functional.interpolate( |
||||
|
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
||||
|
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
||||
|
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
||||
|
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) |
||||
|
|
||||
|
return new_pos_embed |
||||
|
else: |
||||
|
return pos_embed_checkpoint |
File diff suppressed because it is too large
Loading…
Reference in new issue