logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

50 lines
1.6 KiB

import torch
from torch import nn
from torch.nn import functional as F
def position_embedding(input, d_model):
input = input.view(-1, 1)
dim = torch.arange(d_model // 2, dtype=input.dtype, device=input.device).view(1, -1)
sin = torch.sin(input / 10000 ** (2 * dim / d_model))
cos = torch.cos(input / 10000 ** (2 * dim / d_model))
out = torch.zeros((input.shape[0], d_model), device=input.device)
out[:, ::2] = sin
out[:, 1::2] = cos
return out
def sinusoid_encoding_table(max_len, d_model, padding_idx=None, dtype=torch.float32):
pos = torch.arange(max_len, dtype=dtype)
out = position_embedding(pos, d_model)
if padding_idx is not None:
out[padding_idx] = 0
return out
class PositionWiseFeedForward(nn.Module):
"""
Position-wise feed forward layer
"""
def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
super(PositionWiseFeedForward, self).__init__()
self.identity_map_reordering = identity_map_reordering
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(p=dropout)
self.dropout_2 = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, input):
if self.identity_map_reordering:
out = self.layer_norm(input)
out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
out = input + self.dropout(torch.relu(out))
else:
out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
out = self.dropout(out)
out = self.layer_norm(input + out)
return out