albef
copied
wxywb
2 years ago
18 changed files with 3847 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .albef import Albef |
|||
|
|||
def albef(model_name: str, modality: str): |
|||
return Albef(model_name, modality) |
@ -0,0 +1,113 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import sys |
|||
from pathlib import Path |
|||
from PIL import Image |
|||
import torch |
|||
import yaml |
|||
from torchvision import transforms |
|||
|
|||
from towhee.types.image_utils import to_pil |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee.types.arg import arg, to_image_color |
|||
from towhee import register |
|||
|
|||
@register(output_schema=['vec']) |
|||
class Albef(NNOperator): |
|||
""" |
|||
ALBEF multi-modal embedding operator |
|||
""" |
|||
def prepare_model(checkpoint_path, model): |
|||
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|||
state_dict = checkpoint['model'] |
|||
pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder.pos_embed'],model.visual_encoder) |
|||
state_dict['visual_encoder.pos_embed'] = pos_embed_reshaped |
|||
m_pos_embed_reshaped = interpolate_pos_embed(state_dict['visual_encoder_m.pos_embed'],model.visual_encoder_m) |
|||
state_dict['visual_encoder_m.pos_embed'] = m_pos_embed_reshaped |
|||
for key in list(state_dict.keys()): |
|||
if 'bert' in key: |
|||
encoder_key = key.replace('bert.','') |
|||
state_dict[encoder_key] = state_dict[key] |
|||
del state_dict[key] |
|||
msg = model.load_state_dict(state_dict,strict=False) |
|||
print('load checkpoint from ' + checkpoint_path) |
|||
return model |
|||
|
|||
def __init__(self, model_name: str, modality: str): |
|||
self.modality = modality |
|||
config = self._configs()[model_name] |
|||
|
|||
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|||
|
|||
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|||
tokenizer = BertTokenizer.from_pretrained(config) |
|||
model = ALBEF(config=config, text_encoder=config['text_encoder'], tokenizer=tokenizer) |
|||
cfg = yaml.load(open(config['cfg'], 'r'), Loader=yaml.Loader) |
|||
checkpoint_path = cfg['ckpt_path'] |
|||
|
|||
self.model = self.prepare_model(checkpoint_path, model) |
|||
|
|||
self.test_transform = transforms.Compose([ |
|||
transforms.Resize((cfg['image_res'],cfg['image_res']),interpolation=Image.BICUBIC), |
|||
transforms.ToTensor(), |
|||
normalize, |
|||
]) |
|||
|
|||
|
|||
def inference_single_data(self, data): |
|||
if self.modality == 'image': |
|||
vec = self._inference_from_image(data) |
|||
elif self.modality == 'text': |
|||
vec = self._inference_from_text(data) |
|||
else: |
|||
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|||
return vec.detach().cpu().numpy().flatten() |
|||
|
|||
def __call__(self, data): |
|||
if not isinstance(data, list): |
|||
data = [data] |
|||
else: |
|||
data = data |
|||
results = [] |
|||
for single_data in data: |
|||
result = self.inference_single_data(single_data) |
|||
results.append(result) |
|||
if len(data) == 1: |
|||
return results[0] |
|||
else: |
|||
return results |
|||
|
|||
def _inference_from_text(self, text): |
|||
tokens = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].to(self.device) |
|||
text_features = self.text_encoder(tokens).logits |
|||
return text_features |
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def _inference_from_image(self, img): |
|||
image = to_pil(img) |
|||
image = self.processor(images=image, return_tensors="pt").to(self.device) |
|||
image_features = self.clip_model.get_image_features(**image) |
|||
return image_features |
|||
|
|||
def _configs(self): |
|||
config = {} |
|||
config['albef_4m'] = {} |
|||
config['albef_4m']['tokenizer'] = 'bert-base-uncased' |
|||
config['albef_4m']['text_encoder'] = 'bert-base-uncased' |
|||
config['albef_4m']['cfg_path'] = './configs/Retrieval_flickr.yaml' |
|||
config['albef_4m']['ckpt_path'] = '' |
|||
|
|||
|
|||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -0,0 +1,128 @@ |
|||
from functools import partial |
|||
from models.vit import VisionTransformer |
|||
from models.xbert import BertConfig, BertModel |
|||
|
|||
import torch |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
|
|||
class ALBEF(nn.Module): |
|||
def __init__(self, |
|||
text_encoder = None, |
|||
tokenizer = None, |
|||
config = None, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.tokenizer = tokenizer |
|||
self.distill = config['distill'] |
|||
|
|||
self.visual_encoder = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
|
|||
bert_config = BertConfig.from_json_file(config['bert_config']) |
|||
bert_config.num_hidden_layers = 18 |
|||
|
|||
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
self.cls_head = nn.Sequential( |
|||
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
|||
nn.ReLU(), |
|||
nn.Linear(self.text_encoder.config.hidden_size, 2) |
|||
) |
|||
|
|||
self.share_cross_attention(self.text_encoder.encoder) |
|||
|
|||
if self.distill: |
|||
self.visual_encoder_m = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
self.share_cross_attention(self.text_encoder_m.encoder) |
|||
|
|||
self.cls_head_m = nn.Sequential( |
|||
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
|||
nn.ReLU(), |
|||
nn.Linear(self.text_encoder.config.hidden_size, 2) |
|||
) |
|||
|
|||
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
|||
[self.text_encoder,self.text_encoder_m], |
|||
[self.cls_head,self.cls_head_m], |
|||
] |
|||
self.copy_params() |
|||
self.momentum = 0.995 |
|||
|
|||
|
|||
def forward(self, image, text, targets, alpha=0, train=True): |
|||
|
|||
image_embeds = self.visual_encoder(image) |
|||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|||
|
|||
image0_embeds, image1_embeds = torch.split(image_embeds,targets.size(0)) |
|||
|
|||
output = self.text_encoder(text.input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = [image0_embeds,image1_embeds], |
|||
encoder_attention_mask = [image_atts[:image0_embeds.size(0)], |
|||
image_atts[image0_embeds.size(0):]], |
|||
return_dict = True, |
|||
) |
|||
hidden_state = output.last_hidden_state[:,0,:] |
|||
prediction = self.cls_head(hidden_state) |
|||
|
|||
if train: |
|||
if self.distill: |
|||
with torch.no_grad(): |
|||
self._momentum_update() |
|||
image_embeds_m = self.visual_encoder_m(image) |
|||
image0_embeds_m, image1_embeds_m = torch.split(image_embeds_m,targets.size(0)) |
|||
output_m = self.text_encoder_m(text.input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = [image0_embeds_m,image1_embeds_m], |
|||
encoder_attention_mask = [image_atts[:image0_embeds.size(0)], |
|||
image_atts[image0_embeds.size(0):]], |
|||
return_dict = True, |
|||
) |
|||
prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:]) |
|||
|
|||
loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum( |
|||
F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean() |
|||
else: |
|||
loss = F.cross_entropy(prediction, targets) |
|||
return loss |
|||
else: |
|||
return prediction |
|||
|
|||
|
|||
|
|||
@torch.no_grad() |
|||
def copy_params(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data.copy_(param.data) # initialize |
|||
param_m.requires_grad = False # not update by gradient |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _momentum_update(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
|||
|
|||
|
|||
def share_cross_attention(self, model): |
|||
|
|||
for i in range(6): |
|||
layer_num = 6+i*2 |
|||
modules_0 = model.layer[layer_num].crossattention.self._modules |
|||
modules_1 = model.layer[layer_num+1].crossattention.self._modules |
|||
|
|||
for name in modules_0.keys(): |
|||
if 'key' in name or 'value' in name: |
|||
module_0 = modules_0[name] |
|||
module_1 = modules_1[name] |
|||
if hasattr(module_0, "weight"): |
|||
module_0.weight = module_1.weight |
|||
if hasattr(module_0, "bias"): |
|||
module_0.bias = module_1.bias |
@ -0,0 +1,291 @@ |
|||
''' |
|||
* Copyright (c) 2021, salesforce.com, inc. |
|||
* All rights reserved. |
|||
* SPDX-License-Identifier: BSD-3-Clause |
|||
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
|||
''' |
|||
|
|||
from functools import partial |
|||
from models.vit import VisionTransformer, interpolate_pos_embed |
|||
from models.xbert import BertConfig, BertForMaskedLM |
|||
|
|||
import torch |
|||
import torch.nn.functional as F |
|||
from torch import nn |
|||
|
|||
import numpy as np |
|||
import random |
|||
|
|||
|
|||
class ALBEF(nn.Module): |
|||
def __init__(self, |
|||
text_encoder = None, |
|||
tokenizer = None, |
|||
config = None, |
|||
temp = 0.07, |
|||
init_deit = True |
|||
): |
|||
super().__init__() |
|||
|
|||
self.tokenizer = tokenizer |
|||
self.mlm_probability = config['mlm_probability'] |
|||
embed_dim = config['embed_dim'] |
|||
|
|||
self.visual_encoder = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
|
|||
if init_deit: |
|||
checkpoint = torch.hub.load_state_dict_from_url( |
|||
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", |
|||
map_location="cpu", check_hash=True) |
|||
state_dict = checkpoint["model"] |
|||
pos_embed_reshaped = interpolate_pos_embed(state_dict['pos_embed'], self.visual_encoder) |
|||
state_dict['pos_embed'] = pos_embed_reshaped |
|||
msg = self.visual_encoder.load_state_dict(state_dict,strict=False) |
|||
print(msg) |
|||
|
|||
vision_width = config['vision_width'] |
|||
bert_config = BertConfig.from_json_file(config['bert_config']) |
|||
|
|||
self.text_encoder = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config) |
|||
|
|||
text_width = self.text_encoder.config.hidden_size |
|||
self.vision_proj = nn.Linear(vision_width, embed_dim) |
|||
self.text_proj = nn.Linear(text_width, embed_dim) |
|||
|
|||
self.temp = nn.Parameter(torch.ones([]) * config['temp']) |
|||
self.queue_size = config['queue_size'] |
|||
self.momentum = config['momentum'] |
|||
self.itm_head = nn.Linear(text_width, 2) |
|||
|
|||
# create momentum models |
|||
self.visual_encoder_m = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
self.vision_proj_m = nn.Linear(vision_width, embed_dim) |
|||
self.text_encoder_m = BertForMaskedLM.from_pretrained(text_encoder, config=bert_config) |
|||
self.text_proj_m = nn.Linear(text_width, embed_dim) |
|||
|
|||
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
|||
[self.vision_proj,self.vision_proj_m], |
|||
[self.text_encoder,self.text_encoder_m], |
|||
[self.text_proj,self.text_proj_m], |
|||
] |
|||
|
|||
self.copy_params() |
|||
|
|||
# create the queue |
|||
self.register_buffer("image_queue", torch.randn(embed_dim, self.queue_size)) |
|||
self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size)) |
|||
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
|||
|
|||
self.image_queue = nn.functional.normalize(self.image_queue, dim=0) |
|||
self.text_queue = nn.functional.normalize(self.text_queue, dim=0) |
|||
|
|||
|
|||
|
|||
def forward(self, image, text, alpha=0): |
|||
with torch.no_grad(): |
|||
self.temp.clamp_(0.001,0.5) |
|||
|
|||
image_embeds = self.visual_encoder(image) |
|||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|||
|
|||
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
|||
|
|||
text_output = self.text_encoder.bert(text.input_ids, attention_mask = text.attention_mask, |
|||
return_dict = True, mode = 'text') |
|||
text_embeds = text_output.last_hidden_state |
|||
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1) |
|||
|
|||
# get momentum features |
|||
with torch.no_grad(): |
|||
self._momentum_update() |
|||
image_embeds_m = self.visual_encoder_m(image) |
|||
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) |
|||
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) |
|||
text_output_m = self.text_encoder_m.bert(text.input_ids, attention_mask = text.attention_mask, |
|||
return_dict = True, mode = 'text') |
|||
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) |
|||
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) |
|||
|
|||
sim_i2t_m = image_feat_m @ text_feat_all / self.temp |
|||
sim_t2i_m = text_feat_m @ image_feat_all / self.temp |
|||
|
|||
sim_targets = torch.zeros(sim_i2t_m.size()).to(image.device) |
|||
sim_targets.fill_diagonal_(1) |
|||
|
|||
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets |
|||
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets |
|||
|
|||
sim_i2t = image_feat @ text_feat_all / self.temp |
|||
sim_t2i = text_feat @ image_feat_all / self.temp |
|||
|
|||
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() |
|||
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() |
|||
|
|||
loss_ita = (loss_i2t+loss_t2i)/2 |
|||
|
|||
self._dequeue_and_enqueue(image_feat_m, text_feat_m) |
|||
|
|||
###=================================### |
|||
# forward the positve image-text pair |
|||
output_pos = self.text_encoder.bert(encoder_embeds = text_embeds, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True, |
|||
mode = 'fusion', |
|||
) |
|||
with torch.no_grad(): |
|||
bs = image.size(0) |
|||
weights_i2t = F.softmax(sim_i2t[:,:bs],dim=1) |
|||
weights_t2i = F.softmax(sim_t2i[:,:bs],dim=1) |
|||
|
|||
weights_i2t.fill_diagonal_(0) |
|||
weights_t2i.fill_diagonal_(0) |
|||
|
|||
# select a negative image for each text |
|||
image_embeds_neg = [] |
|||
for b in range(bs): |
|||
neg_idx = torch.multinomial(weights_t2i[b], 1).item() |
|||
image_embeds_neg.append(image_embeds[neg_idx]) |
|||
image_embeds_neg = torch.stack(image_embeds_neg,dim=0) |
|||
|
|||
# select a negative text for each image |
|||
text_embeds_neg = [] |
|||
text_atts_neg = [] |
|||
for b in range(bs): |
|||
neg_idx = torch.multinomial(weights_i2t[b], 1).item() |
|||
text_embeds_neg.append(text_embeds[neg_idx]) |
|||
text_atts_neg.append(text.attention_mask[neg_idx]) |
|||
text_embeds_neg = torch.stack(text_embeds_neg,dim=0) |
|||
text_atts_neg = torch.stack(text_atts_neg,dim=0) |
|||
|
|||
text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0) |
|||
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) |
|||
|
|||
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) |
|||
image_atts_all = torch.cat([image_atts,image_atts],dim=0) |
|||
|
|||
output_neg = self.text_encoder.bert(encoder_embeds = text_embeds_all, |
|||
attention_mask = text_atts_all, |
|||
encoder_hidden_states = image_embeds_all, |
|||
encoder_attention_mask = image_atts_all, |
|||
return_dict = True, |
|||
mode = 'fusion', |
|||
) |
|||
|
|||
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) |
|||
vl_output = self.itm_head(vl_embeddings) |
|||
|
|||
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], |
|||
dim=0).to(image.device) |
|||
loss_itm = F.cross_entropy(vl_output, itm_labels) |
|||
|
|||
##================= MLM ========================## |
|||
input_ids = text.input_ids.clone() |
|||
labels = input_ids.clone() |
|||
|
|||
probability_matrix = torch.full(labels.shape, self.mlm_probability) |
|||
input_ids, labels = self.mask(input_ids, self.text_encoder.config.vocab_size, image.device, targets=labels, |
|||
probability_matrix = probability_matrix) |
|||
|
|||
with torch.no_grad(): |
|||
logits_m = self.text_encoder_m(input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds_m, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True, |
|||
return_logits = True, |
|||
) |
|||
mlm_output = self.text_encoder(input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True, |
|||
labels = labels, |
|||
soft_labels = F.softmax(logits_m,dim=-1), |
|||
alpha = alpha |
|||
) |
|||
loss_mlm = mlm_output.loss |
|||
|
|||
return loss_mlm, loss_ita, loss_itm |
|||
|
|||
|
|||
|
|||
@torch.no_grad() |
|||
def copy_params(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data.copy_(param.data) # initialize |
|||
param_m.requires_grad = False # not update by gradient |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _momentum_update(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
|||
|
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _dequeue_and_enqueue(self, image_feat, text_feat): |
|||
# gather keys before updating queue |
|||
image_feats = concat_all_gather(image_feat) |
|||
text_feats = concat_all_gather(text_feat) |
|||
|
|||
batch_size = image_feats.shape[0] |
|||
|
|||
ptr = int(self.queue_ptr) |
|||
assert self.queue_size % batch_size == 0 # for simplicity |
|||
|
|||
# replace the keys at ptr (dequeue and enqueue) |
|||
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T |
|||
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T |
|||
ptr = (ptr + batch_size) % self.queue_size # move pointer |
|||
|
|||
self.queue_ptr[0] = ptr |
|||
|
|||
|
|||
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): |
|||
if masked_indices is None: |
|||
masked_indices = torch.bernoulli(probability_matrix).bool() |
|||
|
|||
masked_indices[input_ids == self.tokenizer.pad_token_id] = False |
|||
masked_indices[input_ids == self.tokenizer.cls_token_id] = False |
|||
|
|||
if targets is not None: |
|||
targets[~masked_indices] = -100 # We only compute loss on masked tokens |
|||
|
|||
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) |
|||
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices |
|||
input_ids[indices_replaced] = self.tokenizer.mask_token_id |
|||
|
|||
# 10% of the time, we replace masked input tokens with random word |
|||
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced |
|||
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) |
|||
input_ids[indices_random] = random_words[indices_random] |
|||
# The rest of the time (10% of the time) we keep the masked input tokens unchanged |
|||
|
|||
if targets is not None: |
|||
return input_ids, targets |
|||
else: |
|||
return input_ids |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def concat_all_gather(tensor): |
|||
""" |
|||
Performs all_gather operation on the provided tensors. |
|||
*** Warning ***: torch.distributed.all_gather has no gradient. |
|||
""" |
|||
tensors_gather = [torch.ones_like(tensor) |
|||
for _ in range(torch.distributed.get_world_size())] |
|||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|||
|
|||
output = torch.cat(tensors_gather, dim=0) |
|||
return output |
|||
|
@ -0,0 +1,99 @@ |
|||
from functools import partial |
|||
from models.vit import VisionTransformer |
|||
from models.xbert import BertConfig, BertModel |
|||
|
|||
import torch |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
|
|||
class ALBEF(nn.Module): |
|||
def __init__(self, |
|||
text_encoder = None, |
|||
tokenizer = None, |
|||
config = None, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.tokenizer = tokenizer |
|||
vision_width = config['vision_width'] |
|||
embed_dim = config['embed_dim'] |
|||
|
|||
self.visual_encoder = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
|
|||
bert_config = BertConfig.from_json_file(config['bert_config']) |
|||
bert_config.num_hidden_layers = 18 |
|||
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
|
|||
#share the cross-attention layers for two images |
|||
self.share_cross_attention(self.text_encoder.encoder) |
|||
|
|||
text_width = self.text_encoder.config.hidden_size |
|||
self.vision_proj = nn.Linear(vision_width, embed_dim) |
|||
self.text_proj = nn.Linear(text_width, embed_dim) |
|||
self.temp = nn.Parameter(torch.ones([]) * 0.07) |
|||
self.ta_head = nn.Linear(self.text_encoder.config.hidden_size, 3) |
|||
|
|||
|
|||
def forward(self, image, text): |
|||
image_embeds = self.visual_encoder(image) |
|||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|||
with torch.no_grad(): |
|||
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
|||
sim = image_feat @ image_feat.t() / 0.07 |
|||
weights = F.softmax(sim,dim=1) |
|||
weights.fill_diagonal_(0) |
|||
|
|||
image_inputs = [[],[]] |
|||
labels = [] |
|||
for b in range(image.size(0)): |
|||
if torch.rand(1)>1/3: |
|||
idx = torch.multinomial(weights[b], 1).item() |
|||
if torch.rand(1)>0.5: |
|||
image_inputs[0].append(image_embeds[b]) |
|||
image_inputs[1].append(image_embeds[idx]) |
|||
labels.append(0) |
|||
else: |
|||
image_inputs[1].append(image_embeds[b]) |
|||
image_inputs[0].append(image_embeds[idx]) |
|||
labels.append(1) |
|||
else: |
|||
idx = torch.multinomial(weights[b], 2) |
|||
image_inputs[0].append(image_embeds[idx[0]]) |
|||
image_inputs[1].append(image_embeds[idx[1]]) |
|||
labels.append(2) |
|||
|
|||
image_inputs[0] = torch.stack(image_inputs[0],dim=0) |
|||
image_inputs[1] = torch.stack(image_inputs[1],dim=0) |
|||
labels = torch.LongTensor(labels).to(image.device) |
|||
|
|||
output = self.text_encoder(text.input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_inputs, |
|||
encoder_attention_mask = [image_atts,image_atts], |
|||
return_dict = True, |
|||
) |
|||
|
|||
pred = self.ta_head(output.last_hidden_state[:,0,:]) |
|||
loss = F.cross_entropy(pred, labels) |
|||
|
|||
return loss |
|||
|
|||
|
|||
|
|||
def share_cross_attention(self, model): |
|||
|
|||
for i in range(6): |
|||
layer_num = 6+i*2 |
|||
modules_0 = model.layer[layer_num].crossattention.self._modules |
|||
modules_1 = model.layer[layer_num+1].crossattention.self._modules |
|||
|
|||
for name in modules_0.keys(): |
|||
if 'key' in name or 'value' in name: |
|||
module_0 = modules_0[name] |
|||
module_1 = modules_1[name] |
|||
if hasattr(module_0, "weight"): |
|||
module_0.weight = module_1.weight |
|||
if hasattr(module_0, "bias"): |
|||
module_0.bias = module_1.bias |
@ -0,0 +1,217 @@ |
|||
from functools import partial |
|||
from models.vit import VisionTransformer |
|||
from models.xbert import BertConfig, BertModel |
|||
|
|||
import torch |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
|
|||
class ALBEF(nn.Module): |
|||
def __init__(self, |
|||
text_encoder = None, |
|||
tokenizer = None, |
|||
config = None, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.tokenizer = tokenizer |
|||
self.distill = config['distill'] |
|||
embed_dim = config['embed_dim'] |
|||
vision_width = config['vision_width'] |
|||
self.visual_encoder = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
|
|||
bert_config = BertConfig.from_json_file(config['bert_config']) |
|||
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
|
|||
text_width = self.text_encoder.config.hidden_size |
|||
self.vision_proj = nn.Linear(vision_width, embed_dim) |
|||
self.text_proj = nn.Linear(text_width, embed_dim) |
|||
|
|||
self.temp = nn.Parameter(torch.ones([]) * config['temp']) |
|||
self.queue_size = config['queue_size'] |
|||
self.momentum = config['momentum'] |
|||
self.itm_head = nn.Linear(text_width, 2) |
|||
|
|||
# create momentum models |
|||
self.visual_encoder_m = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
self.vision_proj_m = nn.Linear(vision_width, embed_dim) |
|||
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
self.text_proj_m = nn.Linear(text_width, embed_dim) |
|||
|
|||
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
|||
[self.vision_proj,self.vision_proj_m], |
|||
[self.text_encoder,self.text_encoder_m], |
|||
[self.text_proj,self.text_proj_m], |
|||
] |
|||
self.copy_params() |
|||
|
|||
# create the queue |
|||
self.register_buffer("image_queue", torch.randn(embed_dim, self.queue_size)) |
|||
self.register_buffer("text_queue", torch.randn(embed_dim, self.queue_size)) |
|||
self.register_buffer("idx_queue", torch.full((1,self.queue_size),-100)) |
|||
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) |
|||
|
|||
self.image_queue = nn.functional.normalize(self.image_queue, dim=0) |
|||
self.text_queue = nn.functional.normalize(self.text_queue, dim=0) |
|||
|
|||
|
|||
def forward(self, image, text, alpha, idx): |
|||
|
|||
image_embeds = self.visual_encoder(image) |
|||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|||
|
|||
image_feat = F.normalize(self.vision_proj(image_embeds[:,0,:]),dim=-1) |
|||
text_output = self.text_encoder(text.input_ids, attention_mask = text.attention_mask, |
|||
return_dict = True, mode = 'text') |
|||
text_embeds = text_output.last_hidden_state |
|||
text_feat = F.normalize(self.text_proj(text_embeds[:,0,:]),dim=-1) |
|||
|
|||
idx = idx.view(-1,1) |
|||
idx_all = torch.cat([idx.t(), self.idx_queue.clone().detach()],dim=1) |
|||
pos_idx = torch.eq(idx, idx_all).float() |
|||
sim_targets = pos_idx / pos_idx.sum(1,keepdim=True) |
|||
|
|||
with torch.no_grad(): |
|||
self._momentum_update() |
|||
image_embeds_m = self.visual_encoder_m(image) |
|||
image_feat_m = F.normalize(self.vision_proj_m(image_embeds_m[:,0,:]),dim=-1) |
|||
image_feat_all = torch.cat([image_feat_m.t(),self.image_queue.clone().detach()],dim=1) |
|||
text_output_m = self.text_encoder_m(text.input_ids, attention_mask = text.attention_mask, |
|||
return_dict = True, mode = 'text') |
|||
text_feat_m = F.normalize(self.text_proj_m(text_output_m.last_hidden_state[:,0,:]),dim=-1) |
|||
text_feat_all = torch.cat([text_feat_m.t(),self.text_queue.clone().detach()],dim=1) |
|||
|
|||
if self.distill: |
|||
sim_i2t_m = image_feat_m @ text_feat_all / self.temp |
|||
sim_t2i_m = text_feat_m @ image_feat_all / self.temp |
|||
|
|||
sim_i2t_targets = alpha * F.softmax(sim_i2t_m, dim=1) + (1 - alpha) * sim_targets |
|||
sim_t2i_targets = alpha * F.softmax(sim_t2i_m, dim=1) + (1 - alpha) * sim_targets |
|||
|
|||
sim_i2t = image_feat @ text_feat_all / self.temp |
|||
sim_t2i = text_feat @ image_feat_all / self.temp |
|||
|
|||
if self.distill: |
|||
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_i2t_targets,dim=1).mean() |
|||
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_t2i_targets,dim=1).mean() |
|||
else: |
|||
loss_i2t = -torch.sum(F.log_softmax(sim_i2t, dim=1)*sim_targets,dim=1).mean() |
|||
loss_t2i = -torch.sum(F.log_softmax(sim_t2i, dim=1)*sim_targets,dim=1).mean() |
|||
|
|||
loss_ita = (loss_i2t+loss_t2i)/2 |
|||
|
|||
self._dequeue_and_enqueue(image_feat_m, text_feat_m, idx) |
|||
|
|||
###=================================### |
|||
# forward the positve image-text pair |
|||
output_pos = self.text_encoder(encoder_embeds = text_embeds, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True, |
|||
mode = 'fusion', |
|||
) |
|||
with torch.no_grad(): |
|||
bs = image.size(0) |
|||
weights_i2t = F.softmax(sim_i2t[:,:bs]+1e-4,dim=1) |
|||
weights_t2i = F.softmax(sim_t2i[:,:bs]+1e-4,dim=1) |
|||
|
|||
mask = torch.eq(idx, idx.T) |
|||
weights_i2t.masked_fill_(mask, 0) |
|||
weights_t2i.masked_fill_(mask, 0) |
|||
|
|||
# select a negative image for each text |
|||
image_embeds_neg = [] |
|||
for b in range(bs): |
|||
neg_idx = torch.multinomial(weights_t2i[b], 1).item() |
|||
image_embeds_neg.append(image_embeds[neg_idx]) |
|||
image_embeds_neg = torch.stack(image_embeds_neg,dim=0) |
|||
|
|||
# select a negative text for each image |
|||
text_embeds_neg = [] |
|||
text_atts_neg = [] |
|||
for b in range(bs): |
|||
neg_idx = torch.multinomial(weights_i2t[b], 1).item() |
|||
text_embeds_neg.append(text_embeds[neg_idx]) |
|||
text_atts_neg.append(text.attention_mask[neg_idx]) |
|||
text_embeds_neg = torch.stack(text_embeds_neg,dim=0) |
|||
text_atts_neg = torch.stack(text_atts_neg,dim=0) |
|||
|
|||
text_embeds_all = torch.cat([text_embeds, text_embeds_neg],dim=0) |
|||
text_atts_all = torch.cat([text.attention_mask, text_atts_neg],dim=0) |
|||
|
|||
image_embeds_all = torch.cat([image_embeds_neg,image_embeds],dim=0) |
|||
image_atts_all = torch.cat([image_atts,image_atts],dim=0) |
|||
|
|||
output_neg = self.text_encoder(encoder_embeds = text_embeds_all, |
|||
attention_mask = text_atts_all, |
|||
encoder_hidden_states = image_embeds_all, |
|||
encoder_attention_mask = image_atts_all, |
|||
return_dict = True, |
|||
mode = 'fusion', |
|||
) |
|||
|
|||
vl_embeddings = torch.cat([output_pos.last_hidden_state[:,0,:], output_neg.last_hidden_state[:,0,:]],dim=0) |
|||
vl_output = self.itm_head(vl_embeddings) |
|||
|
|||
itm_labels = torch.cat([torch.ones(bs,dtype=torch.long),torch.zeros(2*bs,dtype=torch.long)], |
|||
dim=0).to(image.device) |
|||
loss_itm = F.cross_entropy(vl_output, itm_labels) |
|||
|
|||
return loss_ita, loss_itm |
|||
|
|||
|
|||
|
|||
@torch.no_grad() |
|||
def copy_params(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data.copy_(param.data) # initialize |
|||
param_m.requires_grad = False # not update by gradient |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _momentum_update(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _dequeue_and_enqueue(self, image_feat, text_feat, idx): |
|||
# gather keys before updating queue |
|||
image_feats = concat_all_gather(image_feat) |
|||
text_feats = concat_all_gather(text_feat) |
|||
idxs = concat_all_gather(idx) |
|||
|
|||
batch_size = image_feats.shape[0] |
|||
|
|||
ptr = int(self.queue_ptr) |
|||
assert self.queue_size % batch_size == 0 # for simplicity |
|||
|
|||
# replace the keys at ptr (dequeue and enqueue) |
|||
self.image_queue[:, ptr:ptr + batch_size] = image_feats.T |
|||
self.text_queue[:, ptr:ptr + batch_size] = text_feats.T |
|||
self.idx_queue[:, ptr:ptr + batch_size] = idxs.T |
|||
ptr = (ptr + batch_size) % self.queue_size # move pointer |
|||
|
|||
self.queue_ptr[0] = ptr |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def concat_all_gather(tensor): |
|||
""" |
|||
Performs all_gather operation on the provided tensors. |
|||
*** Warning ***: torch.distributed.all_gather has no gradient. |
|||
""" |
|||
tensors_gather = [torch.ones_like(tensor) |
|||
for _ in range(torch.distributed.get_world_size())] |
|||
torch.distributed.all_gather(tensors_gather, tensor, async_op=False) |
|||
|
|||
output = torch.cat(tensors_gather, dim=0) |
|||
return output |
|||
|
@ -0,0 +1,110 @@ |
|||
from functools import partial |
|||
from models.vit import VisionTransformer |
|||
from models.xbert import BertConfig, BertModel |
|||
|
|||
import torch |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
|
|||
class ALBEF(nn.Module): |
|||
def __init__(self, |
|||
text_encoder = None, |
|||
tokenizer = None, |
|||
config = None, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.tokenizer = tokenizer |
|||
self.distill = config['distill'] |
|||
|
|||
self.visual_encoder = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
|
|||
bert_config = BertConfig.from_json_file(config['bert_config']) |
|||
|
|||
self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
|
|||
self.cls_head = nn.Sequential( |
|||
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
|||
nn.ReLU(), |
|||
nn.Linear(self.text_encoder.config.hidden_size, 3) |
|||
) |
|||
|
|||
if self.distill: |
|||
self.visual_encoder_m = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False) |
|||
self.cls_head_m = nn.Sequential( |
|||
nn.Linear(self.text_encoder.config.hidden_size, self.text_encoder.config.hidden_size), |
|||
nn.ReLU(), |
|||
nn.Linear(self.text_encoder.config.hidden_size, 3) |
|||
) |
|||
|
|||
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
|||
[self.text_encoder,self.text_encoder_m], |
|||
[self.cls_head,self.cls_head_m], |
|||
] |
|||
self.copy_params() |
|||
self.momentum = 0.995 |
|||
|
|||
|
|||
def forward(self, image, text, targets, alpha=0, train=True): |
|||
|
|||
image_embeds = self.visual_encoder(image) |
|||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|||
|
|||
if train: |
|||
output = self.text_encoder(text.input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True |
|||
) |
|||
prediction = self.cls_head(output.last_hidden_state[:,0,:]) |
|||
if self.distill: |
|||
with torch.no_grad(): |
|||
self._momentum_update() |
|||
image_embeds_m = self.visual_encoder_m(image) |
|||
output_m = self.text_encoder_m(text.input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds_m, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True |
|||
) |
|||
prediction_m = self.cls_head_m(output_m.last_hidden_state[:,0,:]) |
|||
|
|||
loss = (1-alpha)*F.cross_entropy(prediction, targets) - alpha*torch.sum( |
|||
F.log_softmax(prediction, dim=1)*F.softmax(prediction_m, dim=1),dim=1).mean() |
|||
else: |
|||
loss = F.cross_entropy(prediction, targets) |
|||
return loss |
|||
|
|||
else: |
|||
output = self.text_encoder(text.input_ids, |
|||
attention_mask = text.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True |
|||
) |
|||
prediction = self.cls_head(output.last_hidden_state[:,0,:]) |
|||
return prediction |
|||
|
|||
|
|||
|
|||
@torch.no_grad() |
|||
def copy_params(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data.copy_(param.data) # initialize |
|||
param_m.requires_grad = False # not update by gradient |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _momentum_update(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
|||
|
|||
|
@ -0,0 +1,214 @@ |
|||
from functools import partial |
|||
from models.vit import VisionTransformer |
|||
from models.xbert import BertConfig, BertModel, BertLMHeadModel |
|||
|
|||
import torch |
|||
from torch import nn |
|||
import torch.nn.functional as F |
|||
|
|||
import numpy as np |
|||
|
|||
class ALBEF(nn.Module): |
|||
def __init__(self, |
|||
text_encoder = None, |
|||
text_decoder = None, |
|||
tokenizer = None, |
|||
config = None, |
|||
): |
|||
super().__init__() |
|||
|
|||
self.tokenizer = tokenizer |
|||
self.distill = config['distill'] |
|||
|
|||
self.visual_encoder = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
|
|||
config_encoder = BertConfig.from_json_file(config['bert_config']) |
|||
self.text_encoder = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False) |
|||
|
|||
config_decoder = BertConfig.from_json_file(config['bert_config']) |
|||
config_decoder.fusion_layer = 0 |
|||
config_decoder.num_hidden_layers = 6 |
|||
self.text_decoder = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder) |
|||
|
|||
if self.distill: |
|||
self.visual_encoder_m = VisionTransformer( |
|||
img_size=config['image_res'], patch_size=16, embed_dim=768, depth=12, num_heads=12, |
|||
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6)) |
|||
self.text_encoder_m = BertModel.from_pretrained(text_encoder, config=config_encoder, add_pooling_layer=False) |
|||
self.text_decoder_m = BertLMHeadModel.from_pretrained(text_decoder, config=config_decoder) |
|||
self.model_pairs = [[self.visual_encoder,self.visual_encoder_m], |
|||
[self.text_encoder,self.text_encoder_m], |
|||
[self.text_decoder,self.text_decoder_m], |
|||
] |
|||
self.copy_params() |
|||
self.momentum = 0.995 |
|||
|
|||
|
|||
def forward(self, image, quesiton, answer=None, alpha=0, k=None, weights=None, train=True): |
|||
|
|||
image_embeds = self.visual_encoder(image) |
|||
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device) |
|||
|
|||
if train: |
|||
''' |
|||
k: number of answers for each question |
|||
weights: weight for each answer |
|||
''' |
|||
answer_targets = answer.input_ids.masked_fill(answer.input_ids == self.tokenizer.pad_token_id, -100) |
|||
|
|||
question_output = self.text_encoder(quesiton.input_ids, |
|||
attention_mask = quesiton.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True) |
|||
|
|||
question_states = [] |
|||
question_atts = [] |
|||
for b, n in enumerate(k): |
|||
question_states += [question_output.last_hidden_state[b]]*n |
|||
question_atts += [quesiton.attention_mask[b]]*n |
|||
question_states = torch.stack(question_states,0) |
|||
question_atts = torch.stack(question_atts,0) |
|||
|
|||
if self.distill: |
|||
with torch.no_grad(): |
|||
self._momentum_update() |
|||
image_embeds_m = self.visual_encoder_m(image) |
|||
question_output_m = self.text_encoder_m(quesiton.input_ids, |
|||
attention_mask = quesiton.attention_mask, |
|||
encoder_hidden_states = image_embeds_m, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True) |
|||
|
|||
question_states_m = [] |
|||
for b, n in enumerate(k): |
|||
question_states_m += [question_output_m.last_hidden_state[b]]*n |
|||
question_states_m = torch.stack(question_states_m,0) |
|||
|
|||
logits_m = self.text_decoder_m(answer.input_ids, |
|||
attention_mask = answer.attention_mask, |
|||
encoder_hidden_states = question_states_m, |
|||
encoder_attention_mask = question_atts, |
|||
return_logits = True, |
|||
) |
|||
|
|||
answer_output = self.text_decoder(answer.input_ids, |
|||
attention_mask = answer.attention_mask, |
|||
encoder_hidden_states = question_states, |
|||
encoder_attention_mask = question_atts, |
|||
labels = answer_targets, |
|||
return_dict = True, |
|||
soft_labels = F.softmax(logits_m,dim=-1), |
|||
alpha = alpha, |
|||
reduction = 'none', |
|||
) |
|||
else: |
|||
answer_output = self.text_decoder(answer.input_ids, |
|||
attention_mask = answer.attention_mask, |
|||
encoder_hidden_states = question_states, |
|||
encoder_attention_mask = question_atts, |
|||
labels = answer_targets, |
|||
return_dict = True, |
|||
reduction = 'none', |
|||
) |
|||
loss = weights * answer_output.loss |
|||
loss = loss.sum()/image.size(0) |
|||
|
|||
return loss |
|||
|
|||
|
|||
else: |
|||
question_output = self.text_encoder(quesiton.input_ids, |
|||
attention_mask = quesiton.attention_mask, |
|||
encoder_hidden_states = image_embeds, |
|||
encoder_attention_mask = image_atts, |
|||
return_dict = True) |
|||
topk_ids, topk_probs = self.rank_answer(question_output.last_hidden_state, quesiton.attention_mask, |
|||
answer.input_ids, answer.attention_mask, k) |
|||
return topk_ids, topk_probs |
|||
|
|||
|
|||
|
|||
@torch.no_grad() |
|||
def copy_params(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data.copy_(param.data) # initialize |
|||
param_m.requires_grad = False # not update by gradient |
|||
|
|||
|
|||
@torch.no_grad() |
|||
def _momentum_update(self): |
|||
for model_pair in self.model_pairs: |
|||
for param, param_m in zip(model_pair[0].parameters(), model_pair[1].parameters()): |
|||
param_m.data = param_m.data * self.momentum + param.data * (1. - self.momentum) |
|||
|
|||
|
|||
def rank_answer(self, question_states, question_atts, answer_ids, answer_atts, k): |
|||
|
|||
num_ques = question_states.size(0) |
|||
start_ids = answer_ids[0,0].repeat(num_ques,1) # bos token |
|||
|
|||
start_output = self.text_decoder(start_ids, |
|||
encoder_hidden_states = question_states, |
|||
encoder_attention_mask = question_atts, |
|||
return_dict = True, |
|||
reduction = 'none') |
|||
logits = start_output.logits[:,0,:] # first token's logit |
|||
|
|||
# topk_probs: top-k probability |
|||
# topk_ids: [num_question, k] |
|||
answer_first_token = answer_ids[:,1] |
|||
prob_first_token = F.softmax(logits,dim=1).index_select(dim=1, index=answer_first_token) |
|||
topk_probs, topk_ids = prob_first_token.topk(k,dim=1) |
|||
|
|||
# answer input: [num_question*k, answer_len] |
|||
input_ids = [] |
|||
input_atts = [] |
|||
for b, topk_id in enumerate(topk_ids): |
|||
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) |
|||
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) |
|||
input_ids = torch.cat(input_ids,dim=0) |
|||
input_atts = torch.cat(input_atts,dim=0) |
|||
|
|||
targets_ids = input_ids.masked_fill(input_ids == self.tokenizer.pad_token_id, -100) |
|||
|
|||
# repeat encoder's output for top-k answers |
|||
question_states = tile(question_states, 0, k) |
|||
question_atts = tile(question_atts, 0, k) |
|||
|
|||
output = self.text_decoder(input_ids, |
|||
attention_mask = input_atts, |
|||
encoder_hidden_states = question_states, |
|||
encoder_attention_mask = question_atts, |
|||
labels = targets_ids, |
|||
return_dict = True, |
|||
reduction = 'none') |
|||
|
|||
answer_loss = output.loss |
|||
answer_loss = answer_loss.view(input_ids.size(0),-1) |
|||
|
|||
# topk_prob: first token probability |
|||
topk_probs = topk_probs.view(-1,1) |
|||
log_probs = torch.cat([topk_probs.log(), -answer_loss],dim=1) |
|||
|
|||
# re-calculate log probabilities for the answer sequences using chain rule |
|||
log_probs_sum = log_probs.sum(1) |
|||
log_probs_sum = log_probs_sum.view(num_ques,k) |
|||
|
|||
topk_probs = F.softmax(log_probs_sum, dim=-1) |
|||
# get top-k after re-ranking |
|||
topk_probs, rerank_id = topk_probs.topk(k,dim=1) |
|||
topk_ids = torch.gather(topk_ids, 1, rerank_id) |
|||
|
|||
return topk_ids, topk_probs |
|||
|
|||
def tile(x, dim, n_tile): |
|||
init_dim = x.size(dim) |
|||
repeat_idx = [1] * x.dim() |
|||
repeat_idx[dim] = n_tile |
|||
x = x.repeat(*(repeat_idx)) |
|||
order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])) |
|||
return torch.index_select(x, dim, order_index.to(x.device)) |
@ -0,0 +1,539 @@ |
|||
# coding=utf-8 |
|||
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
"""Tokenization classes for Bert.""" |
|||
|
|||
|
|||
import collections |
|||
import os |
|||
import unicodedata |
|||
from typing import List, Optional, Tuple |
|||
|
|||
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace |
|||
from transformers.utils import logging |
|||
|
|||
|
|||
logger = logging.get_logger(__name__) |
|||
|
|||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"} |
|||
|
|||
PRETRAINED_VOCAB_FILES_MAP = { |
|||
"vocab_file": { |
|||
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt", |
|||
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt", |
|||
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt", |
|||
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt", |
|||
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt", |
|||
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt", |
|||
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt", |
|||
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt", |
|||
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt", |
|||
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt", |
|||
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", |
|||
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt", |
|||
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt", |
|||
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt", |
|||
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt", |
|||
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt", |
|||
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt", |
|||
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt", |
|||
} |
|||
} |
|||
|
|||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { |
|||
"bert-base-uncased": 512, |
|||
"bert-large-uncased": 512, |
|||
"bert-base-cased": 512, |
|||
"bert-large-cased": 512, |
|||
"bert-base-multilingual-uncased": 512, |
|||
"bert-base-multilingual-cased": 512, |
|||
"bert-base-chinese": 512, |
|||
"bert-base-german-cased": 512, |
|||
"bert-large-uncased-whole-word-masking": 512, |
|||
"bert-large-cased-whole-word-masking": 512, |
|||
"bert-large-uncased-whole-word-masking-finetuned-squad": 512, |
|||
"bert-large-cased-whole-word-masking-finetuned-squad": 512, |
|||
"bert-base-cased-finetuned-mrpc": 512, |
|||
"bert-base-german-dbmdz-cased": 512, |
|||
"bert-base-german-dbmdz-uncased": 512, |
|||
"TurkuNLP/bert-base-finnish-cased-v1": 512, |
|||
"TurkuNLP/bert-base-finnish-uncased-v1": 512, |
|||
"wietsedv/bert-base-dutch-cased": 512, |
|||
} |
|||
|
|||
PRETRAINED_INIT_CONFIGURATION = { |
|||
"bert-base-uncased": {"do_lower_case": True}, |
|||
"bert-large-uncased": {"do_lower_case": True}, |
|||
"bert-base-cased": {"do_lower_case": False}, |
|||
"bert-large-cased": {"do_lower_case": False}, |
|||
"bert-base-multilingual-uncased": {"do_lower_case": True}, |
|||
"bert-base-multilingual-cased": {"do_lower_case": False}, |
|||
"bert-base-chinese": {"do_lower_case": False}, |
|||
"bert-base-german-cased": {"do_lower_case": False}, |
|||
"bert-large-uncased-whole-word-masking": {"do_lower_case": True}, |
|||
"bert-large-cased-whole-word-masking": {"do_lower_case": False}, |
|||
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True}, |
|||
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False}, |
|||
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False}, |
|||
"bert-base-german-dbmdz-cased": {"do_lower_case": False}, |
|||
"bert-base-german-dbmdz-uncased": {"do_lower_case": True}, |
|||
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False}, |
|||
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True}, |
|||
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False}, |
|||
} |
|||
|
|||
|
|||
def load_vocab(vocab_file): |
|||
"""Loads a vocabulary file into a dictionary.""" |
|||
vocab = collections.OrderedDict() |
|||
with open(vocab_file, "r", encoding="utf-8") as reader: |
|||
tokens = reader.readlines() |
|||
for index, token in enumerate(tokens): |
|||
token = token.rstrip("\n") |
|||
vocab[token] = index |
|||
return vocab |
|||
|
|||
|
|||
def whitespace_tokenize(text): |
|||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" |
|||
text = text.strip() |
|||
if not text: |
|||
return [] |
|||
tokens = text.split() |
|||
return tokens |
|||
|
|||
|
|||
class BertTokenizer(PreTrainedTokenizer): |
|||
r""" |
|||
Construct a BERT tokenizer. Based on WordPiece. |
|||
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. |
|||
Users should refer to this superclass for more information regarding those methods. |
|||
Args: |
|||
vocab_file (:obj:`str`): |
|||
File containing the vocabulary. |
|||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): |
|||
Whether or not to lowercase the input when tokenizing. |
|||
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`): |
|||
Whether or not to do basic tokenization before WordPiece. |
|||
never_split (:obj:`Iterable`, `optional`): |
|||
Collection of tokens which will never be split during tokenization. Only has an effect when |
|||
:obj:`do_basic_tokenize=True` |
|||
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`): |
|||
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this |
|||
token instead. |
|||
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`): |
|||
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for |
|||
sequence classification or for a text and a question for question answering. It is also used as the last |
|||
token of a sequence built with special tokens. |
|||
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`): |
|||
The token used for padding, for example when batching sequences of different lengths. |
|||
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`): |
|||
The classifier token which is used when doing sequence classification (classification of the whole sequence |
|||
instead of per-token classification). It is the first token of the sequence when built with special tokens. |
|||
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`): |
|||
The token used for masking values. This is the token used when training this model with masked language |
|||
modeling. This is the token which the model will try to predict. |
|||
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): |
|||
Whether or not to tokenize Chinese characters. |
|||
This should likely be deactivated for Japanese (see this `issue |
|||
<https://github.com/huggingface/transformers/issues/328>`__). |
|||
strip_accents: (:obj:`bool`, `optional`): |
|||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the |
|||
value for :obj:`lowercase` (as in the original BERT). |
|||
""" |
|||
|
|||
vocab_files_names = VOCAB_FILES_NAMES |
|||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP |
|||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION |
|||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES |
|||
|
|||
def __init__( |
|||
self, |
|||
vocab_file, |
|||
do_lower_case=True, |
|||
do_basic_tokenize=True, |
|||
never_split=None, |
|||
unk_token="[UNK]", |
|||
sep_token="[SEP]", |
|||
pad_token="[PAD]", |
|||
cls_token="[CLS]", |
|||
mask_token="[MASK]", |
|||
tokenize_chinese_chars=True, |
|||
strip_accents=None, |
|||
**kwargs |
|||
): |
|||
super().__init__( |
|||
do_lower_case=do_lower_case, |
|||
do_basic_tokenize=do_basic_tokenize, |
|||
never_split=never_split, |
|||
unk_token=unk_token, |
|||
sep_token=sep_token, |
|||
pad_token=pad_token, |
|||
cls_token=cls_token, |
|||
mask_token=mask_token, |
|||
tokenize_chinese_chars=tokenize_chinese_chars, |
|||
strip_accents=strip_accents, |
|||
**kwargs, |
|||
) |
|||
|
|||
if not os.path.isfile(vocab_file): |
|||
raise ValueError( |
|||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " |
|||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file) |
|||
) |
|||
self.vocab = load_vocab(vocab_file) |
|||
self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()]) |
|||
self.do_basic_tokenize = do_basic_tokenize |
|||
if do_basic_tokenize: |
|||
self.basic_tokenizer = BasicTokenizer( |
|||
do_lower_case=do_lower_case, |
|||
never_split=never_split, |
|||
tokenize_chinese_chars=tokenize_chinese_chars, |
|||
strip_accents=strip_accents, |
|||
) |
|||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token) |
|||
|
|||
@property |
|||
def do_lower_case(self): |
|||
return self.basic_tokenizer.do_lower_case |
|||
|
|||
@property |
|||
def vocab_size(self): |
|||
return len(self.vocab) |
|||
|
|||
def get_vocab(self): |
|||
return dict(self.vocab, **self.added_tokens_encoder) |
|||
|
|||
def _tokenize(self, text): |
|||
split_tokens = [] |
|||
if self.do_basic_tokenize: |
|||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens): |
|||
|
|||
# If the token is part of the never_split set |
|||
if token in self.basic_tokenizer.never_split: |
|||
split_tokens.append(token) |
|||
else: |
|||
split_tokens += self.wordpiece_tokenizer.tokenize(token) |
|||
else: |
|||
split_tokens = self.wordpiece_tokenizer.tokenize(text) |
|||
return split_tokens |
|||
|
|||
def _convert_token_to_id(self, token): |
|||
""" Converts a token (str) in an id using the vocab. """ |
|||
return self.vocab.get(token, self.vocab.get(self.unk_token)) |
|||
|
|||
def _convert_id_to_token(self, index): |
|||
"""Converts an index (integer) in a token (str) using the vocab.""" |
|||
return self.ids_to_tokens.get(index, self.unk_token) |
|||
|
|||
def convert_tokens_to_string(self, tokens): |
|||
""" Converts a sequence of tokens (string) in a single string. """ |
|||
out_string = " ".join(tokens).replace(" ##", "").strip() |
|||
return out_string |
|||
|
|||
def build_inputs_with_special_tokens( |
|||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|||
) -> List[int]: |
|||
""" |
|||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and |
|||
adding special tokens. A BERT sequence has the following format: |
|||
- single sequence: ``[CLS] X `` |
|||
- pair of sequences: ``[CLS] A [SEP] B [SEP]`` |
|||
Args: |
|||
token_ids_0 (:obj:`List[int]`): |
|||
List of IDs to which the special tokens will be added. |
|||
token_ids_1 (:obj:`List[int]`, `optional`): |
|||
Optional second list of IDs for sequence pairs. |
|||
Returns: |
|||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. |
|||
""" |
|||
if token_ids_1 is None: |
|||
return [self.cls_token_id] + token_ids_0 |
|||
cls = [self.cls_token_id] |
|||
sep = [self.sep_token_id] |
|||
return cls + token_ids_0 + sep + token_ids_1 + sep |
|||
|
|||
def get_special_tokens_mask( |
|||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False |
|||
) -> List[int]: |
|||
""" |
|||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding |
|||
special tokens using the tokenizer ``prepare_for_model`` method. |
|||
Args: |
|||
token_ids_0 (:obj:`List[int]`): |
|||
List of IDs. |
|||
token_ids_1 (:obj:`List[int]`, `optional`): |
|||
Optional second list of IDs for sequence pairs. |
|||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): |
|||
Whether or not the token list is already formatted with special tokens for the model. |
|||
Returns: |
|||
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. |
|||
""" |
|||
|
|||
if already_has_special_tokens: |
|||
if token_ids_1 is not None: |
|||
raise ValueError( |
|||
"You should not supply a second sequence if the provided sequence of " |
|||
"ids is already formatted with special tokens for the model." |
|||
) |
|||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0)) |
|||
|
|||
if token_ids_1 is not None: |
|||
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] |
|||
return [1] + ([0] * len(token_ids_0)) |
|||
|
|||
def create_token_type_ids_from_sequences( |
|||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None |
|||
) -> List[int]: |
|||
""" |
|||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence |
|||
pair mask has the following format: |
|||
:: |
|||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 |
|||
| first sequence | second sequence | |
|||
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s). |
|||
Args: |
|||
token_ids_0 (:obj:`List[int]`): |
|||
List of IDs. |
|||
token_ids_1 (:obj:`List[int]`, `optional`): |
|||
Optional second list of IDs for sequence pairs. |
|||
Returns: |
|||
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given |
|||
sequence(s). |
|||
""" |
|||
sep = [self.sep_token_id] |
|||
cls = [self.cls_token_id] |
|||
if token_ids_1 is None: |
|||
return len(cls + token_ids_0 + sep) * [0] |
|||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] |
|||
|
|||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: |
|||
index = 0 |
|||
if os.path.isdir(save_directory): |
|||
vocab_file = os.path.join( |
|||
save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] |
|||
) |
|||
else: |
|||
vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory |
|||
with open(vocab_file, "w", encoding="utf-8") as writer: |
|||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): |
|||
if index != token_index: |
|||
logger.warning( |
|||
"Saving vocabulary to {}: vocabulary indices are not consecutive." |
|||
" Please check that the vocabulary is not corrupted!".format(vocab_file) |
|||
) |
|||
index = token_index |
|||
writer.write(token + "\n") |
|||
index += 1 |
|||
return (vocab_file,) |
|||
|
|||
|
|||
class BasicTokenizer(object): |
|||
""" |
|||
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.). |
|||
Args: |
|||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`): |
|||
Whether or not to lowercase the input when tokenizing. |
|||
never_split (:obj:`Iterable`, `optional`): |
|||
Collection of tokens which will never be split during tokenization. Only has an effect when |
|||
:obj:`do_basic_tokenize=True` |
|||
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`): |
|||
Whether or not to tokenize Chinese characters. |
|||
This should likely be deactivated for Japanese (see this `issue |
|||
<https://github.com/huggingface/transformers/issues/328>`__). |
|||
strip_accents: (:obj:`bool`, `optional`): |
|||
Whether or not to strip all accents. If this option is not specified, then it will be determined by the |
|||
value for :obj:`lowercase` (as in the original BERT). |
|||
""" |
|||
|
|||
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None): |
|||
if never_split is None: |
|||
never_split = [] |
|||
self.do_lower_case = do_lower_case |
|||
self.never_split = set(never_split) |
|||
self.tokenize_chinese_chars = tokenize_chinese_chars |
|||
self.strip_accents = strip_accents |
|||
|
|||
def tokenize(self, text, never_split=None): |
|||
""" |
|||
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see |
|||
WordPieceTokenizer. |
|||
Args: |
|||
**never_split**: (`optional`) list of str |
|||
Kept for backward compatibility purposes. Now implemented directly at the base class level (see |
|||
:func:`PreTrainedTokenizer.tokenize`) List of token not to split. |
|||
""" |
|||
# union() returns a new set by concatenating the two sets. |
|||
never_split = self.never_split.union(set(never_split)) if never_split else self.never_split |
|||
text = self._clean_text(text) |
|||
|
|||
# This was added on November 1st, 2018 for the multilingual and Chinese |
|||
# models. This is also applied to the English models now, but it doesn't |
|||
# matter since the English models were not trained on any Chinese data |
|||
# and generally don't have any Chinese data in them (there are Chinese |
|||
# characters in the vocabulary because Wikipedia does have some Chinese |
|||
# words in the English Wikipedia.). |
|||
if self.tokenize_chinese_chars: |
|||
text = self._tokenize_chinese_chars(text) |
|||
orig_tokens = whitespace_tokenize(text) |
|||
split_tokens = [] |
|||
for token in orig_tokens: |
|||
if token not in never_split: |
|||
if self.do_lower_case: |
|||
token = token.lower() |
|||
if self.strip_accents is not False: |
|||
token = self._run_strip_accents(token) |
|||
elif self.strip_accents: |
|||
token = self._run_strip_accents(token) |
|||
split_tokens.extend(self._run_split_on_punc(token, never_split)) |
|||
|
|||
output_tokens = whitespace_tokenize(" ".join(split_tokens)) |
|||
return output_tokens |
|||
|
|||
def _run_strip_accents(self, text): |
|||
"""Strips accents from a piece of text.""" |
|||
text = unicodedata.normalize("NFD", text) |
|||
output = [] |
|||
for char in text: |
|||
cat = unicodedata.category(char) |
|||
if cat == "Mn": |
|||
continue |
|||
output.append(char) |
|||
return "".join(output) |
|||
|
|||
def _run_split_on_punc(self, text, never_split=None): |
|||
"""Splits punctuation on a piece of text.""" |
|||
if never_split is not None and text in never_split: |
|||
return [text] |
|||
chars = list(text) |
|||
i = 0 |
|||
start_new_word = True |
|||
output = [] |
|||
while i < len(chars): |
|||
char = chars[i] |
|||
if _is_punctuation(char): |
|||
output.append([char]) |
|||
start_new_word = True |
|||
else: |
|||
if start_new_word: |
|||
output.append([]) |
|||
start_new_word = False |
|||
output[-1].append(char) |
|||
i += 1 |
|||
|
|||
return ["".join(x) for x in output] |
|||
|
|||
def _tokenize_chinese_chars(self, text): |
|||
"""Adds whitespace around any CJK character.""" |
|||
output = [] |
|||
for char in text: |
|||
cp = ord(char) |
|||
if self._is_chinese_char(cp): |
|||
output.append(" ") |
|||
output.append(char) |
|||
output.append(" ") |
|||
else: |
|||
output.append(char) |
|||
return "".join(output) |
|||
|
|||
def _is_chinese_char(self, cp): |
|||
"""Checks whether CP is the codepoint of a CJK character.""" |
|||
# This defines a "chinese character" as anything in the CJK Unicode block: |
|||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) |
|||
# |
|||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, |
|||
# despite its name. The modern Korean Hangul alphabet is a different block, |
|||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write |
|||
# space-separated words, so they are not treated specially and handled |
|||
# like the all of the other languages. |
|||
if ( |
|||
(cp >= 0x4E00 and cp <= 0x9FFF) |
|||
or (cp >= 0x3400 and cp <= 0x4DBF) # |
|||
or (cp >= 0x20000 and cp <= 0x2A6DF) # |
|||
or (cp >= 0x2A700 and cp <= 0x2B73F) # |
|||
or (cp >= 0x2B740 and cp <= 0x2B81F) # |
|||
or (cp >= 0x2B820 and cp <= 0x2CEAF) # |
|||
or (cp >= 0xF900 and cp <= 0xFAFF) |
|||
or (cp >= 0x2F800 and cp <= 0x2FA1F) # |
|||
): # |
|||
return True |
|||
|
|||
return False |
|||
|
|||
def _clean_text(self, text): |
|||
"""Performs invalid character removal and whitespace cleanup on text.""" |
|||
output = [] |
|||
for char in text: |
|||
cp = ord(char) |
|||
if cp == 0 or cp == 0xFFFD or _is_control(char): |
|||
continue |
|||
if _is_whitespace(char): |
|||
output.append(" ") |
|||
else: |
|||
output.append(char) |
|||
return "".join(output) |
|||
|
|||
|
|||
class WordpieceTokenizer(object): |
|||
"""Runs WordPiece tokenization.""" |
|||
|
|||
def __init__(self, vocab, unk_token, max_input_chars_per_word=100): |
|||
self.vocab = vocab |
|||
self.unk_token = unk_token |
|||
self.max_input_chars_per_word = max_input_chars_per_word |
|||
|
|||
def tokenize(self, text): |
|||
""" |
|||
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform |
|||
tokenization using the given vocabulary. |
|||
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`. |
|||
Args: |
|||
text: A single token or whitespace separated tokens. This should have |
|||
already been passed through `BasicTokenizer`. |
|||
Returns: |
|||
A list of wordpiece tokens. |
|||
""" |
|||
|
|||
output_tokens = [] |
|||
for token in whitespace_tokenize(text): |
|||
chars = list(token) |
|||
if len(chars) > self.max_input_chars_per_word: |
|||
output_tokens.append(self.unk_token) |
|||
continue |
|||
|
|||
is_bad = False |
|||
start = 0 |
|||
sub_tokens = [] |
|||
while start < len(chars): |
|||
end = len(chars) |
|||
cur_substr = None |
|||
while start < end: |
|||
substr = "".join(chars[start:end]) |
|||
if start > 0: |
|||
substr = "##" + substr |
|||
if substr in self.vocab: |
|||
cur_substr = substr |
|||
break |
|||
end -= 1 |
|||
if cur_substr is None: |
|||
is_bad = True |
|||
break |
|||
sub_tokens.append(cur_substr) |
|||
start = end |
|||
|
|||
if is_bad: |
|||
output_tokens.append(self.unk_token) |
|||
else: |
|||
output_tokens.extend(sub_tokens) |
|||
return output_tokens |
@ -0,0 +1,202 @@ |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
from functools import partial |
|||
|
|||
from timm.models.vision_transformer import _cfg, PatchEmbed |
|||
from timm.models.registry import register_model |
|||
from timm.models.layers import trunc_normal_, DropPath |
|||
|
|||
|
|||
class Mlp(nn.Module): |
|||
""" MLP as used in Vision Transformer, MLP-Mixer and related networks |
|||
""" |
|||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|||
super().__init__() |
|||
out_features = out_features or in_features |
|||
hidden_features = hidden_features or in_features |
|||
self.fc1 = nn.Linear(in_features, hidden_features) |
|||
self.act = act_layer() |
|||
self.fc2 = nn.Linear(hidden_features, out_features) |
|||
self.drop = nn.Dropout(drop) |
|||
|
|||
def forward(self, x): |
|||
x = self.fc1(x) |
|||
x = self.act(x) |
|||
x = self.drop(x) |
|||
x = self.fc2(x) |
|||
x = self.drop(x) |
|||
return x |
|||
|
|||
|
|||
class Attention(nn.Module): |
|||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): |
|||
super().__init__() |
|||
self.num_heads = num_heads |
|||
head_dim = dim // num_heads |
|||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights |
|||
self.scale = qk_scale or head_dim ** -0.5 |
|||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|||
self.attn_drop = nn.Dropout(attn_drop) |
|||
self.proj = nn.Linear(dim, dim) |
|||
self.proj_drop = nn.Dropout(proj_drop) |
|||
self.attn_gradients = None |
|||
self.attention_map = None |
|||
|
|||
def save_attn_gradients(self, attn_gradients): |
|||
self.attn_gradients = attn_gradients |
|||
|
|||
def get_attn_gradients(self): |
|||
return self.attn_gradients |
|||
|
|||
def save_attention_map(self, attention_map): |
|||
self.attention_map = attention_map |
|||
|
|||
def get_attention_map(self): |
|||
return self.attention_map |
|||
|
|||
def forward(self, x, register_hook=False): |
|||
B, N, C = x.shape |
|||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
|||
|
|||
attn = (q @ k.transpose(-2, -1)) * self.scale |
|||
attn = attn.softmax(dim=-1) |
|||
attn = self.attn_drop(attn) |
|||
|
|||
if register_hook: |
|||
self.save_attention_map(attn) |
|||
attn.register_hook(self.save_attn_gradients) |
|||
|
|||
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|||
x = self.proj(x) |
|||
x = self.proj_drop(x) |
|||
return x |
|||
|
|||
|
|||
class Block(nn.Module): |
|||
|
|||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., |
|||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
|||
super().__init__() |
|||
self.norm1 = norm_layer(dim) |
|||
self.attn = Attention( |
|||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
|||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here |
|||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|||
self.norm2 = norm_layer(dim) |
|||
mlp_hidden_dim = int(dim * mlp_ratio) |
|||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|||
|
|||
def forward(self, x, register_hook=False): |
|||
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook)) |
|||
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|||
return x |
|||
|
|||
|
|||
class VisionTransformer(nn.Module): |
|||
""" Vision Transformer |
|||
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - |
|||
https://arxiv.org/abs/2010.11929 |
|||
""" |
|||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, |
|||
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, |
|||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None): |
|||
""" |
|||
Args: |
|||
img_size (int, tuple): input image size |
|||
patch_size (int, tuple): patch size |
|||
in_chans (int): number of input channels |
|||
num_classes (int): number of classes for classification head |
|||
embed_dim (int): embedding dimension |
|||
depth (int): depth of transformer |
|||
num_heads (int): number of attention heads |
|||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim |
|||
qkv_bias (bool): enable bias for qkv if True |
|||
qk_scale (float): override default qk scale of head_dim ** -0.5 if set |
|||
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set |
|||
drop_rate (float): dropout rate |
|||
attn_drop_rate (float): attention dropout rate |
|||
drop_path_rate (float): stochastic depth rate |
|||
norm_layer: (nn.Module): normalization layer |
|||
""" |
|||
super().__init__() |
|||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models |
|||
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|||
|
|||
self.patch_embed = PatchEmbed( |
|||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) |
|||
num_patches = self.patch_embed.num_patches |
|||
|
|||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) |
|||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) |
|||
self.pos_drop = nn.Dropout(p=drop_rate) |
|||
|
|||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule |
|||
self.blocks = nn.ModuleList([ |
|||
Block( |
|||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) |
|||
for i in range(depth)]) |
|||
self.norm = norm_layer(embed_dim) |
|||
|
|||
trunc_normal_(self.pos_embed, std=.02) |
|||
trunc_normal_(self.cls_token, std=.02) |
|||
self.apply(self._init_weights) |
|||
|
|||
def _init_weights(self, m): |
|||
if isinstance(m, nn.Linear): |
|||
trunc_normal_(m.weight, std=.02) |
|||
if isinstance(m, nn.Linear) and m.bias is not None: |
|||
nn.init.constant_(m.bias, 0) |
|||
elif isinstance(m, nn.LayerNorm): |
|||
nn.init.constant_(m.bias, 0) |
|||
nn.init.constant_(m.weight, 1.0) |
|||
|
|||
@torch.jit.ignore |
|||
def no_weight_decay(self): |
|||
return {'pos_embed', 'cls_token'} |
|||
|
|||
def forward(self, x, register_blk=-1): |
|||
B = x.shape[0] |
|||
x = self.patch_embed(x) |
|||
|
|||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks |
|||
x = torch.cat((cls_tokens, x), dim=1) |
|||
|
|||
x = x + self.pos_embed[:,:x.size(1),:] |
|||
x = self.pos_drop(x) |
|||
|
|||
for i,blk in enumerate(self.blocks): |
|||
x = blk(x, register_blk==i) |
|||
x = self.norm(x) |
|||
|
|||
return x |
|||
|
|||
|
|||
|
|||
def interpolate_pos_embed(pos_embed_checkpoint, visual_encoder): |
|||
# interpolate position embedding |
|||
embedding_size = pos_embed_checkpoint.shape[-1] |
|||
num_patches = visual_encoder.patch_embed.num_patches |
|||
num_extra_tokens = visual_encoder.pos_embed.shape[-2] - num_patches |
|||
# height (== width) for the checkpoint position embedding |
|||
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) |
|||
# height (== width) for the new position embedding |
|||
new_size = int(num_patches ** 0.5) |
|||
|
|||
if orig_size!=new_size: |
|||
# class_token and dist_token are kept unchanged |
|||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] |
|||
# only the position tokens are interpolated |
|||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] |
|||
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) |
|||
pos_tokens = torch.nn.functional.interpolate( |
|||
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) |
|||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) |
|||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) |
|||
print('reshape position embedding from %d to %d'%(orig_size ** 2,new_size ** 2)) |
|||
|
|||
return new_pos_embed |
|||
else: |
|||
return pos_embed_checkpoint |
File diff suppressed because it is too large
Loading…
Reference in new issue