logo
Browse Source

init the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
8f436cc2cb
  1. 18
      __init__.py
  2. BIN
      demo_coco_tokens.pickle
  3. 55
      expansionnet_v2.py
  4. 187
      models/End_ExpansionNet_v2.py
  5. 103
      models/ExpansionNet_v2.py
  6. 241
      models/captioning_model.py
  7. 187
      models/ensemble_captioning_model.py
  8. 286
      models/layers.py
  9. 655
      models/swin_transformer_mod.py

18
__init__.py

@ -0,0 +1,18 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .expansionnet_v2 import ExpansionNetV2
def expansionnet_v2(model_name: str):
return ExpansionNetV2(model_name)

BIN
demo_coco_tokens.pickle

Binary file not shown.

55
expansionnet_v2.py

@ -0,0 +1,55 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
from pathlib import Path
import torch
from torchvision import transforms
from transformers import GPT2Tokenizer
from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
class ExpansionNetV2(NNOperator):
"""
ExpansionNet V2 image captioning operator
"""
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(Path(__file__).parent))
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
sys.path.pop()
with open('demo_coco_tokens.pickle') as fw:
self.model = End_ExpansionNet_v2(swin_img_size=img_size, swin_patch_size=4, swin_in_chans=3,
swin_embed_dim=192, swin_depths=[2, 2, 18, 2], swin_num_heads=[6, 12, 24, 48],
swin_window_size=12, swin_mlp_ratio=4., swin_qkv_bias=True, swin_qk_scale=None,
swin_drop_rate=0.0, swin_attn_drop_rate=0.0, swin_drop_path_rate=0.0,
swin_norm_layer=torch.nn.LayerNorm, swin_ape=False, swin_patch_norm=True,
swin_use_checkpoint=False,
final_swin_dim=1536,
d_model=model_args.model_dim, N_enc=model_args.N_enc,
N_dec=model_args.N_dec, num_heads=8, ff=2048,
num_exp_enc_list=[32, 64, 128, 256, 512],
num_exp_dec=16,
output_word2idx=coco_tokens['word2idx_dict'],
output_idx2word=coco_tokens['idx2word_list'],
max_seq_len=args.max_seq_len, drop_args=model_args.drop_args,
rank='cpu')

187
models/End_ExpansionNet_v2.py

