expansionnet-v2
copied
wxywb
2 years ago
9 changed files with 1732 additions and 0 deletions
@ -0,0 +1,18 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .expansionnet_v2 import ExpansionNetV2 |
||||
|
|
||||
|
def expansionnet_v2(model_name: str): |
||||
|
return ExpansionNetV2(model_name) |
Binary file not shown.
@ -0,0 +1,55 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import sys |
||||
|
import os |
||||
|
from pathlib import Path |
||||
|
|
||||
|
import torch |
||||
|
from torchvision import transforms |
||||
|
from transformers import GPT2Tokenizer |
||||
|
|
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee.types.image_utils import to_pil |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee import register |
||||
|
from towhee.models import clip |
||||
|
|
||||
|
class ExpansionNetV2(NNOperator): |
||||
|
""" |
||||
|
ExpansionNet V2 image captioning operator |
||||
|
""" |
||||
|
def __init__(self, model_name: str): |
||||
|
super().__init__() |
||||
|
sys.path.append(str(Path(__file__).parent)) |
||||
|
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 |
||||
|
sys.path.pop() |
||||
|
with open('demo_coco_tokens.pickle') as fw: |
||||
|
|
||||
|
self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, |
||||
|
swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], |
||||
|
swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, |
||||
|
swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, |
||||
|
swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, |
||||
|
swin_use_checkpoint=False, |
||||
|
final_swin_dim=1536, |
||||
|
|
||||
|
d_model=model_args.model_dim, N_enc=model_args.N_enc, |
||||
|
N_dec=model_args.N_dec, num_heads=8, ff=2048, |
||||
|
num_exp_enc_list=[32, 64, 128, 256, 512], |
||||
|
num_exp_dec=16, |
||||
|
output_word2idx=coco_tokens['word2idx_dict'], |
||||
|
output_idx2word=coco_tokens['idx2word_list'], |
||||
|
max_seq_len=args.max_seq_len, drop_args=model_args.drop_args, |
||||
|
rank='cpu') |
@ -0,0 +1,187 @@ |
|||||
|
import torch |
||||
|
from models.layers import EmbeddingLayer, DecoderLayer, EncoderLayer |
||||
|
from utils.masking import create_pad_mask, create_no_peak_and_pad_mask |
||||
|
from models.captioning_model import CaptioningModel |
||||
|
from models.swin_transformer_mod import SwinTransformer |
||||
|
|
||||
|
import torch.nn as nn |
||||
|
|
||||
|
|
||||
|
class End_ExpansionNet_v2(CaptioningModel): |
||||
|
def __init__(self, |
||||
|
|
||||
|
# swin transf |
||||
|
swin_img_size, swin_patch_size, swin_in_chans, |
||||
|
swin_embed_dim, swin_depths, swin_num_heads, |
||||
|
swin_window_size, swin_mlp_ratio, swin_qkv_bias, swin_qk_scale, |
||||
|
swin_drop_rate, swin_attn_drop_rate, swin_drop_path_rate, |
||||
|
swin_norm_layer, swin_ape, swin_patch_norm, |
||||
|
swin_use_checkpoint, |
||||
|
|
||||
|
# linear_size, |
||||
|
final_swin_dim, |
||||
|
|
||||
|
# captioning |
||||
|
d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, |
||||
|
output_word2idx, output_idx2word, max_seq_len, drop_args, rank=0): |
||||
|
super(End_ExpansionNet_v2, self).__init__() |
||||
|
|
||||
|
self.swin_transf = SwinTransformer( |
||||
|
img_size=swin_img_size, patch_size=swin_patch_size, in_chans=swin_in_chans, |
||||
|
embed_dim=swin_embed_dim, depths=swin_depths, num_heads=swin_num_heads, |
||||
|
window_size=swin_window_size, mlp_ratio=swin_mlp_ratio, qkv_bias=swin_qkv_bias, qk_scale=swin_qk_scale, |
||||
|
drop_rate=swin_drop_rate, attn_drop_rate=swin_attn_drop_rate, drop_path_rate=swin_drop_path_rate, |
||||
|
norm_layer=swin_norm_layer, ape=swin_ape, patch_norm=swin_patch_norm, |
||||
|
use_checkpoint=swin_use_checkpoint) |
||||
|
|
||||
|
self.output_word2idx = output_word2idx |
||||
|
self.output_idx2word = output_idx2word |
||||
|
self.max_seq_len = max_seq_len |
||||
|
|
||||
|
self.num_exp_dec = num_exp_dec |
||||
|
self.num_exp_enc_list = num_exp_enc_list |
||||
|
|
||||
|
self.N_enc = N_enc |
||||
|
self.N_dec = N_dec |
||||
|
self.d_model = d_model |
||||
|
|
||||
|
self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) |
||||
|
self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) |
||||
|
|
||||
|
self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) |
||||
|
self.input_linear = torch.nn.Linear(final_swin_dim, d_model) |
||||
|
self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) |
||||
|
self.log_softmax = nn.LogSoftmax(dim=-1) |
||||
|
|
||||
|
self.out_enc_dropout = nn.Dropout(drop_args.other) |
||||
|
self.out_dec_dropout = nn.Dropout(drop_args.other) |
||||
|
|
||||
|
self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) |
||||
|
self.pos_encoder = nn.Embedding(max_seq_len, d_model) |
||||
|
|
||||
|
self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) |
||||
|
self.enc_reduce_norm = nn.LayerNorm(d_model) |
||||
|
self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) |
||||
|
self.dec_reduce_norm = nn.LayerNorm(d_model) |
||||
|
|
||||
|
for p in self.parameters(): |
||||
|
if p.dim() > 1: |
||||
|
nn.init.xavier_uniform_(p) |
||||
|
|
||||
|
self.trained_steps = 0 |
||||
|
self.rank = rank |
||||
|
|
||||
|
self.check_required_attributes() |
||||
|
|
||||
|
def forward_enc(self, enc_input, enc_input_num_pads): |
||||
|
|
||||
|
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding" |
||||
|
x = self.swin_transf(enc_input) |
||||
|
# --------------- Normale parte di Captioning --------------------------------- |
||||
|
enc_input = self.input_embedder_dropout(self.input_linear(x)) |
||||
|
x = enc_input |
||||
|
enc_input_num_pads = [0] * enc_input.size(0) |
||||
|
|
||||
|
max_num_enc = sum(self.num_exp_enc_list) |
||||
|
pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) |
||||
|
pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), |
||||
|
pad_along_row_input=[0] * enc_input.size(0), |
||||
|
pad_along_column_input=enc_input_num_pads, |
||||
|
rank=self.rank) |
||||
|
|
||||
|
x_list = [] |
||||
|
for i in range(self.N_enc): |
||||
|
x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) |
||||
|
x_list.append(x) |
||||
|
x_list = torch.cat(x_list, dim=-1) |
||||
|
x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) |
||||
|
x = self.enc_reduce_norm(x) |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
||||
|
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * cross_input.size(0))), "enc_input_num_pads should be no None" |
||||
|
|
||||
|
enc_input_num_pads = [0] * dec_input.size(0) |
||||
|
no_peak_and_pad_mask = create_no_peak_and_pad_mask( |
||||
|
mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), |
||||
|
num_pads=dec_input_num_pads, |
||||
|
rank=self.rank) |
||||
|
pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), |
||||
|
pad_along_row_input=dec_input_num_pads, |
||||
|
pad_along_column_input=enc_input_num_pads, |
||||
|
rank=self.rank) |
||||
|
|
||||
|
y = self.out_embedder(dec_input) |
||||
|
pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) |
||||
|
pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) |
||||
|
y = y + self.pos_encoder(pos_y) |
||||
|
y_list = [] |
||||
|
for i in range(self.N_dec): |
||||
|
y = self.decoders[i](x=y, |
||||
|
n_indexes=pos_x, |
||||
|
cross_connection_x=cross_input, |
||||
|
input_attention_mask=no_peak_and_pad_mask, |
||||
|
cross_attention_mask=pad_mask) |
||||
|
y_list.append(y) |
||||
|
y_list = torch.cat(y_list, dim=-1) |
||||
|
y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) |
||||
|
y = self.dec_reduce_norm(y) |
||||
|
|
||||
|
y = self.vocab_linear(y) |
||||
|
|
||||
|
if apply_log_softmax: |
||||
|
y = self.log_softmax(y) |
||||
|
|
||||
|
return y |
||||
|
|
||||
|
|
||||
|
def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, |
||||
|
sos_idx, eos_idx, max_seq_len): |
||||
|
|
||||
|
bs = enc_input.size(0) |
||||
|
x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) |
||||
|
enc_seq_len = x.size(1) |
||||
|
x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) |
||||
|
|
||||
|
upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) |
||||
|
where_is_eos_vector = upperbound_vector.clone() |
||||
|
eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) |
||||
|
finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) |
||||
|
|
||||
|
predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) |
||||
|
predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) |
||||
|
|
||||
|
dec_input_num_pads = [0]*(bs*num_outputs) |
||||
|
time_step = 0 |
||||
|
while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: |
||||
|
dec_input = predicted_caption |
||||
|
log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) |
||||
|
|
||||
|
prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) |
||||
|
sampled_word_indexes = prob_dist.sample() |
||||
|
|
||||
|
predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) |
||||
|
predicted_caption_prob = torch.cat((predicted_caption_prob, |
||||
|
log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) |
||||
|
time_step += 1 |
||||
|
|
||||
|
where_is_eos_vector = torch.min(where_is_eos_vector, |
||||
|
upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) |
||||
|
finished_flag_vector = torch.max(finished_flag_vector, |
||||
|
(sampled_word_indexes == eos_vector).type(torch.IntTensor)) |
||||
|
|
||||
|
res_predicted_caption = [] |
||||
|
for i in range(bs): |
||||
|
res_predicted_caption.append([]) |
||||
|
for j in range(num_outputs): |
||||
|
index = i*num_outputs + j |
||||
|
res_predicted_caption[i].append( |
||||
|
predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) |
||||
|
|
||||
|
where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) |
||||
|
arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) |
||||
|
predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) |
||||
|
res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) |
||||
|
|
||||
|
return res_predicted_caption, res_predicted_caption_prob |
@ -0,0 +1,103 @@ |
|||||
|
import torch |
||||
|
from models.layers import EmbeddingLayer, EncoderLayer, DecoderLayer |
||||
|
from utils.masking import create_pad_mask, create_no_peak_and_pad_mask |
||||
|
from models.captioning_model import CaptioningModel |
||||
|
|
||||
|
import torch.nn as nn |
||||
|
|
||||
|
|
||||
|
class ExpansionNet_v2(CaptioningModel): |
||||
|
def __init__(self, d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, |
||||
|
output_word2idx, output_idx2word, max_seq_len, drop_args, img_feature_dim=2048, rank=0): |
||||
|
super().__init__() |
||||
|
self.output_word2idx = output_word2idx |
||||
|
self.output_idx2word = output_idx2word |
||||
|
self.max_seq_len = max_seq_len |
||||
|
|
||||
|
self.num_exp_dec = num_exp_dec |
||||
|
self.num_exp_enc_list = num_exp_enc_list |
||||
|
|
||||
|
self.N_enc = N_enc |
||||
|
self.N_dec = N_dec |
||||
|
self.d_model = d_model |
||||
|
|
||||
|
self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) |
||||
|
self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) |
||||
|
|
||||
|
self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) |
||||
|
self.input_linear = torch.nn.Linear(img_feature_dim, d_model) |
||||
|
self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) |
||||
|
self.log_softmax = nn.LogSoftmax(dim=-1) |
||||
|
|
||||
|
self.out_enc_dropout = nn.Dropout(drop_args.other) |
||||
|
self.out_dec_dropout = nn.Dropout(drop_args.other) |
||||
|
|
||||
|
self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) |
||||
|
self.pos_encoder = nn.Embedding(max_seq_len, d_model) |
||||
|
|
||||
|
self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) |
||||
|
self.enc_reduce_norm = nn.LayerNorm(d_model) |
||||
|
self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) |
||||
|
self.dec_reduce_norm = nn.LayerNorm(d_model) |
||||
|
|
||||
|
for p in self.parameters(): |
||||
|
if p.dim() > 1: |
||||
|
nn.init.xavier_uniform_(p) |
||||
|
|
||||
|
self.trained_steps = 0 |
||||
|
self.rank = rank |
||||
|
|
||||
|
def forward_enc(self, enc_input, enc_input_num_pads): |
||||
|
|
||||
|
x = self.input_embedder_dropout(self.input_linear(enc_input)) |
||||
|
|
||||
|
max_num_enc = sum(self.num_exp_enc_list) |
||||
|
pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) |
||||
|
pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), |
||||
|
pad_along_row_input=[0] * enc_input.size(0), |
||||
|
pad_along_column_input=enc_input_num_pads, |
||||
|
rank=self.rank) |
||||
|
|
||||
|
x_list = [] |
||||
|
for i in range(self.N_enc): |
||||
|
x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) |
||||
|
x_list.append(x) |
||||
|
x_list = torch.cat(x_list, dim=-1) |
||||
|
x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) |
||||
|
x = self.enc_reduce_norm(x) |
||||
|
return x |
||||
|
|
||||
|
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
||||
|
|
||||
|
no_peak_and_pad_mask = create_no_peak_and_pad_mask( |
||||
|
mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), |
||||
|
num_pads=dec_input_num_pads, |
||||
|
rank=self.rank) |
||||
|
|
||||
|
pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), |
||||
|
pad_along_row_input=dec_input_num_pads, |
||||
|
pad_along_column_input=enc_input_num_pads, |
||||
|
rank=self.rank) |
||||
|
|
||||
|
y = self.out_embedder(dec_input) |
||||
|
pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) |
||||
|
pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) |
||||
|
y = y + self.pos_encoder(pos_y) |
||||
|
y_list = [] |
||||
|
for i in range(self.N_dec): |
||||
|
y = self.decoders[i](x=y, |
||||
|
n_indexes=pos_x, |
||||
|
cross_connection_x=cross_input, |
||||
|
input_attention_mask=no_peak_and_pad_mask, |
||||
|
cross_attention_mask=pad_mask) |
||||
|
y_list.append(y) |
||||
|
y_list = torch.cat(y_list, dim=-1) |
||||
|
y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) |
||||
|
y = self.dec_reduce_norm(y) |
||||
|
|
||||
|
y = self.vocab_linear(y) |
||||
|
|
||||
|
if apply_log_softmax: |
||||
|
y = self.log_softmax(y) |
||||
|
|
||||
|
return y |
@ -0,0 +1,241 @@ |
|||||
|
|
||||
|
|
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
|
||||
|
|
||||
|
class CaptioningModel(nn.Module): |
||||
|
def __init__(self): |
||||
|
super(CaptioningModel, self).__init__() |
||||
|
# mandatory attributes |
||||
|
# rank: to enable multiprocessing |
||||
|
self.rank = None |
||||
|
|
||||
|
def check_required_attributes(self): |
||||
|
if self.rank is None: |
||||
|
raise NotImplementedError("Subclass must assign the rank integer according to the GPU group") |
||||
|
|
||||
|
def forward_enc(self, enc_input, enc_input_num_pads): |
||||
|
raise NotImplementedError |
||||
|
|
||||
|
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
||||
|
raise NotImplementedError |
||||
|
|
||||
|
def forward(self, enc_x, dec_x=None, |
||||
|
enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, |
||||
|
mode='forward', **kwargs): |
||||
|
if mode == 'forward': |
||||
|
x = self.forward_enc(enc_x, enc_x_num_pads) |
||||
|
y = self.forward_dec(x, enc_x_num_pads, dec_x, dec_x_num_pads, apply_log_softmax) |
||||
|
return y |
||||
|
else: |
||||
|
assert ('sos_idx' in kwargs.keys() or 'eos_idx' in kwargs.keys()), \ |
||||
|
'sos and eos must be provided in case of batch sampling or beam search' |
||||
|
sos_idx = kwargs.get('sos_idx', -999) |
||||
|
eos_idx = kwargs.get('eos_idx', -999) |
||||
|
if mode == 'beam_search': |
||||
|
beam_size_arg = kwargs.get('beam_size', 5) |
||||
|
how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) |
||||
|
beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) |
||||
|
sample_or_max = kwargs.get('sample_or_max', 'max') |
||||
|
out_classes, out_logprobs = self.beam_search( |
||||
|
enc_x, enc_x_num_pads, |
||||
|
beam_size=beam_size_arg, |
||||
|
sos_idx=sos_idx, |
||||
|
eos_idx=eos_idx, |
||||
|
how_many_outputs=how_many_outputs_per_beam, |
||||
|
max_seq_len=beam_max_seq_len, |
||||
|
sample_or_max=sample_or_max) |
||||
|
return out_classes, out_logprobs |
||||
|
if mode == 'sampling': |
||||
|
how_many_outputs = kwargs.get('how_many_outputs', 1) |
||||
|
sample_max_seq_len = kwargs.get('sample_max_seq_len', 20) |
||||
|
out_classes, out_logprobs = self.get_batch_multiple_sampled_prediction( |
||||
|
enc_x, enc_x_num_pads, num_outputs=how_many_outputs, |
||||
|
sos_idx=sos_idx, eos_idx=eos_idx, |
||||
|
max_seq_len=sample_max_seq_len) |
||||
|
return out_classes, out_logprobs |
||||
|
|
||||
|
def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, |
||||
|
sos_idx, eos_idx, max_seq_len): |
||||
|
bs, enc_seq_len, _ = enc_input.shape |
||||
|
|
||||
|
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(num_outputs)] |
||||
|
|
||||
|
x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) |
||||
|
x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) |
||||
|
|
||||
|
upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) |
||||
|
where_is_eos_vector = upperbound_vector.clone() |
||||
|
eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) |
||||
|
finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) |
||||
|
|
||||
|
predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) |
||||
|
predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) |
||||
|
|
||||
|
dec_input_num_pads = [0]*(bs*num_outputs) |
||||
|
time_step = 0 |
||||
|
while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: |
||||
|
dec_input = predicted_caption |
||||
|
log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) |
||||
|
|
||||
|
prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) |
||||
|
sampled_word_indexes = prob_dist.sample() |
||||
|
|
||||
|
predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) |
||||
|
predicted_caption_prob = torch.cat((predicted_caption_prob, |
||||
|
log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) |
||||
|
time_step += 1 |
||||
|
|
||||
|
where_is_eos_vector = torch.min(where_is_eos_vector, |
||||
|
upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) |
||||
|
finished_flag_vector = torch.max(finished_flag_vector, |
||||
|
(sampled_word_indexes == eos_vector).type(torch.IntTensor)) |
||||
|
|
||||
|
# remove the elements that come after the first eos from the sequence |
||||
|
res_predicted_caption = [] |
||||
|
for i in range(bs): |
||||
|
res_predicted_caption.append([]) |
||||
|
for j in range(num_outputs): |
||||
|
index = i*num_outputs + j |
||||
|
res_predicted_caption[i].append( |
||||
|
predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) |
||||
|
|
||||
|
where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) |
||||
|
arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) |
||||
|
predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) |
||||
|
res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) |
||||
|
|
||||
|
return res_predicted_caption, res_predicted_caption_prob |
||||
|
|
||||
|
def beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, |
||||
|
beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): |
||||
|
assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" |
||||
|
assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" |
||||
|
bs = enc_input.shape[0] |
||||
|
|
||||
|
cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads) |
||||
|
|
||||
|
# init: ------------------------------------------------------------------ |
||||
|
init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) |
||||
|
init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) |
||||
|
log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, |
||||
|
dec_input=init_dec_class, dec_input_num_pads=[0] * bs, |
||||
|
apply_log_softmax=True) |
||||
|
if sample_or_max == 'max': |
||||
|
_, topi = torch.topk(log_probs, k=beam_size, sorted=True) |
||||
|
else: # sample |
||||
|
topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) |
||||
|
topi = topi.unsqueeze(1) |
||||
|
|
||||
|
init_dec_class = init_dec_class.repeat(1, beam_size) |
||||
|
init_dec_class = init_dec_class.unsqueeze(-1) |
||||
|
top_beam_size_class = topi.transpose(-2, -1) |
||||
|
init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) |
||||
|
|
||||
|
init_dec_logprob = init_dec_logprob.repeat(1, beam_size) |
||||
|
init_dec_logprob = init_dec_logprob.unsqueeze(-1) |
||||
|
top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) |
||||
|
top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) |
||||
|
init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) |
||||
|
|
||||
|
bs, enc_seq_len, d_model = cross_enc_output.shape |
||||
|
cross_enc_output = cross_enc_output.unsqueeze(1) |
||||
|
cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) |
||||
|
cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() |
||||
|
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] |
||||
|
|
||||
|
# loop: ----------------------------------------------------------------- |
||||
|
loop_dec_classes = init_dec_class |
||||
|
loop_dec_logprobs = init_dec_logprob |
||||
|
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
||||
|
|
||||
|
loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) |
||||
|
|
||||
|
for time_step in range(2, max_seq_len): |
||||
|
loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() |
||||
|
|
||||
|
log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, |
||||
|
dec_input=loop_dec_classes, |
||||
|
dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), |
||||
|
apply_log_softmax=True) |
||||
|
if sample_or_max == 'max': |
||||
|
_, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) |
||||
|
else: # sample |
||||
|
topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, |
||||
|
replacement=False) |
||||
|
|
||||
|
top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) |
||||
|
|
||||
|
top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) |
||||
|
top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) |
||||
|
|
||||
|
# each sequence have now its best prediction, but some sequence may have already been terminated with EOS, |
||||
|
# in that case its candidates are simply ignored, and do not sum up in the "loop_dec_logprobs" their value |
||||
|
# are set to zero |
||||
|
there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ |
||||
|
sum(dim=-1, keepdims=True).type(torch.bool) |
||||
|
|
||||
|
# if we pad with -999 its candidates logprobabilities, also the sequence containing EOS would be |
||||
|
# straightforwardly discarded, instead we want to keep it in the exploration. Therefore we mask with 0.0 |
||||
|
# one arbitrary candidate word probability so the sequence probability is unchanged but it |
||||
|
# can still be discarded when a better candidate sequence is found |
||||
|
top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) |
||||
|
top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) |
||||
|
|
||||
|
comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs |
||||
|
|
||||
|
comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) |
||||
|
_, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) |
||||
|
which_sequence = topi // beam_size |
||||
|
which_word = topi % beam_size |
||||
|
|
||||
|
loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) |
||||
|
loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) |
||||
|
|
||||
|
bs_idxes = torch.arange(bs).unsqueeze(-1) |
||||
|
new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] |
||||
|
new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] |
||||
|
|
||||
|
which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] |
||||
|
which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ |
||||
|
[bs_idxes, which_sequence]] |
||||
|
which_word = which_word.unsqueeze(-1) |
||||
|
|
||||
|
lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, |
||||
|
index=which_word) |
||||
|
lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) |
||||
|
|
||||
|
new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) |
||||
|
new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) |
||||
|
loop_dec_classes = new_loop_dec_classes |
||||
|
loop_dec_logprobs = new_loop_dec_logprobs |
||||
|
|
||||
|
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
||||
|
|
||||
|
# -----------------------update loop_num_elem_vector ---------------------------- |
||||
|
loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) |
||||
|
there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ |
||||
|
sum(dim=-1).type(torch.bool).view(bs * beam_size) |
||||
|
loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) |
||||
|
|
||||
|
if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): |
||||
|
break |
||||
|
|
||||
|
# sort out the best result |
||||
|
loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) |
||||
|
_, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) |
||||
|
res_caption_pred = [[] for _ in range(bs)] |
||||
|
res_caption_logprob = [[] for _ in range(bs)] |
||||
|
for i in range(bs): |
||||
|
for j in range(how_many_outputs): |
||||
|
idx = topi[i, j].item() |
||||
|
res_caption_pred[i].append( |
||||
|
loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) |
||||
|
res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) |
||||
|
|
||||
|
flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] |
||||
|
flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) |
||||
|
res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) |
||||
|
|
||||
|
return res_caption_pred, res_caption_logprob |
@ -0,0 +1,187 @@ |
|||||
|
|
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
from models.captioning_model import CaptioningModel |
||||
|
|
||||
|
|
||||
|
class EsembleCaptioningModel(CaptioningModel): |
||||
|
def __init__(self, models_list, rank): |
||||
|
super().__init__() |
||||
|
self.num_models = len(models_list) |
||||
|
self.models_list = models_list |
||||
|
self.rank = rank |
||||
|
|
||||
|
self.dummy_linear = nn.Linear(1, 1) |
||||
|
|
||||
|
for model in self.models_list: |
||||
|
model.eval() |
||||
|
|
||||
|
def forward(self, enc_x, dec_x=None, |
||||
|
enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, |
||||
|
mode='beam_search', **kwargs): |
||||
|
assert (mode == 'beam_search'), "this class supports only beam search." |
||||
|
sos_idx = kwargs.get('sos_idx', -999) |
||||
|
eos_idx = kwargs.get('eos_idx', -999) |
||||
|
if mode == 'beam_search': |
||||
|
beam_size_arg = kwargs.get('beam_size', 5) |
||||
|
how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) |
||||
|
beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) |
||||
|
sample_or_max = kwargs.get('sample_or_max', 'max') |
||||
|
out_classes, out_logprobs = self.ensemble_beam_search( |
||||
|
enc_x, enc_x_num_pads, |
||||
|
beam_size=beam_size_arg, |
||||
|
sos_idx=sos_idx, |
||||
|
eos_idx=eos_idx, |
||||
|
how_many_outputs=how_many_outputs_per_beam, |
||||
|
max_seq_len=beam_max_seq_len, |
||||
|
sample_or_max=sample_or_max) |
||||
|
return out_classes, out_logprobs |
||||
|
|
||||
|
def forward_enc(self, enc_input, enc_input_num_pads): |
||||
|
x_outputs_list = [] |
||||
|
for i in range(self.num_models): |
||||
|
x_outputs = self.models_list[i].forward_enc(enc_input, enc_input_num_pads) |
||||
|
x_outputs_list.append(x_outputs) |
||||
|
return x_outputs_list |
||||
|
|
||||
|
def forward_dec(self, cross_input_list, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): |
||||
|
|
||||
|
import torch.nn.functional as F |
||||
|
y_outputs = [] |
||||
|
for i in range(self.num_models): |
||||
|
y_outputs.append( |
||||
|
F.softmax(self.models_list[i].forward_dec( |
||||
|
cross_input_list[i], enc_input_num_pads, |
||||
|
dec_input, dec_input_num_pads, False).unsqueeze(0), dim=-1)) |
||||
|
avg = torch.cat(y_outputs, dim=0).mean(dim=0).log() |
||||
|
|
||||
|
return avg |
||||
|
|
||||
|
# quite unclean coding, to be re-factored in the future... |
||||
|
# since it's a bit similar to the single model case |
||||
|
def ensemble_beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, |
||||
|
beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): |
||||
|
assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" |
||||
|
assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" |
||||
|
bs = enc_input.shape[0] |
||||
|
|
||||
|
# the cross_dec_input is computed once |
||||
|
cross_enc_output_list = self.forward_enc(enc_input, enc_input_num_pads) |
||||
|
|
||||
|
# init: ------------------------------------------------------------------ |
||||
|
init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) |
||||
|
init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) |
||||
|
log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, |
||||
|
dec_input=init_dec_class, dec_input_num_pads=[0] * bs, |
||||
|
apply_log_softmax=True) |
||||
|
if sample_or_max == 'max': |
||||
|
_, topi = torch.topk(log_probs, k=beam_size, sorted=True) |
||||
|
else: # sample |
||||
|
topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) |
||||
|
topi = topi.unsqueeze(1) |
||||
|
|
||||
|
init_dec_class = init_dec_class.repeat(1, beam_size) |
||||
|
init_dec_class = init_dec_class.unsqueeze(-1) |
||||
|
top_beam_size_class = topi.transpose(-2, -1) |
||||
|
init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) |
||||
|
|
||||
|
init_dec_logprob = init_dec_logprob.repeat(1, beam_size) |
||||
|
init_dec_logprob = init_dec_logprob.unsqueeze(-1) |
||||
|
top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) |
||||
|
top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) |
||||
|
init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) |
||||
|
|
||||
|
tmp_cross_enc_output_list = [] |
||||
|
for cross_enc_output in cross_enc_output_list: |
||||
|
bs, enc_seq_len, d_model = cross_enc_output.shape |
||||
|
cross_enc_output = cross_enc_output.unsqueeze(1) |
||||
|
cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) |
||||
|
cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() |
||||
|
tmp_cross_enc_output_list.append(cross_enc_output) |
||||
|
cross_enc_output_list = tmp_cross_enc_output_list |
||||
|
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] |
||||
|
|
||||
|
loop_dec_classes = init_dec_class |
||||
|
loop_dec_logprobs = init_dec_logprob |
||||
|
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
||||
|
|
||||
|
loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) |
||||
|
|
||||
|
for time_step in range(2, max_seq_len): |
||||
|
loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() |
||||
|
|
||||
|
log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, |
||||
|
dec_input=loop_dec_classes, |
||||
|
dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), |
||||
|
apply_log_softmax=True) |
||||
|
if sample_or_max == 'max': |
||||
|
_, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) |
||||
|
else: # sample |
||||
|
topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, |
||||
|
replacement=False) |
||||
|
|
||||
|
top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) |
||||
|
|
||||
|
top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) |
||||
|
top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) |
||||
|
|
||||
|
there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ |
||||
|
sum(dim=-1, keepdims=True).type(torch.bool) |
||||
|
|
||||
|
top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) |
||||
|
top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) |
||||
|
|
||||
|
comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs |
||||
|
|
||||
|
comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) |
||||
|
_, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) |
||||
|
which_sequence = topi // beam_size |
||||
|
which_word = topi % beam_size |
||||
|
|
||||
|
loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) |
||||
|
loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) |
||||
|
|
||||
|
bs_idxes = torch.arange(bs).unsqueeze(-1) |
||||
|
new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] |
||||
|
new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] |
||||
|
|
||||
|
which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] |
||||
|
which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ |
||||
|
[bs_idxes, which_sequence]] |
||||
|
which_word = which_word.unsqueeze(-1) |
||||
|
|
||||
|
lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, |
||||
|
index=which_word) |
||||
|
lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) |
||||
|
|
||||
|
new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) |
||||
|
new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) |
||||
|
loop_dec_classes = new_loop_dec_classes |
||||
|
loop_dec_logprobs = new_loop_dec_logprobs |
||||
|
|
||||
|
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) |
||||
|
|
||||
|
loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) |
||||
|
there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ |
||||
|
sum(dim=-1).type(torch.bool).view(bs * beam_size) |
||||
|
loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) |
||||
|
|
||||
|
if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): |
||||
|
break |
||||
|
|
||||
|
loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) |
||||
|
_, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) |
||||
|
res_caption_pred = [[] for _ in range(bs)] |
||||
|
res_caption_logprob = [[] for _ in range(bs)] |
||||
|
for i in range(bs): |
||||
|
for j in range(how_many_outputs): |
||||
|
idx = topi[i, j].item() |
||||
|
res_caption_pred[i].append( |
||||
|
loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) |
||||
|
res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) |
||||
|
|
||||
|
flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] |
||||
|
flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) |
||||
|
res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) |
||||
|
|
||||
|
return res_caption_pred, res_caption_logprob |
@ -0,0 +1,286 @@ |
|||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import math |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
class EmbeddingLayer(nn.Module): |
||||
|
def __init__(self, vocab_size, d_model, dropout_perc): |
||||
|
super(EmbeddingLayer, self).__init__() |
||||
|
self.dropout = nn.Dropout(dropout_perc) |
||||
|
self.embed = nn.Embedding(vocab_size, d_model) |
||||
|
self.d_model = d_model |
||||
|
|
||||
|
def forward(self, x): |
||||
|
return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model)) |
||||
|
|
||||
|
|
||||
|
class PositionalEncoder(nn.Module): |
||||
|
def __init__(self, d_model, max_seq_len, rank=0): |
||||
|
super().__init__() |
||||
|
assert d_model % 2 == 0, "d_model is not even, even number suggested" |
||||
|
self.d_model = d_model |
||||
|
self.pe = torch.zeros(max_seq_len, d_model).to(rank) |
||||
|
for pos in range(max_seq_len): |
||||
|
for i in range(0, d_model, 2): |
||||
|
self.pe.data[pos, i] = math.sin(pos / (10000.0 ** ((2.0 * i) / d_model))) |
||||
|
self.pe.data[pos, i + 1] = math.cos(pos / (10000.0 ** ((2.0 * i) / d_model))) |
||||
|
self.pe.data = self.pe.data.unsqueeze(0) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
seq_len = x.shape[1] |
||||
|
return self.pe.data[0, :seq_len] |
||||
|
|
||||
|
|
||||
|
|
||||
|
class StaticExpansionBlock(nn.Module): |
||||
|
def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps): |
||||
|
super().__init__() |
||||
|
self.d_model = d_model |
||||
|
self.num_enc_exp_list = num_enc_exp_list |
||||
|
|
||||
|
self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) |
||||
|
self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) |
||||
|
|
||||
|
self.key_embed = nn.Linear(d_model, d_model) |
||||
|
self.class_a_embed = nn.Linear(d_model, d_model) |
||||
|
self.class_b_embed = nn.Linear(d_model, d_model) |
||||
|
|
||||
|
self.selector_embed = nn.Linear(d_model, d_model) |
||||
|
|
||||
|
self.dropout_class_a_fw = nn.Dropout(dropout_perc) |
||||
|
self.dropout_class_b_fw = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.dropout_class_a_bw = nn.Dropout(dropout_perc) |
||||
|
self.dropout_class_b_bw = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.Z_dropout = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.eps = eps |
||||
|
|
||||
|
def forward(self, x, n_indexes, mask): |
||||
|
bs, enc_len, _ = x.shape |
||||
|
|
||||
|
query_exp = self.query_exp_vectors(n_indexes) |
||||
|
bias_exp = self.bias_exp_vectors(n_indexes) |
||||
|
x_key = self.key_embed(x) |
||||
|
|
||||
|
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model) |
||||
|
z = self.Z_dropout(z) |
||||
|
|
||||
|
class_a_fw = F.relu(z) |
||||
|
class_b_fw = F.relu(-z) |
||||
|
class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0) |
||||
|
class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0) |
||||
|
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) |
||||
|
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) |
||||
|
|
||||
|
class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp |
||||
|
class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp |
||||
|
class_a = self.dropout_class_a_fw(class_a) |
||||
|
class_b = self.dropout_class_b_fw(class_b) |
||||
|
|
||||
|
class_a_bw = F.relu(z.transpose(-2, -1)) |
||||
|
class_b_bw = F.relu(-z.transpose(-2, -1)) |
||||
|
|
||||
|
accum = 0 |
||||
|
class_a_bw_list = [] |
||||
|
class_b_bw_list = [] |
||||
|
for j in range(len(self.num_enc_exp_list)): |
||||
|
from_idx = accum |
||||
|
to_idx = accum + self.num_enc_exp_list[j] |
||||
|
accum += self.num_enc_exp_list[j] |
||||
|
class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) |
||||
|
class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) |
||||
|
class_a_bw = torch.cat(class_a_bw_list, dim=-1) |
||||
|
class_b_bw = torch.cat(class_b_bw_list, dim=-1) |
||||
|
|
||||
|
class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list) |
||||
|
class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list) |
||||
|
class_a = self.dropout_class_a_bw(class_a) |
||||
|
class_b = self.dropout_class_b_bw(class_b) |
||||
|
|
||||
|
selector = torch.sigmoid(self.selector_embed(x)) |
||||
|
x_result = selector * class_a + (1 - selector) * class_b |
||||
|
|
||||
|
return x_result |
||||
|
|
||||
|
|
||||
|
class EncoderLayer(nn.Module): |
||||
|
def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9): |
||||
|
super().__init__() |
||||
|
self.norm_1 = nn.LayerNorm(d_model) |
||||
|
self.norm_2 = nn.LayerNorm(d_model) |
||||
|
self.dropout_1 = nn.Dropout(dropout_perc) |
||||
|
self.dropout_2 = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps) |
||||
|
self.ff = FeedForward(d_model, d_ff, dropout_perc) |
||||
|
|
||||
|
def forward(self, x, n_indexes, mask): |
||||
|
x2 = self.norm_1(x) |
||||
|
x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask)) |
||||
|
x2 = self.norm_2(x) |
||||
|
x = x + self.dropout_2(self.ff(x2)) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class DynamicExpansionBlock(nn.Module): |
||||
|
def __init__(self, d_model, num_exp, dropout_perc, eps): |
||||
|
super().__init__() |
||||
|
self.d_model = d_model |
||||
|
|
||||
|
self.num_exp = num_exp |
||||
|
self.cond_embed = nn.Linear(d_model, d_model) |
||||
|
|
||||
|
self.query_exp_vectors = nn.Embedding(self.num_exp, d_model) |
||||
|
self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model) |
||||
|
|
||||
|
self.key_linear = nn.Linear(d_model, d_model) |
||||
|
self.class_a_embed = nn.Linear(d_model, d_model) |
||||
|
self.class_b_embed = nn.Linear(d_model, d_model) |
||||
|
|
||||
|
self.selector_embed = nn.Linear(d_model, d_model) |
||||
|
|
||||
|
self.dropout_class_a_fw = nn.Dropout(dropout_perc) |
||||
|
self.dropout_class_b_fw = nn.Dropout(dropout_perc) |
||||
|
self.dropout_class_a_bw = nn.Dropout(dropout_perc) |
||||
|
self.dropout_class_b_bw = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.Z_dropout = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.eps = eps |
||||
|
|
||||
|
def forward(self, x, n_indexes, mask): |
||||
|
bs, dec_len, _ = x.shape |
||||
|
|
||||
|
cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model) |
||||
|
query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1) |
||||
|
bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1) |
||||
|
query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) |
||||
|
bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) |
||||
|
|
||||
|
x_key = self.key_linear(x) |
||||
|
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model) |
||||
|
z = self.Z_dropout(z) |
||||
|
|
||||
|
mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \ |
||||
|
view(bs, dec_len * self.num_exp, dec_len) |
||||
|
|
||||
|
class_a_fw = F.relu(z) |
||||
|
class_b_fw = F.relu(-z) |
||||
|
class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0) |
||||
|
class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0) |
||||
|
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) |
||||
|
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) |
||||
|
class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) |
||||
|
class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) |
||||
|
class_a = self.dropout_class_a_fw(class_a) |
||||
|
class_b = self.dropout_class_b_fw(class_b) |
||||
|
|
||||
|
mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \ |
||||
|
view(bs, dec_len, dec_len * self.num_exp) |
||||
|
|
||||
|
class_a_bw = F.relu(z.transpose(-2, -1)) |
||||
|
class_b_bw = F.relu(-z.transpose(-2, -1)) |
||||
|
class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0) |
||||
|
class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0) |
||||
|
class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps) |
||||
|
class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps) |
||||
|
class_a = torch.matmul(class_a_bw, class_a + bias_exp) |
||||
|
class_b = torch.matmul(class_b_bw, class_b + bias_exp) |
||||
|
class_a = self.dropout_class_a_bw(class_a) |
||||
|
class_b = self.dropout_class_b_bw(class_b) |
||||
|
|
||||
|
selector = torch.sigmoid(self.selector_embed(x)) |
||||
|
x_result = selector * class_a + (1 - selector) * class_b |
||||
|
|
||||
|
return x_result |
||||
|
|
||||
|
|
||||
|
class DecoderLayer(nn.Module): |
||||
|
def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9): |
||||
|
super().__init__() |
||||
|
self.norm_1 = nn.LayerNorm(d_model) |
||||
|
self.norm_2 = nn.LayerNorm(d_model) |
||||
|
self.norm_3 = nn.LayerNorm(d_model) |
||||
|
|
||||
|
self.dropout_1 = nn.Dropout(dropout_perc) |
||||
|
self.dropout_2 = nn.Dropout(dropout_perc) |
||||
|
self.dropout_3 = nn.Dropout(dropout_perc) |
||||
|
|
||||
|
self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc) |
||||
|
self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps) |
||||
|
self.ff = FeedForward(d_model, d_ff, dropout_perc) |
||||
|
|
||||
|
def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask): |
||||
|
|
||||
|
# Pre-LayerNorm |
||||
|
x2 = self.norm_1(x) |
||||
|
x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask)) |
||||
|
|
||||
|
x2 = self.norm_2(x) |
||||
|
x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x, |
||||
|
mask=cross_attention_mask)) |
||||
|
|
||||
|
x2 = self.norm_3(x) |
||||
|
x = x + self.dropout_3(self.ff(x2)) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
|
||||
|
class MultiHeadAttention(nn.Module): |
||||
|
def __init__(self, d_model, num_heads, dropout_perc): |
||||
|
super(MultiHeadAttention, self).__init__() |
||||
|
assert d_model % num_heads == 0, "num heads must be multiple of d_model" |
||||
|
|
||||
|
self.d_model = d_model |
||||
|
self.d_k = int(d_model / num_heads) |
||||
|
self.num_heads = num_heads |
||||
|
|
||||
|
self.Wq = nn.Linear(d_model, self.d_k * num_heads) |
||||
|
self.Wk = nn.Linear(d_model, self.d_k * num_heads) |
||||
|
self.Wv = nn.Linear(d_model, self.d_k * num_heads) |
||||
|
|
||||
|
self.out_linear = nn.Linear(d_model, d_model) |
||||
|
|
||||
|
def forward(self, q, k, v, mask=None): |
||||
|
batch_size, q_seq_len, _ = q.shape |
||||
|
k_seq_len = k.size(1) |
||||
|
v_seq_len = v.size(1) |
||||
|
|
||||
|
k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k) |
||||
|
q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k) |
||||
|
v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k) |
||||
|
|
||||
|
k_proj = k_proj.transpose(2, 1) |
||||
|
q_proj = q_proj.transpose(2, 1) |
||||
|
v_proj = v_proj.transpose(2, 1) |
||||
|
|
||||
|
sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2)) |
||||
|
sim_scores = sim_scores / self.d_k ** 0.5 |
||||
|
|
||||
|
if mask is not None: |
||||
|
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) |
||||
|
sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4) |
||||
|
sim_scores = F.softmax(input=sim_scores, dim=-1) |
||||
|
|
||||
|
attention_applied = torch.matmul(sim_scores, v_proj) |
||||
|
attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\ |
||||
|
.view(batch_size, q_seq_len, self.d_model) |
||||
|
|
||||
|
out = self.out_linear(attention_applied_concatenated) |
||||
|
return out |
||||
|
|
||||
|
class FeedForward(nn.Module): |
||||
|
def __init__(self, d_model, d_ff, dropout_perc): |
||||
|
super(FeedForward, self).__init__() |
||||
|
self.linear_1 = nn.Linear(d_model, d_ff) |
||||
|
self.dropout = nn.Dropout(dropout_perc) |
||||
|
self.linear_2 = nn.Linear(d_ff, d_model) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.dropout(F.relu(self.linear_1(x))) |
||||
|
x = self.linear_2(x) |
||||
|
return x |
@ -0,0 +1,655 @@ |
|||||
|
# -------------------------------------------------------- |
||||
|
# Swin Transformer |
||||
|
# Copyright (c) 2021 Microsoft |
||||
|
# Licensed under The MIT License [see LICENSE for details] |
||||
|
# Written by Ze Liu |
||||
|
# -------------------------------------------------------- |
||||
|
|
||||
|
# --------------------------------- |
||||
|
# All credits due to Ze Liu: https://github.com/microsoft/Swin-Transformer |
||||
|
# and the additional sources: |
||||
|
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py |
||||
|
# https://github.com/yukimasano/PASS/blob/main/vision_transformer.py |
||||
|
# --------------------------------- |
||||
|
|
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.utils.checkpoint as checkpoint |
||||
|
|
||||
|
|
||||
|
class DropPath(nn.Module): |
||||
|
def __init__(self, drop_prob): |
||||
|
super().__init__() |
||||
|
self.drop_prob = drop_prob |
||||
|
|
||||
|
def forward(self, x): |
||||
|
if not self.training: |
||||
|
return x |
||||
|
keep_prob = 1 - self.drop_prob |
||||
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets |
||||
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
||||
|
random_tensor.floor_() # binarize |
||||
|
output = x.div(keep_prob) * random_tensor |
||||
|
return output |
||||
|
|
||||
|
|
||||
|
import collections.abc |
||||
|
def to_2tuple(x): |
||||
|
if isinstance(x, collections.abc.Iterable): |
||||
|
return x |
||||
|
return (x, x) |
||||
|
|
||||
|
|
||||
|
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
||||
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
||||
|
|
||||
|
|
||||
|
import warnings |
||||
|
import math |
||||
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
||||
|
# Cut & paste from PyTorch official repo master until it's in a few official releases - RW |
||||
|
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf |
||||
|
def norm_cdf(x): |
||||
|
return (1. + math.erf(x / math.sqrt(2.))) / 2. |
||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
||||
|
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
||||
|
"The distribution of values may be incorrect.", |
||||
|
stacklevel=2) |
||||
|
|
||||
|
with torch.no_grad(): |
||||
|
# Values are generated by using a truncated uniform distribution and |
||||
|
# then using the inverse CDF for the normal distribution. |
||||
|
# Get upper and lower cdf values |
||||
|
l = norm_cdf((a - mean) / std) |
||||
|
u = norm_cdf((b - mean) / std) |
||||
|
|
||||
|
# Uniformly fill tensor with values from [l, u], then translate to |
||||
|
# [2l-1, 2u-1]. |
||||
|
tensor.uniform_(2 * l - 1, 2 * u - 1) |
||||
|
|
||||
|
# Use inverse cdf transform for normal distribution to get truncated |
||||
|
# standard normal |
||||
|
tensor.erfinv_() |
||||
|
|
||||
|
# Transform to proper mean, std |
||||
|
tensor.mul_(std * math.sqrt(2.)) |
||||
|
tensor.add_(mean) |
||||
|
|
||||
|
# Clamp to ensure it's in the proper range |
||||
|
tensor.clamp_(min=a, max=b) |
||||
|
return tensor |
||||
|
|
||||
|
|
||||
|
class Mlp(nn.Module): |
||||
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
||||
|
super().__init__() |
||||
|
out_features = out_features or in_features |
||||
|
hidden_features = hidden_features or in_features |
||||
|
self.fc1 = nn.Linear(in_features, hidden_features) |
||||
|
self.act = act_layer() |
||||
|
self.fc2 = nn.Linear(hidden_features, out_features) |
||||
|
self.drop = nn.Dropout(drop) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.fc1(x) |
||||
|
x = self.act(x) |
||||
|
x = self.drop(x) |
||||
|
x = self.fc2(x) |
||||
|
x = self.drop(x) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
def window_partition(x, window_size): |
||||
|
""" |
||||
|
Args: |
||||
|
x: (B, H, W, C) |
||||
|
window_size (int): window size |
||||
|
|
||||
|
Returns: |
||||
|
windows: (num_windows*B, window_size, window_size, C) |
||||
|
""" |
||||
|
B, H, W, C = x.shape |
||||
|
# mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 |
||||
|
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) |
||||
|
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) |
||||
|
return windows |
||||
|
|
||||
|
|
||||
|
def window_reverse(windows, window_size, H, W): |
||||
|
""" |
||||
|
Args: |
||||
|
windows: (num_windows*B, window_size, window_size, C) |
||||
|
window_size (int): Window size |
||||
|
H (int): Height of image |
||||
|
W (int): Width of image |
||||
|
|
||||
|
Returns: |
||||
|
x: (B, H, W, C) |
||||
|
""" |
||||
|
B = int(windows.shape[0] / (H * W / window_size / window_size)) |
||||
|
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) |
||||
|
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) |
||||
|
return x |
||||
|
|
||||
|
|
||||
|
class WindowAttention(nn.Module): |
||||
|
r""" Window based multi-head self attention (W-MSA) module with relative position bias. |
||||
|
It supports both of shifted and non-shifted window. |
||||
|
|
||||
|
Args: |
||||
|
dim (int): Number of input channels. |
||||
|
window_size (tuple[int]): The height and width of the window. |
||||
|
num_heads (int): Number of attention heads. |
||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set |
||||
|
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 |
||||
|
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): |
||||
|
|
||||
|
super().__init__() |
||||
|
self.dim = dim |
||||
|
self.window_size = window_size # Wh, Ww |
||||
|
self.num_heads = num_heads |
||||
|
head_dim = dim // num_heads |
||||
|
self.scale = qk_scale or head_dim ** -0.5 |
||||
|
|
||||
|
# define a parameter table of relative position bias |
||||
|
self.relative_position_bias_table = nn.Parameter( |
||||
|
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH |
||||
|
|
||||
|
# get pair-wise relative position index for each token inside the window |
||||
|
coords_h = torch.arange(self.window_size[0]) |
||||
|
coords_w = torch.arange(self.window_size[1]) |
||||
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww |
||||
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww |
||||
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww |
||||
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 |
||||
|
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 |
||||
|
relative_coords[:, :, 1] += self.window_size[1] - 1 |
||||
|
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 |
||||
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww |
||||
|
self.register_buffer("relative_position_index", relative_position_index) |
||||
|
|
||||
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
||||
|
self.attn_drop = nn.Dropout(attn_drop) |
||||
|
self.proj = nn.Linear(dim, dim) |
||||
|
self.proj_drop = nn.Dropout(proj_drop) |
||||
|
|
||||
|
trunc_normal_(self.relative_position_bias_table, std=.02) |
||||
|
self.softmax = nn.Softmax(dim=-1) |
||||
|
|
||||
|
def forward(self, x, mask=None): |
||||
|
""" |
||||
|
Args: |
||||
|
x: input features with shape of (num_windows*B, N, C) |
||||
|
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None |
||||
|
""" |
||||
|
B_, N, C = x.shape |
||||
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) |
||||
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) |
||||
|
|
||||
|
q = q * self.scale |
||||
|
attn = (q @ k.transpose(-2, -1)) |
||||
|
|
||||
|
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( |
||||
|
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH |
||||
|
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww |
||||
|
attn = attn + relative_position_bias.unsqueeze(0) |
||||
|
|
||||
|
if mask is not None: |
||||
|
nW = mask.shape[0] |
||||
|
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) |
||||
|
attn = attn.view(-1, self.num_heads, N, N) |
||||
|
attn = self.softmax(attn) |
||||
|
else: |
||||
|
attn = self.softmax(attn) |
||||
|
|
||||
|
attn = self.attn_drop(attn) |
||||
|
|
||||
|
x = (attn @ v).transpose(1, 2).reshape(B_, N, C) |
||||
|
x = self.proj(x) |
||||
|
x = self.proj_drop(x) |
||||
|
return x |
||||
|
|
||||
|
def extra_repr(self) -> str: |
||||
|
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' |
||||
|
|
||||
|
def flops(self, N): |
||||
|
# calculate flops for 1 window with token length of N |
||||
|
flops = 0 |
||||
|
# qkv = self.qkv(x) |
||||
|
flops += N * self.dim * 3 * self.dim |
||||
|
# attn = (q @ k.transpose(-2, -1)) |
||||
|
flops += self.num_heads * N * (self.dim // self.num_heads) * N |
||||
|
# x = (attn @ v) |
||||
|
flops += self.num_heads * N * N * (self.dim // self.num_heads) |
||||
|
# x = self.proj(x) |
||||
|
flops += N * self.dim * self.dim |
||||
|
return flops |
||||
|
|
||||
|
|
||||
|
class SwinTransformerBlock(nn.Module): |
||||
|
r""" Swin Transformer Block. |
||||
|
|
||||
|
Args: |
||||
|
dim (int): Number of input channels. |
||||
|
input_resolution (tuple[int]): Input resulotion. |
||||
|
num_heads (int): Number of attention heads. |
||||
|
window_size (int): Window size. |
||||
|
shift_size (int): Shift size for SW-MSA. |
||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
||||
|
drop (float, optional): Dropout rate. Default: 0.0 |
||||
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
||||
|
drop_path (float, optional): Stochastic depth rate. Default: 0.0 |
||||
|
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU |
||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, |
||||
|
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., |
||||
|
act_layer=nn.GELU, norm_layer=nn.LayerNorm): |
||||
|
super().__init__() |
||||
|
self.dim = dim |
||||
|
self.input_resolution = input_resolution |
||||
|
self.num_heads = num_heads |
||||
|
self.window_size = window_size |
||||
|
self.shift_size = shift_size |
||||
|
self.mlp_ratio = mlp_ratio |
||||
|
if min(self.input_resolution) <= self.window_size: |
||||
|
# if window size is larger than input resolution, we don't partition windows |
||||
|
self.shift_size = 0 |
||||
|
self.window_size = min(self.input_resolution) |
||||
|
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" |
||||
|
|
||||
|
self.norm1 = norm_layer(dim) |
||||
|
self.attn = WindowAttention( |
||||
|
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, |
||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) |
||||
|
|
||||
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
||||
|
self.norm2 = norm_layer(dim) |
||||
|
mlp_hidden_dim = int(dim * mlp_ratio) |
||||
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) |
||||
|
|
||||
|
if self.shift_size > 0: |
||||
|
# calculate attention mask for SW-MSA |
||||
|
H, W = self.input_resolution |
||||
|
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 |
||||
|
h_slices = (slice(0, -self.window_size), |
||||
|
slice(-self.window_size, -self.shift_size), |
||||
|
slice(-self.shift_size, None)) |
||||
|
w_slices = (slice(0, -self.window_size), |
||||
|
slice(-self.window_size, -self.shift_size), |
||||
|
slice(-self.shift_size, None)) |
||||
|
cnt = 0 |
||||
|
for h in h_slices: |
||||
|
for w in w_slices: |
||||
|
img_mask[:, h, w, :] = cnt |
||||
|
cnt += 1 |
||||
|
|
||||
|
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 |
||||
|
mask_windows = mask_windows.view(-1, self.window_size * self.window_size) |
||||
|
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) |
||||
|
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) |
||||
|
else: |
||||
|
attn_mask = None |
||||
|
|
||||
|
self.register_buffer("attn_mask", attn_mask) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
H, W = self.input_resolution |
||||
|
B, L, C = x.shape |
||||
|
assert L == H * W, "input feature has wrong size" |
||||
|
|
||||
|
shortcut = x |
||||
|
x = self.norm1(x) |
||||
|
x = x.view(B, H, W, C) |
||||
|
|
||||
|
# cyclic shift |
||||
|
if self.shift_size > 0: |
||||
|
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) |
||||
|
else: |
||||
|
shifted_x = x |
||||
|
|
||||
|
# partition windows |
||||
|
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C |
||||
|
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C |
||||
|
|
||||
|
# W-MSA/SW-MSA |
||||
|
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C |
||||
|
|
||||
|
# merge windows |
||||
|
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) |
||||
|
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C |
||||
|
|
||||
|
# reverse cyclic shift |
||||
|
if self.shift_size > 0: |
||||
|
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) |
||||
|
else: |
||||
|
x = shifted_x |
||||
|
x = x.view(B, H * W, C) |
||||
|
|
||||
|
# FFN |
||||
|
x = shortcut + self.drop_path(x) |
||||
|
x = x + self.drop_path(self.mlp(self.norm2(x))) |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
def extra_repr(self) -> str: |
||||
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ |
||||
|
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" |
||||
|
|
||||
|
def flops(self): |
||||
|
flops = 0 |
||||
|
H, W = self.input_resolution |
||||
|
# norm1 |
||||
|
flops += self.dim * H * W |
||||
|
# W-MSA/SW-MSA |
||||
|
nW = H * W / self.window_size / self.window_size |
||||
|
flops += nW * self.attn.flops(self.window_size * self.window_size) |
||||
|
# mlp |
||||
|
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio |
||||
|
# norm2 |
||||
|
flops += self.dim * H * W |
||||
|
return flops |
||||
|
|
||||
|
|
||||
|
class PatchMerging(nn.Module): |
||||
|
r""" Patch Merging Layer. |
||||
|
|
||||
|
Args: |
||||
|
input_resolution (tuple[int]): Resolution of input feature. |
||||
|
dim (int): Number of input channels. |
||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): |
||||
|
super().__init__() |
||||
|
self.input_resolution = input_resolution |
||||
|
self.dim = dim |
||||
|
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) |
||||
|
self.norm = norm_layer(4 * dim) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
""" |
||||
|
x: B, H*W, C |
||||
|
""" |
||||
|
H, W = self.input_resolution |
||||
|
B, L, C = x.shape |
||||
|
assert L == H * W, "input feature has wrong size" |
||||
|
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." |
||||
|
|
||||
|
x = x.view(B, H, W, C) |
||||
|
|
||||
|
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C |
||||
|
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C |
||||
|
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C |
||||
|
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C |
||||
|
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C |
||||
|
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C |
||||
|
|
||||
|
x = self.norm(x) |
||||
|
x = self.reduction(x) |
||||
|
|
||||
|
return x |
||||
|
|
||||
|
def extra_repr(self) -> str: |
||||
|
return f"input_resolution={self.input_resolution}, dim={self.dim}" |
||||
|
|
||||
|
def flops(self): |
||||
|
H, W = self.input_resolution |
||||
|
flops = H * W * self.dim |
||||
|
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim |
||||
|
return flops |
||||
|
|
||||
|
|
||||
|
class BasicLayer(nn.Module): |
||||
|
""" A basic Swin Transformer layer for one stage. |
||||
|
|
||||
|
Args: |
||||
|
dim (int): Number of input channels. |
||||
|
input_resolution (tuple[int]): Input resolution. |
||||
|
depth (int): Number of blocks. |
||||
|
num_heads (int): Number of attention heads. |
||||
|
window_size (int): Local window size. |
||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. |
||||
|
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True |
||||
|
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. |
||||
|
drop (float, optional): Dropout rate. Default: 0.0 |
||||
|
attn_drop (float, optional): Attention dropout rate. Default: 0.0 |
||||
|
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 |
||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm |
||||
|
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None |
||||
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, dim, input_resolution, depth, num_heads, window_size, |
||||
|
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., |
||||
|
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): |
||||
|
|
||||
|
super().__init__() |
||||
|
self.dim = dim |
||||
|
self.input_resolution = input_resolution |
||||
|
self.depth = depth |
||||
|
self.use_checkpoint = use_checkpoint |
||||
|
|
||||
|
# build blocks |
||||
|
self.blocks = nn.ModuleList([ |
||||
|
SwinTransformerBlock(dim=dim, input_resolution=input_resolution, |
||||
|
num_heads=num_heads, window_size=window_size, |
||||
|
shift_size=0 if (i % 2 == 0) else window_size // 2, |
||||
|
mlp_ratio=mlp_ratio, |
||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale, |
||||
|
drop=drop, attn_drop=attn_drop, |
||||
|
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, |
||||
|
norm_layer=norm_layer) |
||||
|
for i in range(depth)]) |
||||
|
|
||||
|
# patch merging layer |
||||
|
if downsample is not None: |
||||
|
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) |
||||
|
else: |
||||
|
self.downsample = None |
||||
|
|
||||
|
def forward(self, x): |
||||
|
for blk in self.blocks: |
||||
|
if self.use_checkpoint: |
||||
|
x = checkpoint.checkpoint(blk, x) |
||||
|
else: |
||||
|
x = blk(x) |
||||
|
if self.downsample is not None: |
||||
|
x = self.downsample(x) |
||||
|
return x |
||||
|
|
||||
|
def extra_repr(self) -> str: |
||||
|
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" |
||||
|
|
||||
|
def flops(self): |
||||
|
flops = 0 |
||||
|
for blk in self.blocks: |
||||
|
flops += blk.flops() |
||||
|
if self.downsample is not None: |
||||
|
flops += self.downsample.flops() |
||||
|
return flops |
||||
|
|
||||
|
|
||||
|
class PatchEmbed(nn.Module): |
||||
|
r""" Image to Patch Embedding |
||||
|
|
||||
|
Args: |
||||
|
img_size (int): Image size. Default: 224. |
||||
|
patch_size (int): Patch token size. Default: 4. |
||||
|
in_chans (int): Number of input image channels. Default: 3. |
||||
|
embed_dim (int): Number of linear projection output channels. Default: 96. |
||||
|
norm_layer (nn.Module, optional): Normalization layer. Default: None |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): |
||||
|
super().__init__() |
||||
|
img_size = to_2tuple(img_size) |
||||
|
patch_size = to_2tuple(patch_size) |
||||
|
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] |
||||
|
self.img_size = img_size |
||||
|
self.patch_size = patch_size |
||||
|
self.patches_resolution = patches_resolution |
||||
|
self.num_patches = patches_resolution[0] * patches_resolution[1] |
||||
|
|
||||
|
self.in_chans = in_chans |
||||
|
self.embed_dim = embed_dim |
||||
|
|
||||
|
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) |
||||
|
if norm_layer is not None: |
||||
|
self.norm = norm_layer(embed_dim) |
||||
|
else: |
||||
|
self.norm = None |
||||
|
|
||||
|
def forward(self, x): |
||||
|
B, C, H, W = x.shape |
||||
|
# FIXME look at relaxing size constraints |
||||
|
assert H == self.img_size[0] and W == self.img_size[1], \ |
||||
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." |
||||
|
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C |
||||
|
if self.norm is not None: |
||||
|
x = self.norm(x) |
||||
|
return x |
||||
|
|
||||
|
def flops(self): |
||||
|
Ho, Wo = self.patches_resolution |
||||
|
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) |
||||
|
if self.norm is not None: |
||||
|
flops += Ho * Wo * self.embed_dim |
||||
|
return flops |
||||
|
|
||||
|
|
||||
|
class SwinTransformer(nn.Module): |
||||
|
r""" Swin Transformer |
||||
|
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - |
||||
|
https://arxiv.org/pdf/2103.14030 |
||||
|
|
||||
|
Args: |
||||
|
img_size (int | tuple(int)): Input image size. Default 224 |
||||
|
patch_size (int | tuple(int)): Patch size. Default: 4 |
||||
|
in_chans (int): Number of input image channels. Default: 3 |
||||
|
num_classes (int): Number of classes for classification head. Default: 1000 |
||||
|
embed_dim (int): Patch embedding dimension. Default: 96 |
||||
|
depths (tuple(int)): Depth of each Swin Transformer layer. |
||||
|
num_heads (tuple(int)): Number of attention heads in different layers. |
||||
|
window_size (int): Window size. Default: 7 |
||||
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 |
||||
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True |
||||
|
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None |
||||
|
drop_rate (float): Dropout rate. Default: 0 |
||||
|
attn_drop_rate (float): Attention dropout rate. Default: 0 |
||||
|
drop_path_rate (float): Stochastic depth rate. Default: 0.1 |
||||
|
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. |
||||
|
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False |
||||
|
patch_norm (bool): If True, add normalization after patch embedding. Default: True |
||||
|
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, img_size=224, patch_size=4, in_chans=3, |
||||
|
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], |
||||
|
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, |
||||
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, |
||||
|
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, |
||||
|
use_checkpoint=False): |
||||
|
super().__init__() |
||||
|
|
||||
|
# self.num_classes = num_classes |
||||
|
self.num_layers = len(depths) |
||||
|
self.embed_dim = embed_dim |
||||
|
self.ape = ape |
||||
|
self.patch_norm = patch_norm |
||||
|
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) |
||||
|
self.mlp_ratio = mlp_ratio |
||||
|
|
||||
|
# split image into non-overlapping patches |
||||
|
self.patch_embed = PatchEmbed( |
||||
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, |
||||
|
norm_layer=norm_layer if self.patch_norm else None) |
||||
|
num_patches = self.patch_embed.num_patches |
||||
|
patches_resolution = self.patch_embed.patches_resolution |
||||
|
self.patches_resolution = patches_resolution |
||||
|
|
||||
|
# absolute position embedding |
||||
|
if self.ape: |
||||
|
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) |
||||
|
trunc_normal_(self.absolute_pos_embed, std=.02) |
||||
|
|
||||
|
self.pos_drop = nn.Dropout(p=drop_rate) |
||||
|
|
||||
|
# stochastic depth |
||||
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule |
||||
|
|
||||
|
# build layers |
||||
|
self.layers = nn.ModuleList() |
||||
|
for i_layer in range(self.num_layers): |
||||
|
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), |
||||
|
input_resolution=(patches_resolution[0] // (2 ** i_layer), |
||||
|
patches_resolution[1] // (2 ** i_layer)), |
||||
|
depth=depths[i_layer], |
||||
|
num_heads=num_heads[i_layer], |
||||
|
window_size=window_size, |
||||
|
mlp_ratio=self.mlp_ratio, |
||||
|
qkv_bias=qkv_bias, qk_scale=qk_scale, |
||||
|
drop=drop_rate, attn_drop=attn_drop_rate, |
||||
|
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], |
||||
|
norm_layer=norm_layer, |
||||
|
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, |
||||
|
use_checkpoint=use_checkpoint) |
||||
|
self.layers.append(layer) |
||||
|
|
||||
|
self.norm = norm_layer(self.num_features) |
||||
|
# self.avgpool = nn.AdaptiveAvgPool1d(1) |
||||
|
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() |
||||
|
|
||||
|
self.apply(self._init_weights) |
||||
|
|
||||
|
def _init_weights(self, m): |
||||
|
if isinstance(m, nn.Linear): |
||||
|
trunc_normal_(m.weight, std=.02) |
||||
|
if isinstance(m, nn.Linear) and m.bias is not None: |
||||
|
nn.init.constant_(m.bias, 0) |
||||
|
elif isinstance(m, nn.LayerNorm): |
||||
|
nn.init.constant_(m.bias, 0) |
||||
|
nn.init.constant_(m.weight, 1.0) |
||||
|
|
||||
|
@torch.jit.ignore |
||||
|
def no_weight_decay(self): |
||||
|
return {'absolute_pos_embed'} |
||||
|
|
||||
|
@torch.jit.ignore |
||||
|
def no_weight_decay_keywords(self): |
||||
|
return {'relative_position_bias_table'} |
||||
|
|
||||
|
def forward_features(self, x): |
||||
|
x = self.patch_embed(x) |
||||
|
if self.ape: |
||||
|
x = x + self.absolute_pos_embed |
||||
|
x = self.pos_drop(x) |
||||
|
|
||||
|
for layer in self.layers: |
||||
|
x = layer(x) |
||||
|
|
||||
|
x = self.norm(x) # B L C |
||||
|
# x = self.avgpool(x.transpose(1, 2)) # B C 1 |
||||
|
# x = torch.flatten(x, 1) |
||||
|
return x |
||||
|
|
||||
|
def forward(self, x): |
||||
|
x = self.forward_features(x) |
||||
|
#x = self.head(x) |
||||
|
return x |
||||
|
|
||||
|
def flops(self): |
||||
|
flops = 0 |
||||
|
flops += self.patch_embed.flops() |
||||
|
for i, layer in enumerate(self.layers): |
||||
|
flops += layer.flops() |
||||
|
#flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) |
||||
|
#flops += self.num_features * self.num_classes |
||||
|
return flops |
Loading…
Reference in new issue