logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

65 lines
2.9 KiB

from torch.nn import functional as F
from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
import torch
from torch import nn
from models.transformer.attention import MultiHeadAttention
class EncoderLayer(nn.Module):
def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
attention_module=None, attention_module_kwargs=None):
super(EncoderLayer, self).__init__()
self.identity_map_reordering = identity_map_reordering
self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
attention_module=attention_module,
attention_module_kwargs=attention_module_kwargs)
self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
ff = self.pwff(att)
return ff
class Encoder(nn.Module):
def __init__(self, N, max_len, d_in, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
identity_map_reordering=False, attention_module=None, attention_module_kwargs=None,
with_pe=False, with_mesh=False):
super(Encoder, self).__init__()
self.d_in = d_in
self.d_model = d_model
self.dropout = dropout
self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
identity_map_reordering=identity_map_reordering,
attention_module=attention_module,
attention_module_kwargs=attention_module_kwargs)
for _ in range(N)])
self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, self.d_in, 0), freeze=True)
self.fc = nn.Linear(d_in, self.d_model)
self.dropout = nn.Dropout(p=self.dropout)
self.layer_norm = nn.LayerNorm(self.d_model)
self.with_pe = with_pe
self.with_mesh = with_mesh
def forward(self, input):
# input (b_s, seq_len, d_in)
b_s, seq_len = input.shape[:2]
seq = torch.arange(1, seq_len + 1, device=input.device).view(1, -1).expand(b_s, -1) # (b_s, seq_len)
out = input
if self.with_pe:
out = out + self.pos_emb(seq)
out = F.relu(self.fc(out))
out = self.dropout(out)
out = self.layer_norm(out)
outs = list()
for l in self.layers:
out = l(out, out, out)
if self.with_mesh:
outs.append(out.unsqueeze(1))
if self.with_mesh:
outs = torch.cat(outs, 1)
return outs, None
return out, None