@ -0,0 +1,187 @@
import torch
from models.layers import EmbeddingLayer, DecoderLayer, EncoderLayer
from utils.masking import create_pad_mask, create_no_peak_and_pad_mask
from models.captioning_model import CaptioningModel
from models.swin_transformer_mod import SwinTransformer
import torch.nn as nn
class End_ExpansionNet_v2(CaptioningModel):
def __init__(self,
# swin transf
swin_img_size, swin_patch_size, swin_in_chans,
swin_embed_dim, swin_depths, swin_num_heads,
swin_window_size, swin_mlp_ratio, swin_qkv_bias, swin_qk_scale,
swin_drop_rate, swin_attn_drop_rate, swin_drop_path_rate,
swin_norm_layer, swin_ape, swin_patch_norm,
swin_use_checkpoint,
# linear_size,
final_swin_dim,
# captioning
d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec,
output_word2idx, output_idx2word, max_seq_len, drop_args, rank=0):
super(End_ExpansionNet_v2, self).__init__()
self.swin_transf = SwinTransformer(
img_size=swin_img_size, patch_size=swin_patch_size, in_chans=swin_in_chans,
embed_dim=swin_embed_dim, depths=swin_depths, num_heads=swin_num_heads,
window_size=swin_window_size, mlp_ratio=swin_mlp_ratio, qkv_bias=swin_qkv_bias, qk_scale=swin_qk_scale,
drop_rate=swin_drop_rate, attn_drop_rate=swin_attn_drop_rate, drop_path_rate=swin_drop_path_rate,
norm_layer=swin_norm_layer, ape=swin_ape, patch_norm=swin_patch_norm,
use_checkpoint=swin_use_checkpoint)
self.output_word2idx = output_word2idx
self.output_idx2word = output_idx2word
self.max_seq_len = max_seq_len
self.num_exp_dec = num_exp_dec
self.num_exp_enc_list = num_exp_enc_list
self.N_enc = N_enc
self.N_dec = N_dec
self.d_model = d_model
self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)])
self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)])
self.input_embedder_dropout = nn.Dropout(drop_args.enc_input)
self.input_linear = torch.nn.Linear(final_swin_dim, d_model)
self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx))
self.log_softmax = nn.LogSoftmax(dim=-1)
self.out_enc_dropout = nn.Dropout(drop_args.other)
self.out_dec_dropout = nn.Dropout(drop_args.other)
self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input)
self.pos_encoder = nn.Embedding(max_seq_len, d_model)
self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model)
self.enc_reduce_norm = nn.LayerNorm(d_model)
self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model)
self.dec_reduce_norm = nn.LayerNorm(d_model)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.trained_steps = 0
self.rank = rank
self.check_required_attributes()
def forward_enc(self, enc_input, enc_input_num_pads):
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * enc_input.size(0))), "End to End case have no padding"
x = self.swin_transf(enc_input)
# --------------- Normale parte di Captioning ---------------------------------
enc_input = self.input_embedder_dropout(self.input_linear(x))
x = enc_input
enc_input_num_pads = [0] * enc_input.size(0)
max_num_enc = sum(self.num_exp_enc_list)
pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank)
pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)),
pad_along_row_input=[0] * enc_input.size(0),
pad_along_column_input=enc_input_num_pads,
rank=self.rank)
x_list = []
for i in range(self.N_enc):
x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask)
x_list.append(x)
x_list = torch.cat(x_list, dim=-1)
x = x + self.out_enc_dropout(self.enc_reduce_group(x_list))
x = self.enc_reduce_norm(x)
return x
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
assert (enc_input_num_pads is None or enc_input_num_pads == ([0] * cross_input.size(0))), "enc_input_num_pads should be no None"
enc_input_num_pads = [0] * dec_input.size(0)
no_peak_and_pad_mask = create_no_peak_and_pad_mask(
mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)),
num_pads=dec_input_num_pads,
rank=self.rank)
pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)),
pad_along_row_input=dec_input_num_pads,
pad_along_column_input=enc_input_num_pads,
rank=self.rank)
y = self.out_embedder(dec_input)
pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank)
pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank)
y = y + self.pos_encoder(pos_y)
y_list = []
for i in range(self.N_dec):
y = self.decoders[i](x=y,
n_indexes=pos_x,
cross_connection_x=cross_input,
input_attention_mask=no_peak_and_pad_mask,
cross_attention_mask=pad_mask)
y_list.append(y)
y_list = torch.cat(y_list, dim=-1)
y = y + self.out_dec_dropout(self.dec_reduce_group(y_list))
y = self.dec_reduce_norm(y)
y = self.vocab_linear(y)
if apply_log_softmax:
y = self.log_softmax(y)
return y
def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs,
sos_idx, eos_idx, max_seq_len):
bs = enc_input.size(0)
x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads)
enc_seq_len = x.size(1)
x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1])
upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank)
where_is_eos_vector = upperbound_vector.clone()
eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank)
finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int)
predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1)
predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1)
dec_input_num_pads = [0]*(bs*num_outputs)
time_step = 0
while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len:
dec_input = predicted_caption
log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True)
prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step]))
sampled_word_indexes = prob_dist.sample()
predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1)
predicted_caption_prob = torch.cat((predicted_caption_prob,
log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1)
time_step += 1
where_is_eos_vector = torch.min(where_is_eos_vector,
upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step))
finished_flag_vector = torch.max(finished_flag_vector,
(sampled_word_indexes == eos_vector).type(torch.IntTensor))
res_predicted_caption = []
for i in range(bs):
res_predicted_caption.append([])
for j in range(num_outputs):
index = i*num_outputs + j
res_predicted_caption[i].append(
predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist())
where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1)
arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank)
predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0)
res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1)
return res_predicted_caption, res_predicted_caption_prob

103
models/ExpansionNet_v2.py

