diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..af9ed2f --- /dev/null +++ b/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .expansionnet_v2 import ExpansionNetV2 + +def expansionnet_v2(model_name: str): + return ExpansionNetV2(model_name) diff --git a/demo_coco_tokens.pickle b/demo_coco_tokens.pickle new file mode 100644 index 0000000..8130d61 Binary files /dev/null and b/demo_coco_tokens.pickle differ diff --git a/expansionnet_v2.py b/expansionnet_v2.py new file mode 100644 index 0000000..cbaceb2 --- /dev/null +++ b/expansionnet_v2.py @@ -0,0 +1,55 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from pathlib import Path + +import torch +from torchvision import transforms +from transformers import GPT2Tokenizer + +from towhee.types.arg import arg, to_image_color +from towhee.types.image_utils import to_pil +from towhee.operator.base import NNOperator, OperatorFlag +from towhee import register +from towhee.models import clip + +class ExpansionNetV2(NNOperator): + """ + ExpansionNet V2 image captioning operator + """ + def __init__(self, model_name: str): + super().__init__() + sys.path.append(str(Path(__file__).parent)) + from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 + sys.path.pop() + with open('demo_coco_tokens.pickle') as fw: + + self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3, + swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48], + swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None, + swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0, + swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True, + swin_use_checkpoint=False, + final_swin_dim=1536, + + d_model=model_args.model_dim, N_enc=model_args.N_enc, + N_dec=model_args.N_dec, num_heads=8, ff=2048, + num_exp_enc_list=[32, 64, 128, 256, 512], + num_exp_dec=16, + output_word2idx=coco_tokens['word2idx_dict'], + output_idx2word=coco_tokens['idx2word_list'], + max_seq_len=args.max_seq_len, drop_args=model_args.drop_args, + rank='cpu') diff --git a/models/End_ExpansionNet_v2.py b/models/End_ExpansionNet_v2.py new file mode 100644 index 0000000..712f760 --- /dev/null +++ b/models/End_ExpansionNet_v2.py @@ -0,0 +1,187 @@ +import torch +from models.layers import EmbeddingLayer, DecoderLayer, EncoderLayer +from utils.masking import create_pad_mask, create_no_peak_and_pad_mask +from models.captioning_model import CaptioningModel +from models.swin_transformer_mod import SwinTransformer + +import torch.nn as nn + + +class End_ExpansionNet_v2(CaptioningModel): + def __init__(self, + + # swin transf + swin_img_size, swin_patch_size, swin_in_chans, + swin_embed_dim, swin_depths, swin_num_heads, + swin_window_size, swin_mlp_ratio, swin_qkv_bias, swin_qk_scale, + swin_drop_rate, swin_attn_drop_rate, swin_drop_path_rate, + swin_norm_layer, swin_ape, swin_patch_norm, + swin_use_checkpoint, + + # linear_size, + final_swin_dim, + + # captioning + d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, + output_word2idx, output_idx2word, max_seq_len, drop_args, rank=0): + super(End_ExpansionNet_v2, self).__init__() + + self.swin_transf = SwinTransformer( + img_size=swin_img_size, patch_size=swin_patch_size, in_chans=swin_in_chans, + embed_dim=swin_embed_dim, depths=swin_depths, num_heads=swin_num_heads, + window_size=swin_window_size, mlp_ratio=swin_mlp_ratio, qkv_bias=swin_qkv_bias, qk_scale=swin_qk_scale, + drop_rate=swin_drop_rate, attn_drop_rate=swin_attn_drop_rate, drop_path_rate=swin_drop_path_rate, + norm_layer=swin_norm_layer, ape=swin_ape, patch_norm=swin_patch_norm, + use_checkpoint=swin_use_checkpoint) + + self.output_word2idx = output_word2idx + self.output_idx2word = output_idx2word + self.max_seq_len = max_seq_len + + self.num_exp_dec = num_exp_dec + self.num_exp_enc_list = num_exp_enc_list + + self.N_enc = N_enc + self.N_dec = N_dec + self.d_model = d_model + + self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) + self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) + + self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) + self.input_linear = torch.nn.Linear(final_swin_dim, d_model) + self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) + self.log_softmax = nn.LogSoftmax(dim=-1) + + self.out_enc_dropout = nn.Dropout(drop_args.other) + self.out_dec_dropout = nn.Dropout(drop_args.other) + + self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) + self.pos_encoder = nn.Embedding(max_seq_len, d_model) + + self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) + self.enc_reduce_norm = nn.LayerNorm(d_model) + self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) + self.dec_reduce_norm = nn.LayerNorm(d_model) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + self.trained_steps = 0 + self.rank = rank + + self.check_required_attributes() + + def forward_enc(self, enc_input, enc_input_num_pads): + + assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding" + x = self.swin_transf(enc_input) + # --------------- Normale parte di Captioning --------------------------------- + enc_input = self.input_embedder_dropout(self.input_linear(x)) + x = enc_input + enc_input_num_pads = [0] * enc_input.size(0) + + max_num_enc = sum(self.num_exp_enc_list) + pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) + pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), + pad_along_row_input=[0] * enc_input.size(0), + pad_along_column_input=enc_input_num_pads, + rank=self.rank) + + x_list = [] + for i in range(self.N_enc): + x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) + x_list.append(x) + x_list = torch.cat(x_list, dim=-1) + x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) + x = self.enc_reduce_norm(x) + + return x + + def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): + assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * cross_input.size(0))), "enc_input_num_pads should be no None" + + enc_input_num_pads = [0] * dec_input.size(0) + no_peak_and_pad_mask = create_no_peak_and_pad_mask( + mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), + num_pads=dec_input_num_pads, + rank=self.rank) + pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), + pad_along_row_input=dec_input_num_pads, + pad_along_column_input=enc_input_num_pads, + rank=self.rank) + + y = self.out_embedder(dec_input) + pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) + pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) + y = y + self.pos_encoder(pos_y) + y_list = [] + for i in range(self.N_dec): + y = self.decoders[i](x=y, + n_indexes=pos_x, + cross_connection_x=cross_input, + input_attention_mask=no_peak_and_pad_mask, + cross_attention_mask=pad_mask) + y_list.append(y) + y_list = torch.cat(y_list, dim=-1) + y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) + y = self.dec_reduce_norm(y) + + y = self.vocab_linear(y) + + if apply_log_softmax: + y = self.log_softmax(y) + + return y + + + def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, + sos_idx, eos_idx, max_seq_len): + + bs = enc_input.size(0) + x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) + enc_seq_len = x.size(1) + x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) + + upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) + where_is_eos_vector = upperbound_vector.clone() + eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) + finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) + + predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) + predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) + + dec_input_num_pads = [0]*(bs*num_outputs) + time_step = 0 + while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: + dec_input = predicted_caption + log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) + + prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) + sampled_word_indexes = prob_dist.sample() + + predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) + predicted_caption_prob = torch.cat((predicted_caption_prob, + log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) + time_step += 1 + + where_is_eos_vector = torch.min(where_is_eos_vector, + upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) + finished_flag_vector = torch.max(finished_flag_vector, + (sampled_word_indexes == eos_vector).type(torch.IntTensor)) + + res_predicted_caption = [] + for i in range(bs): + res_predicted_caption.append([]) + for j in range(num_outputs): + index = i*num_outputs + j + res_predicted_caption[i].append( + predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) + + where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) + arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) + predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) + res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) + + return res_predicted_caption, res_predicted_caption_prob diff --git a/models/ExpansionNet_v2.py b/models/ExpansionNet_v2.py new file mode 100644 index 0000000..0a9e9b0 --- /dev/null +++ b/models/ExpansionNet_v2.py @@ -0,0 +1,103 @@ +import torch +from models.layers import EmbeddingLayer, EncoderLayer, DecoderLayer +from utils.masking import create_pad_mask, create_no_peak_and_pad_mask +from models.captioning_model import CaptioningModel + +import torch.nn as nn + + +class ExpansionNet_v2(CaptioningModel): + def __init__(self, d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec, + output_word2idx, output_idx2word, max_seq_len, drop_args, img_feature_dim=2048, rank=0): + super().__init__() + self.output_word2idx = output_word2idx + self.output_idx2word = output_idx2word + self.max_seq_len = max_seq_len + + self.num_exp_dec = num_exp_dec + self.num_exp_enc_list = num_exp_enc_list + + self.N_enc = N_enc + self.N_dec = N_dec + self.d_model = d_model + + self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)]) + self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)]) + + self.input_embedder_dropout = nn.Dropout(drop_args.enc_input) + self.input_linear = torch.nn.Linear(img_feature_dim, d_model) + self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx)) + self.log_softmax = nn.LogSoftmax(dim=-1) + + self.out_enc_dropout = nn.Dropout(drop_args.other) + self.out_dec_dropout = nn.Dropout(drop_args.other) + + self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input) + self.pos_encoder = nn.Embedding(max_seq_len, d_model) + + self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model) + self.enc_reduce_norm = nn.LayerNorm(d_model) + self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model) + self.dec_reduce_norm = nn.LayerNorm(d_model) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + self.trained_steps = 0 + self.rank = rank + + def forward_enc(self, enc_input, enc_input_num_pads): + + x = self.input_embedder_dropout(self.input_linear(enc_input)) + + max_num_enc = sum(self.num_exp_enc_list) + pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank) + pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)), + pad_along_row_input=[0] * enc_input.size(0), + pad_along_column_input=enc_input_num_pads, + rank=self.rank) + + x_list = [] + for i in range(self.N_enc): + x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask) + x_list.append(x) + x_list = torch.cat(x_list, dim=-1) + x = x + self.out_enc_dropout(self.enc_reduce_group(x_list)) + x = self.enc_reduce_norm(x) + return x + + def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): + + no_peak_and_pad_mask = create_no_peak_and_pad_mask( + mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)), + num_pads=dec_input_num_pads, + rank=self.rank) + + pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)), + pad_along_row_input=dec_input_num_pads, + pad_along_column_input=enc_input_num_pads, + rank=self.rank) + + y = self.out_embedder(dec_input) + pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank) + pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank) + y = y + self.pos_encoder(pos_y) + y_list = [] + for i in range(self.N_dec): + y = self.decoders[i](x=y, + n_indexes=pos_x, + cross_connection_x=cross_input, + input_attention_mask=no_peak_and_pad_mask, + cross_attention_mask=pad_mask) + y_list.append(y) + y_list = torch.cat(y_list, dim=-1) + y = y + self.out_dec_dropout(self.dec_reduce_group(y_list)) + y = self.dec_reduce_norm(y) + + y = self.vocab_linear(y) + + if apply_log_softmax: + y = self.log_softmax(y) + + return y diff --git a/models/captioning_model.py b/models/captioning_model.py new file mode 100644 index 0000000..56345ac --- /dev/null +++ b/models/captioning_model.py @@ -0,0 +1,241 @@ + + +import torch +import torch.nn as nn + + +class CaptioningModel(nn.Module): + def __init__(self): + super(CaptioningModel, self).__init__() + # mandatory attributes + # rank: to enable multiprocessing + self.rank = None + + def check_required_attributes(self): + if self.rank is None: + raise NotImplementedError("Subclass must assign the rank integer according to the GPU group") + + def forward_enc(self, enc_input, enc_input_num_pads): + raise NotImplementedError + + def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): + raise NotImplementedError + + def forward(self, enc_x, dec_x=None, + enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, + mode='forward', **kwargs): + if mode == 'forward': + x = self.forward_enc(enc_x, enc_x_num_pads) + y = self.forward_dec(x, enc_x_num_pads, dec_x, dec_x_num_pads, apply_log_softmax) + return y + else: + assert ('sos_idx' in kwargs.keys() or 'eos_idx' in kwargs.keys()), \ + 'sos and eos must be provided in case of batch sampling or beam search' + sos_idx = kwargs.get('sos_idx', -999) + eos_idx = kwargs.get('eos_idx', -999) + if mode == 'beam_search': + beam_size_arg = kwargs.get('beam_size', 5) + how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) + beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) + sample_or_max = kwargs.get('sample_or_max', 'max') + out_classes, out_logprobs = self.beam_search( + enc_x, enc_x_num_pads, + beam_size=beam_size_arg, + sos_idx=sos_idx, + eos_idx=eos_idx, + how_many_outputs=how_many_outputs_per_beam, + max_seq_len=beam_max_seq_len, + sample_or_max=sample_or_max) + return out_classes, out_logprobs + if mode == 'sampling': + how_many_outputs = kwargs.get('how_many_outputs', 1) + sample_max_seq_len = kwargs.get('sample_max_seq_len', 20) + out_classes, out_logprobs = self.get_batch_multiple_sampled_prediction( + enc_x, enc_x_num_pads, num_outputs=how_many_outputs, + sos_idx=sos_idx, eos_idx=eos_idx, + max_seq_len=sample_max_seq_len) + return out_classes, out_logprobs + + def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs, + sos_idx, eos_idx, max_seq_len): + bs, enc_seq_len, _ = enc_input.shape + + enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(num_outputs)] + + x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads) + x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1]) + + upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank) + where_is_eos_vector = upperbound_vector.clone() + eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank) + finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int) + + predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1) + predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1) + + dec_input_num_pads = [0]*(bs*num_outputs) + time_step = 0 + while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len: + dec_input = predicted_caption + log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True) + + prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step])) + sampled_word_indexes = prob_dist.sample() + + predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1) + predicted_caption_prob = torch.cat((predicted_caption_prob, + log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1) + time_step += 1 + + where_is_eos_vector = torch.min(where_is_eos_vector, + upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step)) + finished_flag_vector = torch.max(finished_flag_vector, + (sampled_word_indexes == eos_vector).type(torch.IntTensor)) + + # remove the elements that come after the first eos from the sequence + res_predicted_caption = [] + for i in range(bs): + res_predicted_caption.append([]) + for j in range(num_outputs): + index = i*num_outputs + j + res_predicted_caption[i].append( + predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist()) + + where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1) + arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank) + predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0) + res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1) + + return res_predicted_caption, res_predicted_caption_prob + + def beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, + beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): + assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" + assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" + bs = enc_input.shape[0] + + cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads) + + # init: ------------------------------------------------------------------ + init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) + init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) + log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, + dec_input=init_dec_class, dec_input_num_pads=[0] * bs, + apply_log_softmax=True) + if sample_or_max == 'max': + _, topi = torch.topk(log_probs, k=beam_size, sorted=True) + else: # sample + topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) + topi = topi.unsqueeze(1) + + init_dec_class = init_dec_class.repeat(1, beam_size) + init_dec_class = init_dec_class.unsqueeze(-1) + top_beam_size_class = topi.transpose(-2, -1) + init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) + + init_dec_logprob = init_dec_logprob.repeat(1, beam_size) + init_dec_logprob = init_dec_logprob.unsqueeze(-1) + top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) + top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) + init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) + + bs, enc_seq_len, d_model = cross_enc_output.shape + cross_enc_output = cross_enc_output.unsqueeze(1) + cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) + cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() + enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] + + # loop: ----------------------------------------------------------------- + loop_dec_classes = init_dec_class + loop_dec_logprobs = init_dec_logprob + loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) + + loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) + + for time_step in range(2, max_seq_len): + loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() + + log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads, + dec_input=loop_dec_classes, + dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), + apply_log_softmax=True) + if sample_or_max == 'max': + _, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) + else: # sample + topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, + replacement=False) + + top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) + + top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) + top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) + + # each sequence have now its best prediction, but some sequence may have already been terminated with EOS, + # in that case its candidates are simply ignored, and do not sum up in the "loop_dec_logprobs" their value + # are set to zero + there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ + sum(dim=-1, keepdims=True).type(torch.bool) + + # if we pad with -999 its candidates logprobabilities, also the sequence containing EOS would be + # straightforwardly discarded, instead we want to keep it in the exploration. Therefore we mask with 0.0 + # one arbitrary candidate word probability so the sequence probability is unchanged but it + # can still be discarded when a better candidate sequence is found + top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) + top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) + + comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs + + comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) + _, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) + which_sequence = topi // beam_size + which_word = topi % beam_size + + loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) + loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) + + bs_idxes = torch.arange(bs).unsqueeze(-1) + new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] + new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] + + which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] + which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ + [bs_idxes, which_sequence]] + which_word = which_word.unsqueeze(-1) + + lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, + index=which_word) + lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) + + new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) + new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) + loop_dec_classes = new_loop_dec_classes + loop_dec_logprobs = new_loop_dec_logprobs + + loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) + + # -----------------------update loop_num_elem_vector ---------------------------- + loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) + there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ + sum(dim=-1).type(torch.bool).view(bs * beam_size) + loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) + + if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): + break + + # sort out the best result + loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) + _, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) + res_caption_pred = [[] for _ in range(bs)] + res_caption_logprob = [[] for _ in range(bs)] + for i in range(bs): + for j in range(how_many_outputs): + idx = topi[i, j].item() + res_caption_pred[i].append( + loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) + res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) + + flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] + flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) + res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) + + return res_caption_pred, res_caption_logprob diff --git a/models/ensemble_captioning_model.py b/models/ensemble_captioning_model.py new file mode 100644 index 0000000..3498111 --- /dev/null +++ b/models/ensemble_captioning_model.py @@ -0,0 +1,187 @@ + +import torch +import torch.nn as nn +from models.captioning_model import CaptioningModel + + +class EsembleCaptioningModel(CaptioningModel): + def __init__(self, models_list, rank): + super().__init__() + self.num_models = len(models_list) + self.models_list = models_list + self.rank = rank + + self.dummy_linear = nn.Linear(1, 1) + + for model in self.models_list: + model.eval() + + def forward(self, enc_x, dec_x=None, + enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False, + mode='beam_search', **kwargs): + assert (mode == 'beam_search'), "this class supports only beam search." + sos_idx = kwargs.get('sos_idx', -999) + eos_idx = kwargs.get('eos_idx', -999) + if mode == 'beam_search': + beam_size_arg = kwargs.get('beam_size', 5) + how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1) + beam_max_seq_len = kwargs.get('beam_max_seq_len', 20) + sample_or_max = kwargs.get('sample_or_max', 'max') + out_classes, out_logprobs = self.ensemble_beam_search( + enc_x, enc_x_num_pads, + beam_size=beam_size_arg, + sos_idx=sos_idx, + eos_idx=eos_idx, + how_many_outputs=how_many_outputs_per_beam, + max_seq_len=beam_max_seq_len, + sample_or_max=sample_or_max) + return out_classes, out_logprobs + + def forward_enc(self, enc_input, enc_input_num_pads): + x_outputs_list = [] + for i in range(self.num_models): + x_outputs = self.models_list[i].forward_enc(enc_input, enc_input_num_pads) + x_outputs_list.append(x_outputs) + return x_outputs_list + + def forward_dec(self, cross_input_list, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False): + + import torch.nn.functional as F + y_outputs = [] + for i in range(self.num_models): + y_outputs.append( + F.softmax(self.models_list[i].forward_dec( + cross_input_list[i], enc_input_num_pads, + dec_input, dec_input_num_pads, False).unsqueeze(0), dim=-1)) + avg = torch.cat(y_outputs, dim=0).mean(dim=0).log() + + return avg + + # quite unclean coding, to be re-factored in the future... + # since it's a bit similar to the single model case + def ensemble_beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx, + beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',): + assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width" + assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'" + bs = enc_input.shape[0] + + # the cross_dec_input is computed once + cross_enc_output_list = self.forward_enc(enc_input, enc_input_num_pads) + + # init: ------------------------------------------------------------------ + init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank) + init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank) + log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, + dec_input=init_dec_class, dec_input_num_pads=[0] * bs, + apply_log_softmax=True) + if sample_or_max == 'max': + _, topi = torch.topk(log_probs, k=beam_size, sorted=True) + else: # sample + topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False) + topi = topi.unsqueeze(1) + + init_dec_class = init_dec_class.repeat(1, beam_size) + init_dec_class = init_dec_class.unsqueeze(-1) + top_beam_size_class = topi.transpose(-2, -1) + init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1) + + init_dec_logprob = init_dec_logprob.repeat(1, beam_size) + init_dec_logprob = init_dec_logprob.unsqueeze(-1) + top_beam_size_logprob = log_probs.gather(dim=-1, index=topi) + top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1) + init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1) + + tmp_cross_enc_output_list = [] + for cross_enc_output in cross_enc_output_list: + bs, enc_seq_len, d_model = cross_enc_output.shape + cross_enc_output = cross_enc_output.unsqueeze(1) + cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1) + cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous() + tmp_cross_enc_output_list.append(cross_enc_output) + cross_enc_output_list = tmp_cross_enc_output_list + enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)] + + loop_dec_classes = init_dec_class + loop_dec_logprobs = init_dec_logprob + loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) + + loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank) + + for time_step in range(2, max_seq_len): + loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous() + + log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads, + dec_input=loop_dec_classes, + dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(), + apply_log_softmax=True) + if sample_or_max == 'max': + _, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True) + else: # sample + topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size, + replacement=False) + + top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size) + + top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi) + top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size) + + there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \ + sum(dim=-1, keepdims=True).type(torch.bool) + + top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0) + top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0) + + comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs + + comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size) + _, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True) + which_sequence = topi // beam_size + which_word = topi % beam_size + + loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1) + loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1) + + bs_idxes = torch.arange(bs).unsqueeze(-1) + new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]] + new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]] + + which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]] + which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[ + [bs_idxes, which_sequence]] + which_word = which_word.unsqueeze(-1) + + lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1, + index=which_word) + lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word) + + new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1) + new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1) + loop_dec_classes = new_loop_dec_classes + loop_dec_logprobs = new_loop_dec_logprobs + + loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True) + + loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size) + there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \ + sum(dim=-1).type(torch.bool).view(bs * beam_size) + loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int))) + + if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size): + break + + loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1) + _, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size) + res_caption_pred = [[] for _ in range(bs)] + res_caption_logprob = [[] for _ in range(bs)] + for i in range(bs): + for j in range(how_many_outputs): + idx = topi[i, j].item() + res_caption_pred[i].append( + loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist()) + res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]]) + + flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]] + flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True) + res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1) + + return res_caption_pred, res_caption_logprob diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 0000000..e126a58 --- /dev/null +++ b/models/layers.py @@ -0,0 +1,286 @@ +import torch +import torch.nn as nn +import math + +import numpy as np +import torch.nn.functional as F + +class EmbeddingLayer(nn.Module): + def __init__(self, vocab_size, d_model, dropout_perc): + super(EmbeddingLayer, self).__init__() + self.dropout = nn.Dropout(dropout_perc) + self.embed = nn.Embedding(vocab_size, d_model) + self.d_model = d_model + + def forward(self, x): + return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model)) + + +class PositionalEncoder(nn.Module): + def __init__(self, d_model, max_seq_len, rank=0): + super().__init__() + assert d_model % 2 == 0, "d_model is not even, even number suggested" + self.d_model = d_model + self.pe = torch.zeros(max_seq_len, d_model).to(rank) + for pos in range(max_seq_len): + for i in range(0, d_model, 2): + self.pe.data[pos, i] = math.sin(pos / (10000.0 ** ((2.0 * i) / d_model))) + self.pe.data[pos, i + 1] = math.cos(pos / (10000.0 ** ((2.0 * i) / d_model))) + self.pe.data = self.pe.data.unsqueeze(0) + + def forward(self, x): + seq_len = x.shape[1] + return self.pe.data[0, :seq_len] + + + +class StaticExpansionBlock(nn.Module): + def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps): + super().__init__() + self.d_model = d_model + self.num_enc_exp_list = num_enc_exp_list + + self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) + self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model) + + self.key_embed = nn.Linear(d_model, d_model) + self.class_a_embed = nn.Linear(d_model, d_model) + self.class_b_embed = nn.Linear(d_model, d_model) + + self.selector_embed = nn.Linear(d_model, d_model) + + self.dropout_class_a_fw = nn.Dropout(dropout_perc) + self.dropout_class_b_fw = nn.Dropout(dropout_perc) + + self.dropout_class_a_bw = nn.Dropout(dropout_perc) + self.dropout_class_b_bw = nn.Dropout(dropout_perc) + + self.Z_dropout = nn.Dropout(dropout_perc) + + self.eps = eps + + def forward(self, x, n_indexes, mask): + bs, enc_len, _ = x.shape + + query_exp = self.query_exp_vectors(n_indexes) + bias_exp = self.bias_exp_vectors(n_indexes) + x_key = self.key_embed(x) + + z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model) + z = self.Z_dropout(z) + + class_a_fw = F.relu(z) + class_b_fw = F.relu(-z) + class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0) + class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0) + class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) + class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) + + class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp + class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp + class_a = self.dropout_class_a_fw(class_a) + class_b = self.dropout_class_b_fw(class_b) + + class_a_bw = F.relu(z.transpose(-2, -1)) + class_b_bw = F.relu(-z.transpose(-2, -1)) + + accum = 0 + class_a_bw_list = [] + class_b_bw_list = [] + for j in range(len(self.num_enc_exp_list)): + from_idx = accum + to_idx = accum + self.num_enc_exp_list[j] + accum += self.num_enc_exp_list[j] + class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) + class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps)) + class_a_bw = torch.cat(class_a_bw_list, dim=-1) + class_b_bw = torch.cat(class_b_bw_list, dim=-1) + + class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list) + class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list) + class_a = self.dropout_class_a_bw(class_a) + class_b = self.dropout_class_b_bw(class_b) + + selector = torch.sigmoid(self.selector_embed(x)) + x_result = selector * class_a + (1 - selector) * class_b + + return x_result + + +class EncoderLayer(nn.Module): + def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9): + super().__init__() + self.norm_1 = nn.LayerNorm(d_model) + self.norm_2 = nn.LayerNorm(d_model) + self.dropout_1 = nn.Dropout(dropout_perc) + self.dropout_2 = nn.Dropout(dropout_perc) + + self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps) + self.ff = FeedForward(d_model, d_ff, dropout_perc) + + def forward(self, x, n_indexes, mask): + x2 = self.norm_1(x) + x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask)) + x2 = self.norm_2(x) + x = x + self.dropout_2(self.ff(x2)) + return x + + +class DynamicExpansionBlock(nn.Module): + def __init__(self, d_model, num_exp, dropout_perc, eps): + super().__init__() + self.d_model = d_model + + self.num_exp = num_exp + self.cond_embed = nn.Linear(d_model, d_model) + + self.query_exp_vectors = nn.Embedding(self.num_exp, d_model) + self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model) + + self.key_linear = nn.Linear(d_model, d_model) + self.class_a_embed = nn.Linear(d_model, d_model) + self.class_b_embed = nn.Linear(d_model, d_model) + + self.selector_embed = nn.Linear(d_model, d_model) + + self.dropout_class_a_fw = nn.Dropout(dropout_perc) + self.dropout_class_b_fw = nn.Dropout(dropout_perc) + self.dropout_class_a_bw = nn.Dropout(dropout_perc) + self.dropout_class_b_bw = nn.Dropout(dropout_perc) + + self.Z_dropout = nn.Dropout(dropout_perc) + + self.eps = eps + + def forward(self, x, n_indexes, mask): + bs, dec_len, _ = x.shape + + cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model) + query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1) + bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1) + query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) + bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model) + + x_key = self.key_linear(x) + z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model) + z = self.Z_dropout(z) + + mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \ + view(bs, dec_len * self.num_exp, dec_len) + + class_a_fw = F.relu(z) + class_b_fw = F.relu(-z) + class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0) + class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0) + class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps) + class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps) + class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + class_a = self.dropout_class_a_fw(class_a) + class_b = self.dropout_class_b_fw(class_b) + + mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \ + view(bs, dec_len, dec_len * self.num_exp) + + class_a_bw = F.relu(z.transpose(-2, -1)) + class_b_bw = F.relu(-z.transpose(-2, -1)) + class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0) + class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0) + class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps) + class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps) + class_a = torch.matmul(class_a_bw, class_a + bias_exp) + class_b = torch.matmul(class_b_bw, class_b + bias_exp) + class_a = self.dropout_class_a_bw(class_a) + class_b = self.dropout_class_b_bw(class_b) + + selector = torch.sigmoid(self.selector_embed(x)) + x_result = selector * class_a + (1 - selector) * class_b + + return x_result + + +class DecoderLayer(nn.Module): + def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9): + super().__init__() + self.norm_1 = nn.LayerNorm(d_model) + self.norm_2 = nn.LayerNorm(d_model) + self.norm_3 = nn.LayerNorm(d_model) + + self.dropout_1 = nn.Dropout(dropout_perc) + self.dropout_2 = nn.Dropout(dropout_perc) + self.dropout_3 = nn.Dropout(dropout_perc) + + self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc) + self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps) + self.ff = FeedForward(d_model, d_ff, dropout_perc) + + def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask): + + # Pre-LayerNorm + x2 = self.norm_1(x) + x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask)) + + x2 = self.norm_2(x) + x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x, + mask=cross_attention_mask)) + + x2 = self.norm_3(x) + x = x + self.dropout_3(self.ff(x2)) + return x + + + +class MultiHeadAttention(nn.Module): + def __init__(self, d_model, num_heads, dropout_perc): + super(MultiHeadAttention, self).__init__() + assert d_model % num_heads == 0, "num heads must be multiple of d_model" + + self.d_model = d_model + self.d_k = int(d_model / num_heads) + self.num_heads = num_heads + + self.Wq = nn.Linear(d_model, self.d_k * num_heads) + self.Wk = nn.Linear(d_model, self.d_k * num_heads) + self.Wv = nn.Linear(d_model, self.d_k * num_heads) + + self.out_linear = nn.Linear(d_model, d_model) + + def forward(self, q, k, v, mask=None): + batch_size, q_seq_len, _ = q.shape + k_seq_len = k.size(1) + v_seq_len = v.size(1) + + k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k) + q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k) + v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k) + + k_proj = k_proj.transpose(2, 1) + q_proj = q_proj.transpose(2, 1) + v_proj = v_proj.transpose(2, 1) + + sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2)) + sim_scores = sim_scores / self.d_k ** 0.5 + + if mask is not None: + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) + sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4) + sim_scores = F.softmax(input=sim_scores, dim=-1) + + attention_applied = torch.matmul(sim_scores, v_proj) + attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\ + .view(batch_size, q_seq_len, self.d_model) + + out = self.out_linear(attention_applied_concatenated) + return out + +class FeedForward(nn.Module): + def __init__(self, d_model, d_ff, dropout_perc): + super(FeedForward, self).__init__() + self.linear_1 = nn.Linear(d_model, d_ff) + self.dropout = nn.Dropout(dropout_perc) + self.linear_2 = nn.Linear(d_ff, d_model) + + def forward(self, x): + x = self.dropout(F.relu(self.linear_1(x))) + x = self.linear_2(x) + return x diff --git a/models/swin_transformer_mod.py b/models/swin_transformer_mod.py new file mode 100644 index 0000000..6b16abc --- /dev/null +++ b/models/swin_transformer_mod.py @@ -0,0 +1,655 @@ +# -------------------------------------------------------- +# Swin Transformer +# Copyright (c) 2021 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- + +# --------------------------------- +# All credits due to Ze Liu: https://github.com/microsoft/Swin-Transformer +# and the additional sources: +# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py +# https://github.com/yukimasano/PASS/blob/main/vision_transformer.py +# --------------------------------- + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + + +class DropPath(nn.Module): + def __init__(self, drop_prob): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + if not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +import collections.abc +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +import warnings +import math +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official repo master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + return (1. + math.erf(x / math.sqrt(2.))) / 2. + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + # mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, + embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False): + super().__init__() + + # self.num_classes = num_classes + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + input_resolution=(patches_resolution[0] // (2 ** i_layer), + patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + # self.avgpool = nn.AdaptiveAvgPool1d(1) + # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + # x = self.avgpool(x.transpose(1, 2)) # B C 1 + # x = torch.flatten(x, 1) + return x + + def forward(self, x): + x = self.forward_features(x) + #x = self.head(x) + return x + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + #flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) + #flops += self.num_features * self.num_classes + return flops