expansionnet-v2
copied
wxywb
2 years ago
9 changed files with 1732 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
from .expansionnet_v2 import ExpansionNetV2 |
|||
|
|||
def expansionnet_v2(model_name: str): |
|||
return ExpansionNetV2(model_name) |
Binary file not shown.
@ -0,0 +1,55 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
|||
|
|||
import sys |
|||
import os |
|||
from pathlib import Path |
|||
|
|||
import torch |
|||
from torchvision import transforms |
|||
from transformers import GPT2Tokenizer |
|||
|
|||
from towhee.types.arg import arg, to_image_color |
|||
from towhee.types.image_utils import to_pil |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee import register |
|||
from towhee.models import clip |
|||
|
|||
class ExpansionNetV2(NNOperator): |
|||
""" |
|||
ExpansionNet V2 image captioning operator |
|||
""" |
|||
def __init__(self, model_name: str): |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 |
|||
sys.path.pop() |
|||
with open('demo_coco_tokens.pickle') as fw: |
|||
|
|||
self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, |
|||
swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], |
|||
swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, |
|||
swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, |
|||
swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, |
|||
swin_use_checkpoint=False, |
|||
final_swin_dim=1536, |
|||
|
|||
d_model=model_args.model_dim, N_enc=model_args.N_enc, |
|||
N_dec=model_args.N_dec, num_heads=8, ff=2048, |
|||
num_exp_enc_list=[32, 64, 128, 256, 512], |
|||
num_exp_dec=16, |
|||
output_word2idx=coco_tokens['word2idx_dict'], |
|||
output_idx2word=coco_tokens['idx2word_list'], |
|||
max_seq_len=args.max_seq_len, drop_args=model_args.drop_args, |
|||
rank='cpu') |
@ -0,0 +1,187 @@ |
|||
import torch |
|||
from models.layers import EmbeddingLayer, DecoderLayer, EncoderLayer |
|||
from utils.masking import create_pad_mask, create_no_peak_and_pad_mask |
|||
from models.captioning_model import CaptioningModel |
|||
from models.swin_transformer_mod import SwinTransformer |
|||
|
|||
import torch.nn as nn |
|||
|
|||
|
|||
class End_ExpansionNet_v2(CaptioningModel): |
|||
def __init__(self, |
|||
|
|||
# swin transf |
|||
swin_img_size, swin_patch_size, swin_in_chans, |
|||
swin_embed_dim, swin_depths, swin_num_heads, |
|||
swin_window_size, swin_mlp_ratio, swin_qkv_bias, swin_qk_scale, |
|||
swin_drop_rate, swin_attn_drop_rate, swin_drop_path_rate, |
|||
swin_norm_layer, swin_ape, swin_patch_norm, |
|||
swin_use_checkpoint, |
|||
|
|||
# linear_size, |
|||
final_swin_dim, |
|||
|
|||
# captioning |
|||
d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, |
|||
output_word2idx, output_idx2word, max_seq_len, drop_args, rank=0): |
|||
super(End_ExpansionNet_v2, self).__init__() |
|||
|
|||
self.swin_transf = SwinTransformer( |
|||
img_size=swin_img_size, patch_size=swin_patch_size, in_chans=swin_in_chans, |
|||
embed_dim=swin_embed_dim, depths=swin_depths, num_heads=swin_num_heads, |
|||
window_size=swin_window_size, mlp_ratio=swin_mlp_ratio, qkv_bias=swin_qkv_bias, qk_scale=swin_qk_scale, |
|||
drop_rate=swin_drop_rate, attn_drop_rate=swin_attn_drop_rate, drop_path_rate=swin_drop_path_rate, |
|||
norm_layer=swin_norm_layer, ape=swin_ape, patch_norm=swin_patch_norm, |
|||
use_checkpoint=swin_use_checkpoint) |
|||
|
|||
self.output_word2idx = output_word2idx |
|||
self.output_idx2word = output_idx2word |
|||
self.max_seq_len = max_seq_len |
|||
|
|||
self.num_exp_dec = num_exp_dec |
|||
self.num_exp_enc_list = num_exp_enc_list |
|||
|
|||
self.N_enc = N_enc |
|||
self.N_dec = N_dec |
|||
self.d_model = d_model |
|||
|
|||
self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) |
|||
self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) |
|||
|
|||
self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) |
|||
self.input_linear = torch.nn.Linear(final_swin_dim, d_model) |
|||
self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) |
|||
self.log_softmax = nn.LogSoftmax(dim=-1) |
|||
|
|||
self.out_enc_dropout = nn.Dropout(drop_args.other) |
|||
self.out_dec_dropout = nn.Dropout(drop_args.other) |
|||
|
|||
self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) |
|||
self.pos_encoder = nn.Embedding(max_seq_len, d_model) |
|||
|
|||
self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) |
|||
self.enc_reduce_norm = nn.LayerNorm(d_model) |
|||
self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) |
|||
self.dec_reduce_norm = nn.LayerNorm(d_model) |
|||
|
|||
for p in self.parameters(): |
|||
if p.dim() > 1: |
|||
nn.init.xavier_uniform_(p) |
|||
|
|||
self.trained_steps = 0 |
|||
self.rank = rank |
|||
|
|||
self.check_required_attributes() |
|||
|
|||
def forward_enc(self, enc_input, enc_input_num_pads): |
|||
|
|||
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding" |
|||
x = self.swin_transf(enc_input) |
|||
# --------------- Normale parte di Captioning --------------------------------- |
|||
enc_input = self.input_embedder_dropout(self.input_linear(x)) |
|||
x = enc_input |
|||
enc_input_num_pads = [0] * enc_input.size(0) |
|||
|
|||
max_num_enc = sum(self.num_exp_enc_list) |
|||
pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) |
|||
pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), |
|||
pad_along_row_input=[0] * enc_input.size(0), |
|||
pad_along_column_input=enc_input_num_pads, |
|||
rank=self.rank) |
|||
|
|||
x_list = [] |
|||
for i in range(self.N_enc): |
|||
x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) |
|||
x_list.append(x) |
|||
x_list = torch.cat(x_list, dim=-1) |
|||
x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) |
|||
x = self.enc_reduce_norm(x) |
|||
|
|||
return x |
|||
|
|||
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
|||
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * cross_input.size(0))), "enc_input_num_pads should be no None" |
|||
|
|||
enc_input_num_pads = [0] * dec_input.size(0) |
|||
no_peak_and_pad_mask = create_no_peak_and_pad_mask( |
|||
mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), |
|||
num_pads=dec_input_num_pads, |
|||
rank=self.rank) |
|||
pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), |
|||
pad_along_row_input=dec_input_num_pads, |
|||
pad_along_column_input=enc_input_num_pads, |
|||
rank=self.rank) |
|||
|
|||
y = self.out_embedder(dec_input) |
|||
pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) |
|||
pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) |
|||
y = y + self.pos_encoder(pos_y) |
|||
y_list = [] |
|||
for i in range(self.N_dec): |
|||
y = self.decoders[i](x=y, |
|||
n_indexes=pos_x, |
|||
cross_connection_x=cross_input, |
|||
input_attention_mask=no_peak_and_pad_mask, |
|||
cross_attention_mask=pad_mask) |
|||
y_list.append(y) |
|||
y_list = torch.cat(y_list, dim=-1) |
|||
y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) |
|||
y = self.dec_reduce_norm(y) |
|||
|
|||
y = self.vocab_linear(y) |
|||
|
|||
if apply_log_softmax: |
|||
y = self.log_softmax(y) |
|||
|
|||
return y |
|||
|
|||
|
|||
def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, |
|||
sos_idx, eos_idx, max_seq_len): |
|||
|
|||
bs = enc_input.size(0) |
|||
x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) |
|||
enc_seq_len = x.size(1) |
|||
x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) |
|||
|
|||
upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) |
|||
where_is_eos_vector = upperbound_vector.clone() |
|||
eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) |
|||
finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) |
|||
|
|||
predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) |
|||
predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) |
|||
|
|||
dec_input_num_pads = [0]*(bs*num_outputs) |
|||
time_step = 0 |
|||
while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: |
|||
dec_input = predicted_caption |
|||
log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) |
|||
|
|||
prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) |
|||
sampled_word_indexes = prob_dist.sample() |
|||
|
|||
predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) |
|||
predicted_caption_prob = torch.cat((predicted_caption_prob, |
|||
log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) |
|||
time_step += 1 |
|||
|
|||
where_is_eos_vector = torch.min(where_is_eos_vector, |
|||
upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) |
|||
finished_flag_vector = torch.max(finished_flag_vector, |
|||
(sampled_word_indexes == eos_vector).type(torch.IntTensor)) |
|||
|
|||
res_predicted_caption = [] |
|||
for i in range(bs): |
|||
res_predicted_caption.append([]) |
|||
for j in range(num_outputs): |
|||
index = i*num_outputs + j |
|||
res_predicted_caption[i].append( |
|||
predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) |
|||
|
|||
where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) |
|||
arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) |
|||
predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) |
|||
res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) |
|||
|
|||
return res_predicted_caption, res_predicted_caption_prob |
@ -0,0 +1,103 @@ |
|||
import torch |
|||
from models.layers import EmbeddingLayer, EncoderLayer, DecoderLayer |
|||
from utils.masking import create_pad_mask, create_no_peak_and_pad_mask |
|||
from models.captioning_model import CaptioningModel |
|||
|
|||
import torch.nn as nn |
|||
|
|||
|
|||
class ExpansionNet_v2(CaptioningModel): |
|||
def __init__(self, d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, |
|||
output_word2idx, output_idx2word, max_seq_len, drop_args, img_feature_dim=2048, rank=0): |
|||
super().__init__() |
|||
self.output_word2idx = output_word2idx |
|||
self.output_idx2word = output_idx2word |
|||
self.max_seq_len = max_seq_len |
|||
|
|||
self.num_exp_dec = num_exp_dec |
|||
self.num_exp_enc_list = num_exp_enc_list |
|||
|
|||
self.N_enc = N_enc |
|||
self.N_dec = N_dec |
|||
self.d_model = d_model |
|||
|
|||
self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) |
|||
self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) |
|||
|
|||
self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) |
|||
self.input_linear = torch.nn.Linear(img_feature_dim, d_model) |
|||
self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) |
|||
self.log_softmax = nn.LogSoftmax(dim=-1) |
|||
|
|||
self.out_enc_dropout = nn.Dropout(drop_args.other) |
|||
self.out_dec_dropout = nn.Dropout(drop_args.other) |
|||
|
|||
self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) |
|||
self.pos_encoder = nn.Embedding(max_seq_len, d_model) |
|||
|
|||
self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) |
|||
self.enc_reduce_norm = nn.LayerNorm(d_model) |
|||
self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) |
|||
self.dec_reduce_norm = nn.LayerNorm(d_model) |
|||
|
|||
for p in self.parameters(): |
|||
if p.dim() > 1: |
|||
nn.init.xavier_uniform_(p) |
|||
|
|||
self.trained_steps = 0 |
|||
self.rank = rank |
|||
|
|||
def forward_enc(self, enc_input, enc_input_num_pads): |
|||
|
|||
x = self.input_embedder_dropout(self.input_linear(enc_input)) |
|||
|
|||
max_num_enc = sum(self.num_exp_enc_list) |
|||
pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) |
|||
pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), |
|||
pad_along_row_input=[0] * enc_input.size(0), |
|||
pad_along_column_input=enc_input_num_pads, |
|||
rank=self.rank) |
|||
|
|||
x_list = [] |
|||
for i in range(self.N_enc): |
|||
x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) |
|||
x_list.append(x) |
|||
x_list = torch.cat(x_list, dim=-1) |
|||
x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) |
|||
x = self.enc_reduce_norm(x) |
|||
return x |
|||
|
|||
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
|||
|
|||
no_peak_and_pad_mask = create_no_peak_and_pad_mask( |
|||
mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), |
|||
num_pads=dec_input_num_pads, |
|||
rank=self.rank) |
|||
|
|||
pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), |
|||
pad_along_row_input=dec_input_num_pads, |
|||
pad_along_column_input=enc_input_num_pads, |
|||
rank=self.rank) |
|||
|
|||
y = self.out_embedder(dec_input) |
|||
pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) |
|||
pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) |
|||
y = y + self.pos_encoder(pos_y) |
|||
y_list = [] |
|||
for i in range(self.N_dec): |
|||
y = self.decoders[i](x=y, |
|||
n_indexes=pos_x, |
|||
cross_connection_x=cross_input, |
|||
input_attention_mask=no_peak_and_pad_mask, |
|||
cross_attention_mask=pad_mask) |
|||
y_list.append(y) |
|||
y_list = torch.cat(y_list, dim=-1) |
|||
y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) |
|||
y = self.dec_reduce_norm(y) |
|||
|
|||
y = self.vocab_linear(y) |
|||
|
|||
if apply_log_softmax: |
|||
y = self.log_softmax(y) |
|||
|
|||
return y |
@ -0,0 +1,241 @@ |
|||
|
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
|
|||
|
|||
class CaptioningModel(nn.Module): |
|||
def __init__(self): |
|||
super(CaptioningModel, self).__init__() |
|||
# mandatory attributes |
|||
# rank: to enable multiprocessing |
|||
self.rank = None |
|||
|
|||
def check_required_attributes(self): |
|||
if self.rank is None: |
|||
raise NotImplementedError("Subclass must assign the rank integer according to the GPU group") |
|||
|
|||
def forward_enc(self, enc_input, enc_input_num_pads): |
|||
raise NotImplementedError |
|||
|
|||
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
|||
raise NotImplementedError |
|||
|
|||
def forward(self, enc_x, dec_x=None, |
|||
enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, |
|||
mode='forward', **kwargs): |
|||
if mode == 'forward': |
|||
x = self.forward_enc(enc_x, enc_x_num_pads) |
|||
y = self.forward_dec(x, enc_x_num_pads, dec_x, dec_x_num_pads, apply_log_softmax) |
|||
return y |
|||
else: |
|||
assert ('sos_idx' in kwargs.keys() or 'eos_idx' in kwargs.keys()), \ |
|||
'sos and eos must be provided in case of batch sampling or beam search' |
|||
sos_idx = kwargs.get('sos_idx', -999) |
|||
eos_idx = kwargs.get('eos_idx', -999) |
|||
if mode == 'beam_search': |
|||
beam_size_arg = kwargs.get('beam_size', 5) |
|||
how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) |
|||
beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) |
|||
sample_or_max = kwargs.get('sample_or_max', 'max') |
|||
out_classes, out_logprobs = self.beam_search( |
|||
enc_x, enc_x_num_pads, |
|||
beam_size=beam_size_arg, |
|||
sos_idx=sos_idx, |
|||
eos_idx=eos_idx, |
|||
how_many_outputs=how_many_outputs_per_beam, |
|||
max_seq_len=beam_max_seq_len, |
|||
sample_or_max=sample_or_max) |
|||
return out_classes, out_logprobs |
|||
if mode == 'sampling': |
|||
how_many_outputs = kwargs.get('how_many_outputs', 1) |
|||
sample_max_seq_len = kwargs.get('sample_max_seq_len', 20) |
|||
out_classes, out_logprobs = self.get_batch_multiple_sampled_prediction( |
|||
enc_x, enc_x_num_pads, num_outputs=how_many_outputs, |
|||
sos_idx=sos_idx, eos_idx=eos_idx, |
|||
max_seq_len=sample_max_seq_len) |
|||
return out_classes, out_logprobs |
|||
|
|||
def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, |
|||
sos_idx, eos_idx, max_seq_len): |
|||
bs, enc_seq_len, _ = enc_input.shape |
|||
|
|||
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(num_outputs)] |
|||
|
|||
x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) |
|||
x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) |
|||
|
|||
upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) |
|||
where_is_eos_vector = upperbound_vector.clone() |
|||
eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) |
|||
finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) |
|||
|
|||
predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) |
|||
predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) |
|||
|
|||
dec_input_num_pads = [0]*(bs*num_outputs) |
|||
time_step = 0 |
|||
while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: |
|||
dec_input = predicted_caption |
|||
log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) |
|||
|
|||
prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) |
|||
sampled_word_indexes = prob_dist.sample() |
|||
|
|||
predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) |
|||
predicted_caption_prob = torch.cat((predicted_caption_prob, |
|||
log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) |
|||
time_step += 1 |
|||
|
|||
where_is_eos_vector = torch.min(where_is_eos_vector, |
|||
upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) |
|||
finished_flag_vector = torch.max(finished_flag_vector, |
|||
(sampled_word_indexes == eos_vector).type(torch.IntTensor)) |
|||
|
|||
# remove the elements that come after the first eos from the sequence |
|||
res_predicted_caption = [] |
|||
for i in range(bs): |
|||
res_predicted_caption.append([]) |
|||
for j in range(num_outputs): |
|||
index = i*num_outputs + j |
|||
res_predicted_caption[i].append( |
|||
predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) |
|||
|
|||
where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) |
|||
arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) |
|||
predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) |
|||
res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) |
|||
|
|||
return res_predicted_caption, res_predicted_caption_prob |
|||
|
|||
def beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, |
|||
beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): |
|||
assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" |
|||
assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" |
|||
bs = enc_input.shape[0] |
|||
|
|||
cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads) |
|||
|
|||
# init: ------------------------------------------------------------------ |
|||
init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) |
|||
init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) |
|||
log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, |
|||
dec_input=init_dec_class, dec_input_num_pads=[0] * bs, |
|||
apply_log_softmax=True) |
|||
if sample_or_max == 'max': |
|||
_, topi = torch.topk(log_probs, k=beam_size, sorted=True) |
|||
else: # sample |
|||
topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) |
|||
topi = topi.unsqueeze(1) |
|||
|
|||
init_dec_class = init_dec_class.repeat(1, beam_size) |
|||
init_dec_class = init_dec_class.unsqueeze(-1) |
|||
top_beam_size_class = topi.transpose(-2, -1) |
|||
init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) |
|||
|
|||
init_dec_logprob = init_dec_logprob.repeat(1, beam_size) |
|||
init_dec_logprob = init_dec_logprob.unsqueeze(-1) |
|||
top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) |
|||
top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) |
|||
init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) |
|||
|
|||
bs, enc_seq_len, d_model = cross_enc_output.shape |
|||
cross_enc_output = cross_enc_output.unsqueeze(1) |
|||
cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) |
|||
cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() |
|||
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] |
|||
|
|||
# loop: ----------------------------------------------------------------- |
|||
loop_dec_classes = init_dec_class |
|||
loop_dec_logprobs = init_dec_logprob |
|||
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
|||
|
|||
loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) |
|||
|
|||
for time_step in range(2, max_seq_len): |
|||
loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() |
|||
|
|||
log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, |
|||
dec_input=loop_dec_classes, |
|||
dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), |
|||
apply_log_softmax=True) |
|||
if sample_or_max == 'max': |
|||
_, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) |
|||
else: # sample |
|||
topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, |
|||
replacement=False) |
|||
|
|||
top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) |
|||
|
|||
top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) |
|||
top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) |
|||
|
|||
# each sequence have now its best prediction, but some sequence may have already been terminated with EOS, |
|||
# in that case its candidates are simply ignored, and do not sum up in the "loop_dec_logprobs" their value |
|||
# are set to zero |
|||
there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ |
|||
sum(dim=-1, keepdims=True).type(torch.bool) |
|||
|
|||
# if we pad with -999 its candidates logprobabilities, also the sequence containing EOS would be |
|||
# straightforwardly discarded, instead we want to keep it in the exploration. Therefore we mask with 0.0 |
|||
# one arbitrary candidate word probability so the sequence probability is unchanged but it |
|||
# can still be discarded when a better candidate sequence is found |
|||
top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) |
|||
top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) |
|||
|
|||
comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs |
|||
|
|||
comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) |
|||
_, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) |
|||
which_sequence = topi // beam_size |
|||
which_word = topi % beam_size |
|||
|
|||
loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) |
|||
loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) |
|||
|
|||
bs_idxes = torch.arange(bs).unsqueeze(-1) |
|||
new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] |
|||
new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] |
|||
|
|||
which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] |
|||
which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ |
|||
[bs_idxes, which_sequence]] |
|||
which_word = which_word.unsqueeze(-1) |
|||
|
|||
lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, |
|||
index=which_word) |
|||
lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) |
|||
|
|||
new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) |
|||
new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) |
|||
loop_dec_classes = new_loop_dec_classes |
|||
loop_dec_logprobs = new_loop_dec_logprobs |
|||
|
|||
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
|||
|
|||
# -----------------------update loop_num_elem_vector ---------------------------- |
|||
loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) |
|||
there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ |
|||
sum(dim=-1).type(torch.bool).view(bs * beam_size) |
|||
loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) |
|||
|
|||
if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): |
|||
break |
|||
|
|||
# sort out the best result |
|||
loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) |
|||
_, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) |
|||
res_caption_pred = [[] for _ in range(bs)] |
|||
res_caption_logprob = [[] for _ in range(bs)] |
|||
for i in range(bs): |
|||
for j in range(how_many_outputs): |
|||
idx = topi[i, j].item() |
|||
res_caption_pred[i].append( |
|||
loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) |
|||
res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) |
|||
|
|||
flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] |
|||
flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) |
|||
res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) |
|||
|
|||
return res_caption_pred, res_caption_logprob |
@ -0,0 +1,187 @@ |
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
from models.captioning_model import CaptioningModel |
|||
|
|||
|
|||
class EsembleCaptioningModel(CaptioningModel): |
|||
def __init__(self, models_list, rank): |
|||
super().__init__() |
|||
self.num_models = len(models_list) |
|||
self.models_list = models_list |
|||
self.rank = rank |
|||
|
|||
self.dummy_linear = nn.Linear(1, 1) |
|||
|
|||
for model in self.models_list: |
|||
model.eval() |
|||
|
|||
def forward(self, enc_x, dec_x=None, |
|||
enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, |
|||
mode='beam_search', **kwargs): |
|||
assert (mode == 'beam_search'), "this class supports only beam search." |
|||
sos_idx = kwargs.get('sos_idx', -999) |
|||
eos_idx = kwargs.get('eos_idx', -999) |
|||
if mode == 'beam_search': |
|||
beam_size_arg = kwargs.get('beam_size', 5) |
|||
how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) |
|||
beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) |
|||
sample_or_max = kwargs.get('sample_or_max', 'max') |
|||
out_classes, out_logprobs = self.ensemble_beam_search( |
|||
enc_x, enc_x_num_pads, |
|||
beam_size=beam_size_arg, |
|||
sos_idx=sos_idx, |
|||
eos_idx=eos_idx, |
|||
how_many_outputs=how_many_outputs_per_beam, |
|||
max_seq_len=beam_max_seq_len, |
|||
sample_or_max=sample_or_max) |
|||
return out_classes, out_logprobs |
|||
|
|||
def forward_enc(self, enc_input, enc_input_num_pads): |
|||
x_outputs_list = [] |
|||
for i in range(self.num_models): |
|||
x_outputs = self.models_list[i].forward_enc(enc_input, enc_input_num_pads) |
|||
x_outputs_list.append(x_outputs) |
|||
return x_outputs_list |
|||
|
|||
def forward_dec(self, cross_input_list, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
|||
|
|||
import torch.nn.functional as F |
|||
y_outputs = [] |
|||
for i in range(self.num_models): |
|||
y_outputs.append( |
|||
F.softmax(self.models_list[i].forward_dec( |
|||
cross_input_list[i], enc_input_num_pads, |
|||
dec_input, dec_input_num_pads, False).unsqueeze(0), dim=-1)) |
|||
avg = torch.cat(y_outputs, dim=0).mean(dim=0).log() |
|||
|
|||
return avg |
|||
|
|||
# quite unclean coding, to be re-factored in the future... |
|||
# since it's a bit similar to the single model case |
|||
def ensemble_beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, |
|||
beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): |
|||
assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" |
|||
assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" |
|||
bs = enc_input.shape[0] |
|||
|
|||
# the cross_dec_input is computed once |
|||
cross_enc_output_list = self.forward_enc(enc_input, enc_input_num_pads) |
|||
|
|||
# init: ------------------------------------------------------------------ |
|||
init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) |
|||
init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) |
|||
log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, |
|||
dec_input=init_dec_class, dec_input_num_pads=[0] * bs, |
|||
apply_log_softmax=True) |
|||
if sample_or_max == 'max': |
|||
_, topi = torch.topk(log_probs, k=beam_size, sorted=True) |
|||
else: # sample |
|||
topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) |
|||
topi = topi.unsqueeze(1) |
|||
|
|||
init_dec_class = init_dec_class.repeat(1, beam_size) |
|||
init_dec_class = init_dec_class.unsqueeze(-1) |
|||
top_beam_size_class = topi.transpose(-2, -1) |
|||
init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) |
|||
|
|||
init_dec_logprob = init_dec_logprob.repeat(1, beam_size) |
|||
init_dec_logprob = init_dec_logprob.unsqueeze(-1) |
|||
top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) |
|||
top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) |
|||
init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) |
|||
|
|||
tmp_cross_enc_output_list = [] |
|||
for cross_enc_output in cross_enc_output_list: |
|||
bs, enc_seq_len, d_model = cross_enc_output.shape |
|||
cross_enc_output = cross_enc_output.unsqueeze(1) |
|||
cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) |
|||
cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() |
|||
tmp_cross_enc_output_list.append(cross_enc_output) |
|||
cross_enc_output_list = tmp_cross_enc_output_list |
|||
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] |
|||
|
|||
loop_dec_classes = init_dec_class |
|||
loop_dec_logprobs = init_dec_logprob |
|||
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
|||
|
|||
loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) |
|||
|
|||
for time_step in range(2, max_seq_len): |
|||
loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() |
|||
|
|||
log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, |
|||
dec_input=loop_dec_classes, |
|||
dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), |
|||
apply_log_softmax=True) |
|||
if sample_or_max == 'max': |
|||
_, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) |
|||
else: # sample |
|||
topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, |
|||
replacement=False) |
|||
|
|||
top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) |
|||
|
|||
top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) |
|||
top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) |
|||
|
|||
there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ |
|||
sum(dim=-1, keepdims=True).type(torch.bool) |
|||
|
|||
top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) |
|||
top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) |
|||
|
|||
comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs |
|||
|
|||
comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) |
|||
_, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) |
|||
which_sequence = topi // beam_size |
|||
which_word = topi % beam_size |
|||
|
|||
loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) |
|||
loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) |
|||
|
|||
bs_idxes = torch.arange(bs).unsqueeze(-1) |
|||
new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] |
|||
new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] |
|||
|
|||
which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] |
|||
which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ |
|||
[bs_idxes, which_sequence]] |
|||
which_word = which_word.unsqueeze(-1) |
|||
|
|||
lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, |
|||
index=which_word) |
|||
lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) |
|||
|
|||
new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) |
|||
new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) |
|||
loop_dec_classes = new_loop_dec_classes |
|||
loop_dec_logprobs = new_loop_dec_logprobs |
|||
|
|||
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
|||
|
|||
loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) |
|||
there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ |
|||
sum(dim=-1).type(torch.bool).view(bs * beam_size) |
|||
loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) |
|||
|
|||
if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): |
|||
break |
|||
|
|||
loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) |
|||
_, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) |
|||
res_caption_pred = [[] for _ in range(bs)] |
|||
res_caption_logprob = [[] for _ in range(bs)] |
|||
for i in range(bs): |
|||
for j in range(how_many_outputs): |
|||
idx = topi[i, j].item() |
|||
res_caption_pred[i].append( |
|||
loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) |
|||
res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) |
|||
|
|||
flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] |
|||
flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) |
|||
res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) |
|||
|
|||
return res_caption_pred, res_caption_logprob |
@ -0,0 +1,286 @@ |
|||
import torch |
|||
import torch.nn as nn |
|||
import math |
|||
|
|||
import numpy as np |
|||
import torch.nn.functional as F |
|||
|
|||
class EmbeddingLayer(nn.Module): |
|||
def __init__(self, vocab_size, d_model, dropout_perc): |
|||
super(EmbeddingLayer, self).__init__() |
|||
self.dropout = nn.Dropout(dropout_perc) |
|||
self.embed = nn.Embedding(vocab_size, d_model) |
|||
self.d_model = d_model |
|||
|
|||
def forward(self, x): |
|||
return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model)) |
|||
|
|||
|
|||
class PositionalEncoder(nn.Module): |
|||
def __init__(self, d_model, max_seq_len, rank=0): |
|||
super().__init__() |
|||
assert d_model % 2 == 0, "d_model is not even, even number suggested" |
|||
self.d_model = d_model |
|||
self.pe = torch.zeros(max_seq_len, d_model).to(rank) |
|||
for pos in range(max_seq_len): |
|||
for i in range(0, d_model, 2): |
|||
self.pe.data[pos, i] = math.sin(pos / (10000.0 ** ((2.0 * i) / d_model))) |
|||
self.pe.data[pos, i + 1] = math.cos(pos / (10000.0 ** ((2.0 * i) / d_model))) |
|||
self.pe.data = self.pe.data.unsqueeze(0) |
|||
|
|||
def forward(self, x): |
|||
seq_len = x.shape[1] |
|||
return self.pe.data[0, :seq_len] |
|||
|
|||
|
|||
|
|||
class StaticExpansionBlock(nn.Module): |
|||
def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps): |
|||
super().__init__() |
|||
self.d_model = d_model |
|||
self.num_enc_exp_list = num_enc_exp_list |
|||
|
|||
self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) |
|||
self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) |
|||
|
|||
self.key_embed = nn.Linear(d_model, d_model) |
|||
self.class_a_embed = nn.Linear(d_model, d_model) |
|||
self.class_b_embed = nn.Linear(d_model, d_model) |
|||
|
|||
self.selector_embed = nn.Linear(d_model, d_model) |
|||
|
|||
self.dropout_class_a_fw = nn.Dropout(dropout_perc) |
|||
self.dropout_class_b_fw = nn.Dropout(dropout_perc) |
|||
|
|||
self.dropout_class_a_bw = nn.Dropout(dropout_perc) |
|||
self.dropout_class_b_bw = nn.Dropout(dropout_perc) |
|||
|
|||
self.Z_dropout = nn.Dropout(dropout_perc) |
|||
|
|||
self.eps = eps |
|||
|
|||
def forward(self, x, n_indexes, mask): |
|||
bs, enc_len, _ = x.shape |
|||
|
|||
query_exp = self.query_exp_vectors(n_indexes) |
|||
bias_exp = self.bias_exp_vectors(n_indexes) |
|||
x_key = self.key_embed(x) |
|||
|
|||
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model) |
|||
z = self.Z_dropout(z) |
|||
|
|||
class_a_fw = F.relu(z) |
|||
class_b_fw = F.relu(-z) |
|||
class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0) |
|||
class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0) |
|||
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) |
|||
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) |
|||
|
|||
class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp |
|||
class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp |
|||
class_a = self.dropout_class_a_fw(class_a) |
|||
class_b = self.dropout_class_b_fw(class_b) |
|||
|
|||
class_a_bw = F.relu(z.transpose(-2, -1)) |
|||
class_b_bw = F.relu(-z.transpose(-2, -1)) |
|||
|
|||
accum = 0 |
|||
class_a_bw_list = [] |
|||
class_b_bw_list = [] |
|||
for j in range(len(self.num_enc_exp_list)): |
|||
from_idx = accum |
|||
to_idx = accum + self.num_enc_exp_list[j] |
|||
accum += self.num_enc_exp_list[j] |
|||
class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) |
|||
class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) |
|||
class_a_bw = torch.cat(class_a_bw_list, dim=-1) |
|||
class_b_bw = torch.cat(class_b_bw_list, dim=-1) |
|||
|
|||
class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list) |
|||
class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list) |
|||
class_a = self.dropout_class_a_bw(class_a) |
|||
class_b = self.dropout_class_b_bw(class_b) |
|||
|
|||
selector = torch.sigmoid(self.selector_embed(x)) |
|||
x_result = selector * class_a + (1 - selector) * class_b |
|||
|
|||
return x_result |
|||
|
|||
|
|||
class EncoderLayer(nn.Module): |
|||
def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9): |
|||
super().__init__() |
|||
self.norm_1 = nn.LayerNorm(d_model) |
|||
self.norm_2 = nn.LayerNorm(d_model) |
|||
self.dropout_1 = nn.Dropout(dropout_perc) |
|||
self.dropout_2 = nn.Dropout(dropout_perc) |
|||
|
|||
self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps) |
|||
self.ff = FeedForward(d_model, d_ff, dropout_perc) |
|||
|
|||
def forward(self, x, n_indexes, mask): |
|||
x2 = self.norm_1(x) |
|||
x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask)) |
|||
x2 = self.norm_2(x) |
|||
x = x + self.dropout_2(self.ff(x2)) |
|||
return x |
|||
|
|||
|
|||
class DynamicExpansionBlock(nn.Module): |
|||
def __init__(self, d_model, num_exp, dropout_perc, eps): |
|||
super().__init__() |
|||
self.d_model = d_model |
|||
|
|||
self.num_exp = num_exp |
|||
self.cond_embed = nn.Linear(d_model, d_model) |
|||
|
|||
self.query_exp_vectors = nn.Embedding(self.num_exp, d_model) |
|||
self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model) |
|||
|
|||
self.key_linear = nn.Linear(d_model, d_model) |
|||
self.class_a_embed = nn.Linear(d_model, d_model) |
|||
self.class_b_embed = nn.Linear(d_model, d_model) |
|||
|
|||
self.selector_embed = nn.Linear(d_model, d_model) |
|||
|
|||
self.dropout_class_a_fw = nn.Dropout(dropout_perc) |
|||
self.dropout_class_b_fw = nn.Dropout(dropout_perc) |
|||
self.dropout_class_a_bw = nn.Dropout(dropout_perc) |
|||
self.dropout_class_b_bw = nn.Dropout(dropout_perc) |
|||
|
|||
self.Z_dropout = nn.Dropout(dropout_perc) |
|||
|
|||
self.eps = eps |
|||
|
|||
def forward(self, x, n_indexes, mask): |
|||
bs, dec_len, _ = x.shape |
|||
|
|||
cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model) |
|||
query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1) |
|||
bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1) |
|||
query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) |
|||
bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) |
|||
|
|||
x_key = self.key_linear(x) |
|||
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model) |
|||
z = self.Z_dropout(z) |
|||
|
|||
mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \ |
|||
view(bs, dec_len * self.num_exp, dec_len) |
|||
|
|||
class_a_fw = F.relu(z) |
|||
class_b_fw = F.relu(-z) |
|||
class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0) |
|||
class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0) |
|||
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) |
|||
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) |
|||
class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) |
|||
class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) |
|||
class_a = self.dropout_class_a_fw(class_a) |
|||
class_b = self.dropout_class_b_fw(class_b) |
|||
|
|||
mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \ |
|||
view(bs, dec_len, dec_len * self.num_exp) |
|||
|
|||
class_a_bw = F.relu(z.transpose(-2, -1)) |
|||
class_b_bw = F.relu(-z.transpose(-2, -1)) |
|||
class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0) |
|||
class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0) |
|||
class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps) |
|||
class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps) |
|||
class_a = torch.matmul(class_a_bw, class_a + bias_exp) |
|||
class_b = torch.matmul(class_b_bw, class_b + bias_exp) |
|||
class_a = self.dropout_class_a_bw(class_a) |
|||
class_b = self.dropout_class_b_bw(class_b) |
|||
|
|||
selector = torch.sigmoid(self.selector_embed(x)) |
|||
x_result = selector * class_a + (1 - selector) * class_b |
|||
|
|||
return x_result |
|||
|
|||
|
|||
class DecoderLayer(nn.Module): |
|||
def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9): |
|||
super().__init__() |
|||
self.norm_1 = nn.LayerNorm(d_model) |
|||
self.norm_2 = nn.LayerNorm(d_model) |
|||
self.norm_3 = nn.LayerNorm(d_model) |
|||
|
|||
self.dropout_1 = nn.Dropout(dropout_perc) |
|||
self.dropout_2 = nn.Dropout(dropout_perc) |
|||
self.dropout_3 = nn.Dropout(dropout_perc) |
|||
|
|||
self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc) |
|||
self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps) |
|||
self.ff = FeedForward(d_model, d_ff, dropout_perc) |
|||
|
|||
def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask): |
|||
|
|||
# Pre-LayerNorm |
|||
x2 = self.norm_1(x) |
|||
x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask)) |
|||
|
|||
x2 = self.norm_2(x) |
|||
x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x, |
|||
mask=cross_attention_mask)) |
|||
|
|||
x2 = self.norm_3(x) |
|||
x = x + self.dropout_3(self.ff(x2)) |
|||
return x |
|||
|
|||
|
|||
|
|||
class MultiHeadAttention(nn.Module): |
|||
def __init__(self, d_model, num_heads, dropout_perc): |
|||
super(MultiHeadAttention, self).__init__() |
|||
assert d_model % num_heads == 0, "num heads must be multiple of d_model" |
|||
|
|||
self.d_model = d_model |
|||
self.d_k = int(d_model / num_heads) |
|||
self.num_heads = num_heads |
|||
|
|||
self.Wq = nn.Linear(d_model, self.d_k * num_heads) |
|||
self.Wk = nn.Linear(d_model, self.d_k * num_heads) |
|||
self.Wv = nn.Linear(d_model, self.d_k * num_heads) |
|||
|
|||
self.out_linear = nn.Linear(d_model, d_model) |
|||
|
|||
def forward(self, q, k, v, mask=None): |
|||
batch_size, q_seq_len, _ = q.shape |
|||
k_seq_len = k.size(1) |
|||
v_seq_len = v.size(1) |
|||
|
|||
k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k) |
|||
q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k) |
|||
v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k) |
|||
|
|||
k_proj = k_proj.transpose(2, 1) |
|||
q_proj = q_proj.transpose(2, 1) |
|||
v_proj = v_proj.transpose(2, 1) |
|||
|
|||
sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2)) |
|||
sim_scores = sim_scores / self.d_k ** 0.5 |
|||
|
|||
if mask is not None: |
|||
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) |
|||
sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4) |
|||
sim_scores = F.softmax(input=sim_scores, dim=-1) |
|||
|
|||
attention_applied = torch.matmul(sim_scores, v_proj) |
|||
attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\ |
|||
.view(batch_size, q_seq_len, self.d_model) |
|||
|
|||
out = self.out_linear(attention_applied_concatenated) |
|||
return out |
|||
|
|||
class FeedForward(nn.Module): |
|||
def __init__(self, d_model, d_ff, dropout_perc): |
|||
super(FeedForward, self).__init__() |
|||
self.linear_1 = nn.Linear(d_model, d_ff) |
|||
self.dropout = nn.Dropout(dropout_perc) |
|||
self.linear_2 = nn.Linear(d_ff, d_model) |
|||
|
|||
def forward(self, x): |
|||
x = self.dropout(F.relu(self.linear_1(x))) |
|||
x = self.linear_2(x) |
|||
return x |
@ -0,0 +1,655 @@ |
|||
# -------------------------------------------------------- |
|||
# Swin Transformer |
|||
# Copyright (c) 2021 Microsoft |
|||
# Licensed under The MIT License [see LICENSE for details] |
|||
# Written by Ze Liu |
|||
# -------------------------------------------------------- |
|||
|
|||
# --------------------------------- |
|||
# All credits due to Ze Liu: https://github.com/microsoft/Swin-Transformer |
|||
# and the additional sources: |
|||
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py |
|||
# https://github.com/yukimasano/PASS/blob/main/vision_transformer.py |
|||
# --------------------------------- |
|||
|
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.utils.checkpoint as checkpoint |
|||
|
|||
|
|||
class DropPath(nn.Module): |
|||
def __init__(self, drop_prob): |
|||
super().__init__() |
|||
self.drop_prob = drop_prob |
|||
|
|||
def forward(self, x): |
|||
if not self.training: |
|||
return x |
|||
keep_prob = 1 - self.drop_prob |
|||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets |
|||
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|||
random_tensor.floor_() # binarize |
|||
output = x.div(keep_prob) * random_tensor |
|||
return output |
|||
|
|||
|
|||
import collections.abc |
|||
def to_2tuple(x): |
|||
if isinstance(x, collections.abc.Iterable): |
|||
return x |
|||
return (x, x) |
|||
|
|||
|
|||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
|||
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
|||
|
|||
|
|||
import warnings |
|||
import math |
|||
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
|||
# Cut & paste from PyTorch official repo master until it's in a few official releases - RW |
|||
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf |
|||
def norm_cdf(x): |
|||
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
|||
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|||
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|||
"The distribution of values may be incorrect.", |
|||
stacklevel=2) |
|||
|
|||
with torch.no_grad(): |
|||
# Values are generated by using a truncated uniform distribution and |
|||
# then using the inverse CDF for the normal distribution. |
|||
# Get upper and lower cdf values |
|||
l = norm_cdf((a - mean) / std) |
|||
u = norm_cdf((b - mean) / std) |
|||
|
|||
# Uniformly fill tensor with values from [l, u], then translate to |
|||
# [2l-1, 2u-1]. |
|||
tensor.uniform_(2 * l - 1, 2 * u - 1) |
|||
|
|||
# Use inverse cdf transform for normal distribution to get truncated |
|||
# standard normal |
|||
tensor.erfinv_() |
|||
|
|||
# Transform to proper mean, std |
|||
tensor.mul_(std * math.sqrt(2.)) |
|||
tensor.add_(mean) |
|||
|
|||
# Clamp to ensure it's in the proper range |
|||
tensor.clamp_(min=a, max=b) |
|||
return tensor |
|||
|
|||
|
|||
class Mlp(nn.Module): |
|||
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|||
super().__init__() |
|||
out_features = out_features or in_features |
|||
hidden_features = hidden_features or in_features |
|||
self.fc1 = nn.Linear(in_features, hidden_features) |
|||
self.act = act_layer() |
|||
self.fc2 = nn.Linear(hidden_features, out_features) |
|||
self.drop = nn.Dropout(drop) |
|||
|
|||
def forward(self, x): |
|||
x = self.fc1(x) |
|||
x = self.act(x) |
|||
x = self.drop(x) |
|||
x = self.fc2(x) |
|||
x = self.drop(x) |
|||
return x |
|||
|
|||
|
|||
def window_partition(x, window_size): |
|||
""" |
|||
Args: |
|||
x: (B, H, W, C) |
|||
window_size (int): window size |
|||
|
|||
Returns: |
|||
windows: (num_windows*B, window_size, window_size, C) |
|||
""" |
|||
B, H, W, C = x.shape |
|||
# mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 |
|||
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
|||
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
|||
return windows |
|||
|
|||
|
|||
def window_reverse(windows, window_size, H, W): |
|||
""" |
|||
Args: |
|||
windows: (num_windows*B, window_size, window_size, C) |
|||
window_size (int): Window size |
|||
H (int): Height of image |
|||
W (int): Width of image |
|||
|
|||
Returns: |
|||
x: (B, H, W, C) |
|||
""" |
|||
B = int(windows.shape[0] / (H * W / window_size / window_size)) |
|||
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
|||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
|||
return x |
|||
|
|||
|
|||
class WindowAttention(nn.Module): |
|||
r""" Window based multi-head self attention (W-MSA) module with relative position bias. |
|||
It supports both of shifted and non-shifted window. |
|||
|
|||
Args: |
|||
dim (int): Number of input channels. |
|||
window_size (tuple[int]): The height and width of the window. |
|||
num_heads (int): Number of attention heads. |
|||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
|||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
|||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
|||
""" |
|||
|
|||
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
|||
|
|||
super().__init__() |
|||
self.dim = dim |
|||
self.window_size = window_size # Wh, Ww |
|||
self.num_heads = num_heads |
|||
head_dim = dim // num_heads |
|||
self.scale = qk_scale or head_dim ** -0.5 |
|||
|
|||
# define a parameter table of relative position bias |
|||
self.relative_position_bias_table = nn.Parameter( |
|||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH |
|||
|
|||
# get pair-wise relative position index for each token inside the window |
|||
coords_h = torch.arange(self.window_size[0]) |
|||
coords_w = torch.arange(self.window_size[1]) |
|||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww |
|||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww |
|||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww |
|||
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 |
|||
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 |
|||
relative_coords[:, :, 1] += self.window_size[1] - 1 |
|||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
|||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww |
|||
self.register_buffer("relative_position_index", relative_position_index) |
|||
|
|||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|||
self.attn_drop = nn.Dropout(attn_drop) |
|||
self.proj = nn.Linear(dim, dim) |
|||
self.proj_drop = nn.Dropout(proj_drop) |
|||
|
|||
trunc_normal_(self.relative_position_bias_table, std=.02) |
|||
self.softmax = nn.Softmax(dim=-1) |
|||
|
|||
def forward(self, x, mask=None): |
|||
""" |
|||
Args: |
|||
x: input features with shape of (num_windows*B, N, C) |
|||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
|||
""" |
|||
B_, N, C = x.shape |
|||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
|||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
|||
|
|||
q = q * self.scale |
|||
attn = (q @ k.transpose(-2, -1)) |
|||
|
|||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
|||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH |
|||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww |
|||
attn = attn + relative_position_bias.unsqueeze(0) |
|||
|
|||
if mask is not None: |
|||
nW = mask.shape[0] |
|||
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
|||
attn = attn.view(-1, self.num_heads, N, N) |
|||
attn = self.softmax(attn) |
|||
else: |
|||
attn = self.softmax(attn) |
|||
|
|||
attn = self.attn_drop(attn) |
|||
|
|||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
|||
x = self.proj(x) |
|||
x = self.proj_drop(x) |
|||
return x |
|||
|
|||
def extra_repr(self) -> str: |
|||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' |
|||
|
|||
def flops(self, N): |
|||
# calculate flops for 1 window with token length of N |
|||
flops = 0 |
|||
# qkv = self.qkv(x) |
|||
flops += N * self.dim * 3 * self.dim |
|||
# attn = (q @ k.transpose(-2, -1)) |
|||
flops += self.num_heads * N * (self.dim // self.num_heads) * N |
|||
# x = (attn @ v) |
|||
flops += self.num_heads * N * N * (self.dim // self.num_heads) |
|||
# x = self.proj(x) |
|||
flops += N * self.dim * self.dim |
|||
return flops |
|||
|
|||
|
|||
class SwinTransformerBlock(nn.Module): |
|||
r""" Swin Transformer Block. |
|||
|
|||
Args: |
|||
dim (int): Number of input channels. |
|||
input_resolution (tuple[int]): Input resulotion. |
|||
num_heads (int): Number of attention heads. |
|||
window_size (int): Window size. |
|||
shift_size (int): Shift size for SW-MSA. |
|||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|||
drop (float, optional): Dropout rate. Default: 0.0 |
|||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|||
drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
|||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
|||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|||
""" |
|||
|
|||
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, |
|||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., |
|||
act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
|||
super().__init__() |
|||
self.dim = dim |
|||
self.input_resolution = input_resolution |
|||
self.num_heads = num_heads |
|||
self.window_size = window_size |
|||
self.shift_size = shift_size |
|||
self.mlp_ratio = mlp_ratio |
|||
if min(self.input_resolution) <= self.window_size: |
|||
# if window size is larger than input resolution, we don't partition windows |
|||
self.shift_size = 0 |
|||
self.window_size = min(self.input_resolution) |
|||
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
|||
|
|||
self.norm1 = norm_layer(dim) |
|||
self.attn = WindowAttention( |
|||
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, |
|||
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
|||
|
|||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|||
self.norm2 = norm_layer(dim) |
|||
mlp_hidden_dim = int(dim * mlp_ratio) |
|||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
|||
|
|||
if self.shift_size > 0: |
|||
# calculate attention mask for SW-MSA |
|||
H, W = self.input_resolution |
|||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 |
|||
h_slices = (slice(0, -self.window_size), |
|||
slice(-self.window_size, -self.shift_size), |
|||
slice(-self.shift_size, None)) |
|||
w_slices = (slice(0, -self.window_size), |
|||
slice(-self.window_size, -self.shift_size), |
|||
slice(-self.shift_size, None)) |
|||
cnt = 0 |
|||
for h in h_slices: |
|||
for w in w_slices: |
|||
img_mask[:, h, w, :] = cnt |
|||
cnt += 1 |
|||
|
|||
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 |
|||
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
|||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
|||
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
|||
else: |
|||
attn_mask = None |
|||
|
|||
self.register_buffer("attn_mask", attn_mask) |
|||
|
|||
def forward(self, x): |
|||
H, W = self.input_resolution |
|||
B, L, C = x.shape |
|||
assert L == H * W, "input feature has wrong size" |
|||
|
|||
shortcut = x |
|||
x = self.norm1(x) |
|||
x = x.view(B, H, W, C) |
|||
|
|||
# cyclic shift |
|||
if self.shift_size > 0: |
|||
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
|||
else: |
|||
shifted_x = x |
|||
|
|||
# partition windows |
|||
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C |
|||
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C |
|||
|
|||
# W-MSA/SW-MSA |
|||
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C |
|||
|
|||
# merge windows |
|||
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
|||
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C |
|||
|
|||
# reverse cyclic shift |
|||
if self.shift_size > 0: |
|||
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
|||
else: |
|||
x = shifted_x |
|||
x = x.view(B, H * W, C) |
|||
|
|||
# FFN |
|||
x = shortcut + self.drop_path(x) |
|||
x = x + self.drop_path(self.mlp(self.norm2(x))) |
|||
|
|||
return x |
|||
|
|||
def extra_repr(self) -> str: |
|||
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ |
|||
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" |
|||
|
|||
def flops(self): |
|||
flops = 0 |
|||
H, W = self.input_resolution |
|||
# norm1 |
|||
flops += self.dim * H * W |
|||
# W-MSA/SW-MSA |
|||
nW = H * W / self.window_size / self.window_size |
|||
flops += nW * self.attn.flops(self.window_size * self.window_size) |
|||
# mlp |
|||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio |
|||
# norm2 |
|||
flops += self.dim * H * W |
|||
return flops |
|||
|
|||
|
|||
class PatchMerging(nn.Module): |
|||
r""" Patch Merging Layer. |
|||
|
|||
Args: |
|||
input_resolution (tuple[int]): Resolution of input feature. |
|||
dim (int): Number of input channels. |
|||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|||
""" |
|||
|
|||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): |
|||
super().__init__() |
|||
self.input_resolution = input_resolution |
|||
self.dim = dim |
|||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
|||
self.norm = norm_layer(4 * dim) |
|||
|
|||
def forward(self, x): |
|||
""" |
|||
x: B, H*W, C |
|||
""" |
|||
H, W = self.input_resolution |
|||
B, L, C = x.shape |
|||
assert L == H * W, "input feature has wrong size" |
|||
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." |
|||
|
|||
x = x.view(B, H, W, C) |
|||
|
|||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C |
|||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C |
|||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C |
|||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C |
|||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C |
|||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C |
|||
|
|||
x = self.norm(x) |
|||
x = self.reduction(x) |
|||
|
|||
return x |
|||
|
|||
def extra_repr(self) -> str: |
|||
return f"input_resolution={self.input_resolution}, dim={self.dim}" |
|||
|
|||
def flops(self): |
|||
H, W = self.input_resolution |
|||
flops = H * W * self.dim |
|||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim |
|||
return flops |
|||
|
|||
|
|||
class BasicLayer(nn.Module): |
|||
""" A basic Swin Transformer layer for one stage. |
|||
|
|||
Args: |
|||
dim (int): Number of input channels. |
|||
input_resolution (tuple[int]): Input resolution. |
|||
depth (int): Number of blocks. |
|||
num_heads (int): Number of attention heads. |
|||
window_size (int): Local window size. |
|||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
|||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
|||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
|||
drop (float, optional): Dropout rate. Default: 0.0 |
|||
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
|||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
|||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
|||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
|||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
|||
""" |
|||
|
|||
def __init__(self, dim, input_resolution, depth, num_heads, window_size, |
|||
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., |
|||
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): |
|||
|
|||
super().__init__() |
|||
self.dim = dim |
|||
self.input_resolution = input_resolution |
|||
self.depth = depth |
|||
self.use_checkpoint = use_checkpoint |
|||
|
|||
# build blocks |
|||
self.blocks = nn.ModuleList([ |
|||
SwinTransformerBlock(dim=dim, input_resolution=input_resolution, |
|||
num_heads=num_heads, window_size=window_size, |
|||
shift_size=0 if (i % 2 == 0) else window_size // 2, |
|||
mlp_ratio=mlp_ratio, |
|||
qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
drop=drop, attn_drop=attn_drop, |
|||
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
|||
norm_layer=norm_layer) |
|||
for i in range(depth)]) |
|||
|
|||
# patch merging layer |
|||
if downsample is not None: |
|||
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) |
|||
else: |
|||
self.downsample = None |
|||
|
|||
def forward(self, x): |
|||
for blk in self.blocks: |
|||
if self.use_checkpoint: |
|||
x = checkpoint.checkpoint(blk, x) |
|||
else: |
|||
x = blk(x) |
|||
if self.downsample is not None: |
|||
x = self.downsample(x) |
|||
return x |
|||
|
|||
def extra_repr(self) -> str: |
|||
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" |
|||
|
|||
def flops(self): |
|||
flops = 0 |
|||
for blk in self.blocks: |
|||
flops += blk.flops() |
|||
if self.downsample is not None: |
|||
flops += self.downsample.flops() |
|||
return flops |
|||
|
|||
|
|||
class PatchEmbed(nn.Module): |
|||
r""" Image to Patch Embedding |
|||
|
|||
Args: |
|||
img_size (int): Image size. Default: 224. |
|||
patch_size (int): Patch token size. Default: 4. |
|||
in_chans (int): Number of input image channels. Default: 3. |
|||
embed_dim (int): Number of linear projection output channels. Default: 96. |
|||
norm_layer (nn.Module, optional): Normalization layer. Default: None |
|||
""" |
|||
|
|||
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
|||
super().__init__() |
|||
img_size = to_2tuple(img_size) |
|||
patch_size = to_2tuple(patch_size) |
|||
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
|||
self.img_size = img_size |
|||
self.patch_size = patch_size |
|||
self.patches_resolution = patches_resolution |
|||
self.num_patches = patches_resolution[0] * patches_resolution[1] |
|||
|
|||
self.in_chans = in_chans |
|||
self.embed_dim = embed_dim |
|||
|
|||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
|||
if norm_layer is not None: |
|||
self.norm = norm_layer(embed_dim) |
|||
else: |
|||
self.norm = None |
|||
|
|||
def forward(self, x): |
|||
B, C, H, W = x.shape |
|||
# FIXME look at relaxing size constraints |
|||
assert H == self.img_size[0] and W == self.img_size[1], \ |
|||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
|||
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C |
|||
if self.norm is not None: |
|||
x = self.norm(x) |
|||
return x |
|||
|
|||
def flops(self): |
|||
Ho, Wo = self.patches_resolution |
|||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) |
|||
if self.norm is not None: |
|||
flops += Ho * Wo * self.embed_dim |
|||
return flops |
|||
|
|||
|
|||
class SwinTransformer(nn.Module): |
|||
r""" Swin Transformer |
|||
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - |
|||
https://arxiv.org/pdf/2103.14030 |
|||
|
|||
Args: |
|||
img_size (int | tuple(int)): Input image size. Default 224 |
|||
patch_size (int | tuple(int)): Patch size. Default: 4 |
|||
in_chans (int): Number of input image channels. Default: 3 |
|||
num_classes (int): Number of classes for classification head. Default: 1000 |
|||
embed_dim (int): Patch embedding dimension. Default: 96 |
|||
depths (tuple(int)): Depth of each Swin Transformer layer. |
|||
num_heads (tuple(int)): Number of attention heads in different layers. |
|||
window_size (int): Window size. Default: 7 |
|||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 |
|||
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True |
|||
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None |
|||
drop_rate (float): Dropout rate. Default: 0 |
|||
attn_drop_rate (float): Attention dropout rate. Default: 0 |
|||
drop_path_rate (float): Stochastic depth rate. Default: 0.1 |
|||
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
|||
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False |
|||
patch_norm (bool): If True, add normalization after patch embedding. Default: True |
|||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False |
|||
""" |
|||
|
|||
def __init__(self, img_size=224, patch_size=4, in_chans=3, |
|||
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], |
|||
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, |
|||
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, |
|||
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, |
|||
use_checkpoint=False): |
|||
super().__init__() |
|||
|
|||
# self.num_classes = num_classes |
|||
self.num_layers = len(depths) |
|||
self.embed_dim = embed_dim |
|||
self.ape = ape |
|||
self.patch_norm = patch_norm |
|||
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) |
|||
self.mlp_ratio = mlp_ratio |
|||
|
|||
# split image into non-overlapping patches |
|||
self.patch_embed = PatchEmbed( |
|||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, |
|||
norm_layer=norm_layer if self.patch_norm else None) |
|||
num_patches = self.patch_embed.num_patches |
|||
patches_resolution = self.patch_embed.patches_resolution |
|||
self.patches_resolution = patches_resolution |
|||
|
|||
# absolute position embedding |
|||
if self.ape: |
|||
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
|||
trunc_normal_(self.absolute_pos_embed, std=.02) |
|||
|
|||
self.pos_drop = nn.Dropout(p=drop_rate) |
|||
|
|||
# stochastic depth |
|||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule |
|||
|
|||
# build layers |
|||
self.layers = nn.ModuleList() |
|||
for i_layer in range(self.num_layers): |
|||
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), |
|||
input_resolution=(patches_resolution[0] // (2 ** i_layer), |
|||
patches_resolution[1] // (2 ** i_layer)), |
|||
depth=depths[i_layer], |
|||
num_heads=num_heads[i_layer], |
|||
window_size=window_size, |
|||
mlp_ratio=self.mlp_ratio, |
|||
qkv_bias=qkv_bias, qk_scale=qk_scale, |
|||
drop=drop_rate, attn_drop=attn_drop_rate, |
|||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], |
|||
norm_layer=norm_layer, |
|||
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, |
|||
use_checkpoint=use_checkpoint) |
|||
self.layers.append(layer) |
|||
|
|||
self.norm = norm_layer(self.num_features) |
|||
# self.avgpool = nn.AdaptiveAvgPool1d(1) |
|||
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
|||
|
|||
self.apply(self._init_weights) |
|||
|
|||
def _init_weights(self, m): |
|||
if isinstance(m, nn.Linear): |
|||
trunc_normal_(m.weight, std=.02) |
|||
if isinstance(m, nn.Linear) and m.bias is not None: |
|||
nn.init.constant_(m.bias, 0) |
|||
elif isinstance(m, nn.LayerNorm): |
|||
nn.init.constant_(m.bias, 0) |
|||
nn.init.constant_(m.weight, 1.0) |
|||
|
|||
@torch.jit.ignore |
|||
def no_weight_decay(self): |
|||
return {'absolute_pos_embed'} |
|||
|
|||
@torch.jit.ignore |
|||
def no_weight_decay_keywords(self): |
|||
return {'relative_position_bias_table'} |
|||
|
|||
def forward_features(self, x): |
|||
x = self.patch_embed(x) |
|||
if self.ape: |
|||
x = x + self.absolute_pos_embed |
|||
x = self.pos_drop(x) |
|||
|
|||
for layer in self.layers: |
|||
x = layer(x) |
|||
|
|||
x = self.norm(x) # B L C |
|||
# x = self.avgpool(x.transpose(1, 2)) # B C 1 |
|||
# x = torch.flatten(x, 1) |
|||
return x |
|||
|
|||
def forward(self, x): |
|||
x = self.forward_features(x) |
|||
#x = self.head(x) |
|||
return x |
|||
|
|||
def flops(self): |
|||
flops = 0 |
|||
flops += self.patch_embed.flops() |
|||
for i, layer in enumerate(self.layers): |
|||
flops += layer.flops() |
|||
#flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) |
|||
#flops += self.num_features * self.num_classes |
|||
return flops |
Loading…
Reference in new issue