@ -0,0 +1,103 @@
import torch
from models.layers import EmbeddingLayer, EncoderLayer, DecoderLayer
from utils.masking import create_pad_mask, create_no_peak_and_pad_mask
from models.captioning_model import CaptioningModel
import torch.nn as nn
class ExpansionNet_v2(CaptioningModel):
def __init__(self, d_model, N_enc, N_dec, ff, num_heads, num_exp_enc_list, num_exp_dec,
output_word2idx, output_idx2word, max_seq_len, drop_args, img_feature_dim=2048, rank=0):
super().__init__()
self.output_word2idx = output_word2idx
self.output_idx2word = output_idx2word
self.max_seq_len = max_seq_len
self.num_exp_dec = num_exp_dec
self.num_exp_enc_list = num_exp_enc_list
self.N_enc = N_enc
self.N_dec = N_dec
self.d_model = d_model
self.encoders = nn.ModuleList([EncoderLayer(d_model, ff, num_exp_enc_list, drop_args.enc) for _ in range(N_enc)])
self.decoders = nn.ModuleList([DecoderLayer(d_model, num_heads, ff, num_exp_dec, drop_args.dec) for _ in range(N_dec)])
self.input_embedder_dropout = nn.Dropout(drop_args.enc_input)
self.input_linear = torch.nn.Linear(img_feature_dim, d_model)
self.vocab_linear = torch.nn.Linear(d_model, len(output_word2idx))
self.log_softmax = nn.LogSoftmax(dim=-1)
self.out_enc_dropout = nn.Dropout(drop_args.other)
self.out_dec_dropout = nn.Dropout(drop_args.other)
self.out_embedder = EmbeddingLayer(len(output_word2idx), d_model, drop_args.dec_input)
self.pos_encoder = nn.Embedding(max_seq_len, d_model)
self.enc_reduce_group = nn.Linear(d_model * self.N_enc, d_model)
self.enc_reduce_norm = nn.LayerNorm(d_model)
self.dec_reduce_group = nn.Linear(d_model * self.N_dec, d_model)
self.dec_reduce_norm = nn.LayerNorm(d_model)
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
self.trained_steps = 0
self.rank = rank
def forward_enc(self, enc_input, enc_input_num_pads):
x = self.input_embedder_dropout(self.input_linear(enc_input))
max_num_enc = sum(self.num_exp_enc_list)
pos_x = torch.arange(max_num_enc).unsqueeze(0).expand(enc_input.size(0), max_num_enc).to(self.rank)
pad_mask = create_pad_mask(mask_size=(enc_input.size(0), max_num_enc, enc_input.size(1)),
pad_along_row_input=[0] * enc_input.size(0),
pad_along_column_input=enc_input_num_pads,
rank=self.rank)
x_list = []
for i in range(self.N_enc):
x = self.encoders[i](x=x, n_indexes=pos_x, mask=pad_mask)
x_list.append(x)
x_list = torch.cat(x_list, dim=-1)
x = x + self.out_enc_dropout(self.enc_reduce_group(x_list))
x = self.enc_reduce_norm(x)
return x
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
no_peak_and_pad_mask = create_no_peak_and_pad_mask(
mask_size=(dec_input.size(0), dec_input.size(1), dec_input.size(1)),
num_pads=dec_input_num_pads,
rank=self.rank)
pad_mask = create_pad_mask(mask_size=(dec_input.size(0), dec_input.size(1), cross_input.size(1)),
pad_along_row_input=dec_input_num_pads,
pad_along_column_input=enc_input_num_pads,
rank=self.rank)
y = self.out_embedder(dec_input)
pos_x = torch.arange(self.num_exp_dec).unsqueeze(0).expand(dec_input.size(0), self.num_exp_dec).to(self.rank)
pos_y = torch.arange(dec_input.size(1)).unsqueeze(0).expand(dec_input.size(0), dec_input.size(1)).to(self.rank)
y = y + self.pos_encoder(pos_y)
y_list = []
for i in range(self.N_dec):
y = self.decoders[i](x=y,
n_indexes=pos_x,
cross_connection_x=cross_input,
input_attention_mask=no_peak_and_pad_mask,
cross_attention_mask=pad_mask)
y_list.append(y)
y_list = torch.cat(y_list, dim=-1)
y = y + self.out_dec_dropout(self.dec_reduce_group(y_list))
y = self.dec_reduce_norm(y)
y = self.vocab_linear(y)
if apply_log_softmax:
y = self.log_softmax(y)
return y

241
models/captioning_model.py

