import torch from torch import nn from torch.nn import functional as F def position_embedding(input, d_model): input = input.view(-1, 1) dim = torch.arange(d_model // 2, dtype=input.dtype, device=input.device).view(1, -1) sin = torch.sin(input / 10000 ** (2 * dim / d_model)) cos = torch.cos(input / 10000 ** (2 * dim / d_model)) out = torch.zeros((input.shape[0], d_model), device=input.device) out[:, ::2] = sin out[:, 1::2] = cos return out def sinusoid_encoding_table(max_len, d_model, padding_idx=None, dtype=torch.float32): pos = torch.arange(max_len, dtype=dtype) out = position_embedding(pos, d_model) if padding_idx is not None: out[padding_idx] = 0 return out class PositionWiseFeedForward(nn.Module): """ Position-wise feed forward layer """ def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False): super(PositionWiseFeedForward, self).__init__() self.identity_map_reordering = identity_map_reordering self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(p=dropout) self.dropout_2 = nn.Dropout(p=dropout) self.layer_norm = nn.LayerNorm(d_model) def forward(self, input): if self.identity_map_reordering: out = self.layer_norm(input) out = self.fc2(self.dropout_2(F.relu(self.fc1(out)))) out = input + self.dropout(torch.relu(out)) else: out = self.fc2(self.dropout_2(F.relu(self.fc1(input)))) out = self.dropout(out) out = self.layer_norm(input + out) return out