@ -0,0 +1,241 @@
import torch
import torch.nn as nn
class CaptioningModel(nn.Module):
def __init__(self):
super(CaptioningModel, self).__init__()
# mandatory attributes
# rank: to enable multiprocessing
self.rank = None
def check_required_attributes(self):
if self.rank is None:
raise NotImplementedError("Subclass must assign the rank integer according to the GPU group")
def forward_enc(self, enc_input, enc_input_num_pads):
raise NotImplementedError
def forward_dec(self, cross_input, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
raise NotImplementedError
def forward(self, enc_x, dec_x=None,
enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False,
mode='forward', **kwargs):
if mode == 'forward':
x = self.forward_enc(enc_x, enc_x_num_pads)
y = self.forward_dec(x, enc_x_num_pads, dec_x, dec_x_num_pads, apply_log_softmax)
return y
else:
assert ('sos_idx' in kwargs.keys() or 'eos_idx' in kwargs.keys()), \
'sos and eos must be provided in case of batch sampling or beam search'
sos_idx = kwargs.get('sos_idx', -999)
eos_idx = kwargs.get('eos_idx', -999)
if mode == 'beam_search':
beam_size_arg = kwargs.get('beam_size', 5)
how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1)
beam_max_seq_len = kwargs.get('beam_max_seq_len', 20)
sample_or_max = kwargs.get('sample_or_max', 'max')
out_classes, out_logprobs = self.beam_search(
enc_x, enc_x_num_pads,
beam_size=beam_size_arg,
sos_idx=sos_idx,
eos_idx=eos_idx,
how_many_outputs=how_many_outputs_per_beam,
max_seq_len=beam_max_seq_len,
sample_or_max=sample_or_max)
return out_classes, out_logprobs
if mode == 'sampling':
how_many_outputs = kwargs.get('how_many_outputs', 1)
sample_max_seq_len = kwargs.get('sample_max_seq_len', 20)
out_classes, out_logprobs = self.get_batch_multiple_sampled_prediction(
enc_x, enc_x_num_pads, num_outputs=how_many_outputs,
sos_idx=sos_idx, eos_idx=eos_idx,
max_seq_len=sample_max_seq_len)
return out_classes, out_logprobs
def get_batch_multiple_sampled_prediction(self, enc_input, enc_input_num_pads, num_outputs,
sos_idx, eos_idx, max_seq_len):
bs, enc_seq_len, _ = enc_input.shape
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(num_outputs)]
x = self.forward_enc(enc_input=enc_input, enc_input_num_pads=enc_input_num_pads)
x = x.unsqueeze(1).expand(-1, num_outputs, -1, -1).reshape(bs * num_outputs, enc_seq_len, x.shape[-1])
upperbound_vector = torch.tensor([max_seq_len] * bs * num_outputs, dtype=torch.int).to(self.rank)
where_is_eos_vector = upperbound_vector.clone()
eos_vector = torch.tensor([eos_idx] * bs * num_outputs, dtype=torch.long).to(self.rank)
finished_flag_vector = torch.zeros(bs * num_outputs).type(torch.int)
predicted_caption = torch.tensor([sos_idx] * (bs * num_outputs), dtype=torch.long).to(self.rank).unsqueeze(-1)
predicted_caption_prob = torch.zeros(bs * num_outputs).to(self.rank).unsqueeze(-1)
dec_input_num_pads = [0]*(bs*num_outputs)
time_step = 0
while (finished_flag_vector.sum() != bs * num_outputs) and time_step < max_seq_len:
dec_input = predicted_caption
log_probs = self.forward_dec(x, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=True)
prob_dist = torch.distributions.Categorical(torch.exp(log_probs[:, time_step]))
sampled_word_indexes = prob_dist.sample()
predicted_caption = torch.cat((predicted_caption, sampled_word_indexes.unsqueeze(-1)), dim=-1)
predicted_caption_prob = torch.cat((predicted_caption_prob,
log_probs[:, time_step].gather(index=sampled_word_indexes.unsqueeze(-1), dim=-1)), dim=-1)
time_step += 1
where_is_eos_vector = torch.min(where_is_eos_vector,
upperbound_vector.masked_fill(sampled_word_indexes == eos_vector, time_step))
finished_flag_vector = torch.max(finished_flag_vector,
(sampled_word_indexes == eos_vector).type(torch.IntTensor))
# remove the elements that come after the first eos from the sequence
res_predicted_caption = []
for i in range(bs):
res_predicted_caption.append([])
for j in range(num_outputs):
index = i*num_outputs + j
res_predicted_caption[i].append(
predicted_caption[index, :where_is_eos_vector[index].item()+1].tolist())
where_is_eos_vector = where_is_eos_vector.unsqueeze(-1).expand(-1, time_step+1)
arange_tensor = torch.arange(time_step+1).unsqueeze(0).expand(bs * num_outputs, -1).to(self.rank)
predicted_caption_prob.masked_fill_(arange_tensor > where_is_eos_vector, 0.0)
res_predicted_caption_prob = predicted_caption_prob.reshape(bs, num_outputs, -1)
return res_predicted_caption, res_predicted_caption_prob
def beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx,
beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',):
assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width"
assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'"
bs = enc_input.shape[0]
cross_enc_output = self.forward_enc(enc_input, enc_input_num_pads)
# init: ------------------------------------------------------------------
init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank)
init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank)
log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads,
dec_input=init_dec_class, dec_input_num_pads=[0] * bs,
apply_log_softmax=True)
if sample_or_max == 'max':
_, topi = torch.topk(log_probs, k=beam_size, sorted=True)
else: # sample
topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False)
topi = topi.unsqueeze(1)
init_dec_class = init_dec_class.repeat(1, beam_size)
init_dec_class = init_dec_class.unsqueeze(-1)
top_beam_size_class = topi.transpose(-2, -1)
init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1)
init_dec_logprob = init_dec_logprob.repeat(1, beam_size)
init_dec_logprob = init_dec_logprob.unsqueeze(-1)
top_beam_size_logprob = log_probs.gather(dim=-1, index=topi)
top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1)
init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1)
bs, enc_seq_len, d_model = cross_enc_output.shape
cross_enc_output = cross_enc_output.unsqueeze(1)
cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1)
cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous()
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)]
# loop: -----------------------------------------------------------------
loop_dec_classes = init_dec_class
loop_dec_logprobs = init_dec_logprob
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank)
for time_step in range(2, max_seq_len):
loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous()
log_probs = self.forward_dec(cross_input=cross_enc_output, enc_input_num_pads=enc_input_num_pads,
dec_input=loop_dec_classes,
dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(),
apply_log_softmax=True)
if sample_or_max == 'max':
_, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True)
else: # sample
topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size,
replacement=False)
top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size)
top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi)
top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size)
# each sequence have now its best prediction, but some sequence may have already been terminated with EOS,
# in that case its candidates are simply ignored, and do not sum up in the "loop_dec_logprobs" their value
# are set to zero
there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \
sum(dim=-1, keepdims=True).type(torch.bool)
# if we pad with -999 its candidates logprobabilities, also the sequence containing EOS would be
# straightforwardly discarded, instead we want to keep it in the exploration. Therefore we mask with 0.0
# one arbitrary candidate word probability so the sequence probability is unchanged but it
# can still be discarded when a better candidate sequence is found
top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0)
top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0)
comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs
comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size)
_, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True)
which_sequence = topi // beam_size
which_word = topi % beam_size
loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1)
loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1)
bs_idxes = torch.arange(bs).unsqueeze(-1)
new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]]
new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]]
which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]]
which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[
[bs_idxes, which_sequence]]
which_word = which_word.unsqueeze(-1)
lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1,
index=which_word)
lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word)
new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1)
new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1)
loop_dec_classes = new_loop_dec_classes
loop_dec_logprobs = new_loop_dec_logprobs
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
# -----------------------update loop_num_elem_vector ----------------------------
loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size)
there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \
sum(dim=-1).type(torch.bool).view(bs * beam_size)
loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int)))
if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size):
break
# sort out the best result
loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1)
_, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size)
res_caption_pred = [[] for _ in range(bs)]
res_caption_logprob = [[] for _ in range(bs)]
for i in range(bs):
for j in range(how_many_outputs):
idx = topi[i, j].item()
res_caption_pred[i].append(
loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist())
res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]])
flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]]
flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True)
res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1)
return res_caption_pred, res_caption_logprob

187
models/ensemble_captioning_model.py

@ -0,0 +1,187 @@
import torch
import torch.nn as nn
from models.captioning_model import CaptioningModel
class EsembleCaptioningModel(CaptioningModel):
def __init__(self, models_list, rank):
super().__init__()
self.num_models = len(models_list)
self.models_list = models_list
self.rank = rank
self.dummy_linear = nn.Linear(1, 1)
for model in self.models_list:
model.eval()
def forward(self, enc_x, dec_x=None,
enc_x_num_pads=[0], dec_x_num_pads=[0], apply_log_softmax=False,
mode='beam_search', **kwargs):
assert (mode == 'beam_search'), "this class supports only beam search."
sos_idx = kwargs.get('sos_idx', -999)
eos_idx = kwargs.get('eos_idx', -999)
if mode == 'beam_search':
beam_size_arg = kwargs.get('beam_size', 5)
how_many_outputs_per_beam = kwargs.get('how_many_outputs', 1)
beam_max_seq_len = kwargs.get('beam_max_seq_len', 20)
sample_or_max = kwargs.get('sample_or_max', 'max')
out_classes, out_logprobs = self.ensemble_beam_search(
enc_x, enc_x_num_pads,
beam_size=beam_size_arg,
sos_idx=sos_idx,
eos_idx=eos_idx,
how_many_outputs=how_many_outputs_per_beam,
max_seq_len=beam_max_seq_len,
sample_or_max=sample_or_max)
return out_classes, out_logprobs
def forward_enc(self, enc_input, enc_input_num_pads):
x_outputs_list = []
for i in range(self.num_models):
x_outputs = self.models_list[i].forward_enc(enc_input, enc_input_num_pads)
x_outputs_list.append(x_outputs)
return x_outputs_list
def forward_dec(self, cross_input_list, enc_input_num_pads, dec_input, dec_input_num_pads, apply_log_softmax=False):
import torch.nn.functional as F
y_outputs = []
for i in range(self.num_models):
y_outputs.append(
F.softmax(self.models_list[i].forward_dec(
cross_input_list[i], enc_input_num_pads,
dec_input, dec_input_num_pads, False).unsqueeze(0), dim=-1))
avg = torch.cat(y_outputs, dim=0).mean(dim=0).log()
return avg
# quite unclean coding, to be re-factored in the future...
# since it's a bit similar to the single model case
def ensemble_beam_search(self, enc_input, enc_input_num_pads, sos_idx, eos_idx,
beam_size=3, how_many_outputs=1, max_seq_len=20, sample_or_max='max',):
assert (how_many_outputs <= beam_size), "requested output per sequence must be lower than beam width"
assert (sample_or_max == 'max' or sample_or_max == 'sample'), "argument must be chosen between \'max\' and \'sample\'"
bs = enc_input.shape[0]
# the cross_dec_input is computed once
cross_enc_output_list = self.forward_enc(enc_input, enc_input_num_pads)
# init: ------------------------------------------------------------------
init_dec_class = torch.tensor([sos_idx] * bs).unsqueeze(1).type(torch.long).to(self.rank)
init_dec_logprob = torch.tensor([0.0] * bs).unsqueeze(1).type(torch.float).to(self.rank)
log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads,
dec_input=init_dec_class, dec_input_num_pads=[0] * bs,
apply_log_softmax=True)
if sample_or_max == 'max':
_, topi = torch.topk(log_probs, k=beam_size, sorted=True)
else: # sample
topi = torch.exp(log_probs[:, 0, :]).multinomial(num_samples=beam_size, replacement=False)
topi = topi.unsqueeze(1)
init_dec_class = init_dec_class.repeat(1, beam_size)
init_dec_class = init_dec_class.unsqueeze(-1)
top_beam_size_class = topi.transpose(-2, -1)
init_dec_class = torch.cat((init_dec_class, top_beam_size_class), dim=-1)
init_dec_logprob = init_dec_logprob.repeat(1, beam_size)
init_dec_logprob = init_dec_logprob.unsqueeze(-1)
top_beam_size_logprob = log_probs.gather(dim=-1, index=topi)
top_beam_size_logprob = top_beam_size_logprob.transpose(-2, -1)
init_dec_logprob = torch.cat((init_dec_logprob, top_beam_size_logprob), dim=-1)
tmp_cross_enc_output_list = []
for cross_enc_output in cross_enc_output_list:
bs, enc_seq_len, d_model = cross_enc_output.shape
cross_enc_output = cross_enc_output.unsqueeze(1)
cross_enc_output = cross_enc_output.expand(-1, beam_size, -1, -1)
cross_enc_output = cross_enc_output.reshape(bs * beam_size, enc_seq_len, d_model).contiguous()
tmp_cross_enc_output_list.append(cross_enc_output)
cross_enc_output_list = tmp_cross_enc_output_list
enc_input_num_pads = [enc_input_num_pads[i] for i in range(bs) for _ in range(beam_size)]
loop_dec_classes = init_dec_class
loop_dec_logprobs = init_dec_logprob
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
loop_num_elem_vector = torch.tensor([2] * (bs * beam_size)).to(self.rank)
for time_step in range(2, max_seq_len):
loop_dec_classes = loop_dec_classes.reshape(bs * beam_size, time_step).contiguous()
log_probs = self.forward_dec(cross_input_list=cross_enc_output_list, enc_input_num_pads=enc_input_num_pads,
dec_input=loop_dec_classes,
dec_input_num_pads=(time_step-loop_num_elem_vector).tolist(),
apply_log_softmax=True)
if sample_or_max == 'max':
_, topi = torch.topk(log_probs[:, time_step-1, :], k=beam_size, sorted=True)
else: # sample
topi = torch.exp(log_probs[:, time_step-1, :]).multinomial(num_samples=beam_size,
replacement=False)
top_beam_size_word_classes = topi.reshape(bs, beam_size, beam_size)
top_beam_size_word_logprobs = log_probs[:, time_step-1, :].gather(dim=-1, index=topi)
top_beam_size_word_logprobs = top_beam_size_word_logprobs.reshape(bs, beam_size, beam_size)
there_is_eos_mask = (loop_dec_classes.view(bs, beam_size, time_step) == eos_idx). \
sum(dim=-1, keepdims=True).type(torch.bool)
top_beam_size_word_logprobs[:, :, 0:1].masked_fill_(there_is_eos_mask, 0.0)
top_beam_size_word_logprobs[:, :, 1:].masked_fill_(there_is_eos_mask, -999.0)
comparison_logprobs = loop_cumul_logprobs + top_beam_size_word_logprobs
comparison_logprobs = comparison_logprobs.contiguous().view(bs, beam_size * beam_size)
_, topi = torch.topk(comparison_logprobs, k=beam_size, sorted=True)
which_sequence = topi // beam_size
which_word = topi % beam_size
loop_dec_classes = loop_dec_classes.view(bs, beam_size, -1)
loop_dec_logprobs = loop_dec_logprobs.view(bs, beam_size, -1)
bs_idxes = torch.arange(bs).unsqueeze(-1)
new_loop_dec_classes = loop_dec_classes[[bs_idxes, which_sequence]]
new_loop_dec_logprobs = loop_dec_logprobs[[bs_idxes, which_sequence]]
which_sequence_top_beam_size_word_classes = top_beam_size_word_classes[[bs_idxes, which_sequence]]
which_sequence_top_beam_size_word_logprobs = top_beam_size_word_logprobs[
[bs_idxes, which_sequence]]
which_word = which_word.unsqueeze(-1)
lastword_top_beam_size_classes = which_sequence_top_beam_size_word_classes.gather(dim=-1,
index=which_word)
lastword_top_beam_size_logprobs = which_sequence_top_beam_size_word_logprobs.gather(dim=-1, index=which_word)
new_loop_dec_classes = torch.cat((new_loop_dec_classes, lastword_top_beam_size_classes), dim=-1)
new_loop_dec_logprobs = torch.cat((new_loop_dec_logprobs, lastword_top_beam_size_logprobs), dim=-1)
loop_dec_classes = new_loop_dec_classes
loop_dec_logprobs = new_loop_dec_logprobs
loop_cumul_logprobs = loop_dec_logprobs.sum(dim=-1, keepdims=True)
loop_num_elem_vector = loop_num_elem_vector.view(bs, beam_size)[[bs_idxes, which_sequence]].view(bs * beam_size)
there_was_eos_mask = (loop_dec_classes[:, :, :-1].view(bs, beam_size, time_step) == eos_idx). \
sum(dim=-1).type(torch.bool).view(bs * beam_size)
loop_num_elem_vector = loop_num_elem_vector + (1 * (1 - there_was_eos_mask.type(torch.int)))
if (loop_num_elem_vector != time_step + 1).sum() == (bs * beam_size):
break
loop_cumul_logprobs /= loop_num_elem_vector.reshape(bs, beam_size, 1)
_, topi = torch.topk(loop_cumul_logprobs.squeeze(-1), k=beam_size)
res_caption_pred = [[] for _ in range(bs)]
res_caption_logprob = [[] for _ in range(bs)]
for i in range(bs):
for j in range(how_many_outputs):
idx = topi[i, j].item()
res_caption_pred[i].append(
loop_dec_classes[i, idx, :loop_num_elem_vector[i * beam_size + idx]].tolist())
res_caption_logprob[i].append(loop_dec_logprobs[i, idx, :loop_num_elem_vector[i * beam_size + idx]])
flatted_res_caption_logprob = [logprobs for i in range(bs) for logprobs in res_caption_logprob[i]]
flatted_res_caption_logprob = torch.nn.utils.rnn.pad_sequence(flatted_res_caption_logprob, batch_first=True)
res_caption_logprob = flatted_res_caption_logprob.view(bs, how_many_outputs, -1)
return res_caption_pred, res_caption_logprob

286
models/layers.py

@ -0,0 +1,286 @@
import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
class EmbeddingLayer(nn.Module):
def __init__(self, vocab_size, d_model, dropout_perc):
super(EmbeddingLayer, self).__init__()
self.dropout = nn.Dropout(dropout_perc)
self.embed = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model))
class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_seq_len, rank=0):
super().__init__()
assert d_model % 2 == 0, "d_model is not even, even number suggested"
self.d_model = d_model
self.pe = torch.zeros(max_seq_len, d_model).to(rank)
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
self.pe.data[pos, i] = math.sin(pos / (10000.0 ** ((2.0 * i) / d_model)))
self.pe.data[pos, i + 1] = math.cos(pos / (10000.0 ** ((2.0 * i) / d_model)))
self.pe.data = self.pe.data.unsqueeze(0)
def forward(self, x):
seq_len = x.shape[1]
return self.pe.data[0, :seq_len]
class StaticExpansionBlock(nn.Module):
def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps):
super().__init__()
self.d_model = d_model
self.num_enc_exp_list = num_enc_exp_list
self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model)
self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model)
self.key_embed = nn.Linear(d_model, d_model)
self.class_a_embed = nn.Linear(d_model, d_model)
self.class_b_embed = nn.Linear(d_model, d_model)
self.selector_embed = nn.Linear(d_model, d_model)
self.dropout_class_a_fw = nn.Dropout(dropout_perc)
self.dropout_class_b_fw = nn.Dropout(dropout_perc)
self.dropout_class_a_bw = nn.Dropout(dropout_perc)
self.dropout_class_b_bw = nn.Dropout(dropout_perc)
self.Z_dropout = nn.Dropout(dropout_perc)
self.eps = eps
def forward(self, x, n_indexes, mask):
bs, enc_len, _ = x.shape
query_exp = self.query_exp_vectors(n_indexes)
bias_exp = self.bias_exp_vectors(n_indexes)
x_key = self.key_embed(x)
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model)
z = self.Z_dropout(z)
class_a_fw = F.relu(z)
class_b_fw = F.relu(-z)
class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0)
class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0)
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps)
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps)
class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp
class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp
class_a = self.dropout_class_a_fw(class_a)
class_b = self.dropout_class_b_fw(class_b)
class_a_bw = F.relu(z.transpose(-2, -1))
class_b_bw = F.relu(-z.transpose(-2, -1))
accum = 0
class_a_bw_list = []
class_b_bw_list = []
for j in range(len(self.num_enc_exp_list)):
from_idx = accum
to_idx = accum + self.num_enc_exp_list[j]
accum += self.num_enc_exp_list[j]
class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps))
class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps))
class_a_bw = torch.cat(class_a_bw_list, dim=-1)
class_b_bw = torch.cat(class_b_bw_list, dim=-1)
class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list)
class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list)
class_a = self.dropout_class_a_bw(class_a)
class_b = self.dropout_class_b_bw(class_b)
selector = torch.sigmoid(self.selector_embed(x))
x_result = selector * class_a + (1 - selector) * class_b
return x_result
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9):
super().__init__()
self.norm_1 = nn.LayerNorm(d_model)
self.norm_2 = nn.LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout_perc)
self.dropout_2 = nn.Dropout(dropout_perc)
self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps)
self.ff = FeedForward(d_model, d_ff, dropout_perc)
def forward(self, x, n_indexes, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
class DynamicExpansionBlock(nn.Module):
def __init__(self, d_model, num_exp, dropout_perc, eps):
super().__init__()
self.d_model = d_model
self.num_exp = num_exp
self.cond_embed = nn.Linear(d_model, d_model)
self.query_exp_vectors = nn.Embedding(self.num_exp, d_model)
self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.class_a_embed = nn.Linear(d_model, d_model)
self.class_b_embed = nn.Linear(d_model, d_model)
self.selector_embed = nn.Linear(d_model, d_model)
self.dropout_class_a_fw = nn.Dropout(dropout_perc)
self.dropout_class_b_fw = nn.Dropout(dropout_perc)
self.dropout_class_a_bw = nn.Dropout(dropout_perc)
self.dropout_class_b_bw = nn.Dropout(dropout_perc)
self.Z_dropout = nn.Dropout(dropout_perc)
self.eps = eps
def forward(self, x, n_indexes, mask):
bs, dec_len, _ = x.shape
cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model)
query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1)
bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1)
query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model)
bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model)
x_key = self.key_linear(x)
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model)
z = self.Z_dropout(z)
mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \
view(bs, dec_len * self.num_exp, dec_len)
class_a_fw = F.relu(z)
class_b_fw = F.relu(-z)
class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0)
class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0)
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps)
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps)
class_a = torch.matmul(class_a_fw, self.class_a_embed(x))
class_b = torch.matmul(class_b_fw, self.class_b_embed(x))
class_a = self.dropout_class_a_fw(class_a)
class_b = self.dropout_class_b_fw(class_b)
mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \
view(bs, dec_len, dec_len * self.num_exp)
class_a_bw = F.relu(z.transpose(-2, -1))
class_b_bw = F.relu(-z.transpose(-2, -1))
class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0)
class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0)
class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps)
class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps)
class_a = torch.matmul(class_a_bw, class_a + bias_exp)
class_b = torch.matmul(class_b_bw, class_b + bias_exp)
class_a = self.dropout_class_a_bw(class_a)
class_b = self.dropout_class_b_bw(class_b)
selector = torch.sigmoid(self.selector_embed(x))
x_result = selector * class_a + (1 - selector) * class_b
return x_result
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9):
super().__init__()
self.norm_1 = nn.LayerNorm(d_model)
self.norm_2 = nn.LayerNorm(d_model)
self.norm_3 = nn.LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout_perc)
self.dropout_2 = nn.Dropout(dropout_perc)
self.dropout_3 = nn.Dropout(dropout_perc)
self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc)
self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps)
self.ff = FeedForward(d_model, d_ff, dropout_perc)
def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask):
# Pre-LayerNorm
x2 = self.norm_1(x)
x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x,
mask=cross_attention_mask))
x2 = self.norm_3(x)
x = x + self.dropout_3(self.ff(x2))
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout_perc):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "num heads must be multiple of d_model"
self.d_model = d_model
self.d_k = int(d_model / num_heads)
self.num_heads = num_heads
self.Wq = nn.Linear(d_model, self.d_k * num_heads)
self.Wk = nn.Linear(d_model, self.d_k * num_heads)
self.Wv = nn.Linear(d_model, self.d_k * num_heads)
self.out_linear = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size, q_seq_len, _ = q.shape
k_seq_len = k.size(1)
v_seq_len = v.size(1)
k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k)
q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k)
v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k)
k_proj = k_proj.transpose(2, 1)
q_proj = q_proj.transpose(2, 1)
v_proj = v_proj.transpose(2, 1)
sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2))
sim_scores = sim_scores / self.d_k ** 0.5
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4)
sim_scores = F.softmax(input=sim_scores, dim=-1)
attention_applied = torch.matmul(sim_scores, v_proj)
attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\
.view(batch_size, q_seq_len, self.d_model)
out = self.out_linear(attention_applied_concatenated)
return out
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout_perc):
super(FeedForward, self).__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout_perc)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return x

655
models/swin_transformer_mod.py

@ -0,0 +1,655 @@
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
# ---------------------------------
# All credits due to Ze Liu: https://github.com/microsoft/Swin-Transformer
# and the additional sources:
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
# https://github.com/yukimasano/PASS/blob/main/vision_transformer.py
# ---------------------------------
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
class DropPath(nn.Module):
def __init__(self, drop_prob):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
import collections.abc
def to_2tuple(x):
if isinstance(x, collections.abc.Iterable):
return x
return (x, x)
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
import warnings
import math
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official repo master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
# mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if self.shift_size > 0:
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def forward(self, x):
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f"input_resolution={self.input_resolution}, dim={self.dim}"
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops
class SwinTransformer(nn.Module):
r""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def __init__(self, img_size=224, patch_size=4, in_chans=3,
embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False):
super().__init__()
# self.num_classes = num_classes
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer),
input_resolution=(patches_resolution[0] // (2 ** i_layer),
patches_resolution[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
# self.avgpool = nn.AdaptiveAvgPool1d(1)
# self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x):
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
# x = self.avgpool(x.transpose(1, 2)) # B C 1
# x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
#x = self.head(x)
return x
def flops(self):
flops = 0
flops += self.patch_embed.flops()
for i, layer in enumerate(self.layers):
flops += layer.flops()
#flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
#flops += self.num_features * self.num_classes
return flops
Loading…
Cancel
Save