logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

287 lines
11 KiB

import torch
import torch.nn as nn
import math
import numpy as np
import torch.nn.functional as F
class EmbeddingLayer(nn.Module):
def __init__(self, vocab_size, d_model, dropout_perc):
super(EmbeddingLayer, self).__init__()
self.dropout = nn.Dropout(dropout_perc)
self.embed = nn.Embedding(vocab_size, d_model)
self.d_model = d_model
def forward(self, x):
return self.dropout(self.embed(x)) * math.sqrt(float(self.d_model))
class PositionalEncoder(nn.Module):
def __init__(self, d_model, max_seq_len, rank=0):
super().__init__()
assert d_model % 2 == 0, "d_model is not even, even number suggested"
self.d_model = d_model
self.pe = torch.zeros(max_seq_len, d_model).to(rank)
for pos in range(max_seq_len):
for i in range(0, d_model, 2):
self.pe.data[pos, i] = math.sin(pos / (10000.0 ** ((2.0 * i) / d_model)))
self.pe.data[pos, i + 1] = math.cos(pos / (10000.0 ** ((2.0 * i) / d_model)))
self.pe.data = self.pe.data.unsqueeze(0)
def forward(self, x):
seq_len = x.shape[1]
return self.pe.data[0, :seq_len]
class StaticExpansionBlock(nn.Module):
def __init__(self, d_model, num_enc_exp_list, dropout_perc, eps):
super().__init__()
self.d_model = d_model
self.num_enc_exp_list = num_enc_exp_list
self.query_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model)
self.bias_exp_vectors = nn.Embedding(sum(num_enc_exp_list), d_model)
self.key_embed = nn.Linear(d_model, d_model)
self.class_a_embed = nn.Linear(d_model, d_model)
self.class_b_embed = nn.Linear(d_model, d_model)
self.selector_embed = nn.Linear(d_model, d_model)
self.dropout_class_a_fw = nn.Dropout(dropout_perc)
self.dropout_class_b_fw = nn.Dropout(dropout_perc)
self.dropout_class_a_bw = nn.Dropout(dropout_perc)
self.dropout_class_b_bw = nn.Dropout(dropout_perc)
self.Z_dropout = nn.Dropout(dropout_perc)
self.eps = eps
def forward(self, x, n_indexes, mask):
bs, enc_len, _ = x.shape
query_exp = self.query_exp_vectors(n_indexes)
bias_exp = self.bias_exp_vectors(n_indexes)
x_key = self.key_embed(x)
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model)
z = self.Z_dropout(z)
class_a_fw = F.relu(z)
class_b_fw = F.relu(-z)
class_a_fw = class_a_fw.masked_fill(mask == 0, 0.0)
class_b_fw = class_b_fw.masked_fill(mask == 0, 0.0)
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps)
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps)
class_a = torch.matmul(class_a_fw, self.class_a_embed(x)) + bias_exp
class_b = torch.matmul(class_b_fw, self.class_b_embed(x)) + bias_exp
class_a = self.dropout_class_a_fw(class_a)
class_b = self.dropout_class_b_fw(class_b)
class_a_bw = F.relu(z.transpose(-2, -1))
class_b_bw = F.relu(-z.transpose(-2, -1))
accum = 0
class_a_bw_list = []
class_b_bw_list = []
for j in range(len(self.num_enc_exp_list)):
from_idx = accum
to_idx = accum + self.num_enc_exp_list[j]
accum += self.num_enc_exp_list[j]
class_a_bw_list.append(class_a_bw[:, :, from_idx:to_idx] / (class_a_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps))
class_b_bw_list.append(class_b_bw[:, :, from_idx:to_idx] / (class_b_bw[:, :, from_idx:to_idx].sum(dim=-1, keepdim=True) + self.eps))
class_a_bw = torch.cat(class_a_bw_list, dim=-1)
class_b_bw = torch.cat(class_b_bw_list, dim=-1)
class_a = torch.matmul(class_a_bw, class_a) / len(self.num_enc_exp_list)
class_b = torch.matmul(class_b_bw, class_b) / len(self.num_enc_exp_list)
class_a = self.dropout_class_a_bw(class_a)
class_b = self.dropout_class_b_bw(class_b)
selector = torch.sigmoid(self.selector_embed(x))
x_result = selector * class_a + (1 - selector) * class_b
return x_result
class EncoderLayer(nn.Module):
def __init__(self, d_model, d_ff, num_enc_exp_list, dropout_perc, eps=1e-9):
super().__init__()
self.norm_1 = nn.LayerNorm(d_model)
self.norm_2 = nn.LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout_perc)
self.dropout_2 = nn.Dropout(dropout_perc)
self.stc_exp = StaticExpansionBlock(d_model, num_enc_exp_list, dropout_perc, eps)
self.ff = FeedForward(d_model, d_ff, dropout_perc)
def forward(self, x, n_indexes, mask):
x2 = self.norm_1(x)
x = x + self.dropout_1(self.stc_exp(x=x2, n_indexes=n_indexes, mask=mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.ff(x2))
return x
class DynamicExpansionBlock(nn.Module):
def __init__(self, d_model, num_exp, dropout_perc, eps):
super().__init__()
self.d_model = d_model
self.num_exp = num_exp
self.cond_embed = nn.Linear(d_model, d_model)
self.query_exp_vectors = nn.Embedding(self.num_exp, d_model)
self.bias_exp_vectors = nn.Embedding(self.num_exp, d_model)
self.key_linear = nn.Linear(d_model, d_model)
self.class_a_embed = nn.Linear(d_model, d_model)
self.class_b_embed = nn.Linear(d_model, d_model)
self.selector_embed = nn.Linear(d_model, d_model)
self.dropout_class_a_fw = nn.Dropout(dropout_perc)
self.dropout_class_b_fw = nn.Dropout(dropout_perc)
self.dropout_class_a_bw = nn.Dropout(dropout_perc)
self.dropout_class_b_bw = nn.Dropout(dropout_perc)
self.Z_dropout = nn.Dropout(dropout_perc)
self.eps = eps
def forward(self, x, n_indexes, mask):
bs, dec_len, _ = x.shape
cond = self.cond_embed(x).view(bs, dec_len, 1, self.d_model)
query_exp = self.query_exp_vectors(n_indexes).unsqueeze(1)
bias_exp = self.bias_exp_vectors(n_indexes).unsqueeze(1)
query_exp = (query_exp + cond).view(bs, dec_len * self.num_exp, self.d_model)
bias_exp = (bias_exp + cond).view(bs, dec_len * self.num_exp, self.d_model)
x_key = self.key_linear(x)
z = torch.matmul(query_exp, x_key.transpose(-1, -2)) / np.sqrt(self.d_model)
z = self.Z_dropout(z)
mod_mask_1 = mask.unsqueeze(2).expand(bs, dec_len, self.num_exp, dec_len).contiguous(). \
view(bs, dec_len * self.num_exp, dec_len)
class_a_fw = F.relu(z)
class_b_fw = F.relu(-z)
class_a_fw = class_a_fw.masked_fill(mod_mask_1 == 0, 0.0)
class_b_fw = class_b_fw.masked_fill(mod_mask_1 == 0, 0.0)
class_a_fw = class_a_fw / (class_a_fw.sum(dim=-1, keepdim=True) + self.eps)
class_b_fw = class_b_fw / (class_b_fw.sum(dim=-1, keepdim=True) + self.eps)
class_a = torch.matmul(class_a_fw, self.class_a_embed(x))
class_b = torch.matmul(class_b_fw, self.class_b_embed(x))
class_a = self.dropout_class_a_fw(class_a)
class_b = self.dropout_class_b_fw(class_b)
mod_mask_2 = mask.unsqueeze(-1).expand(bs, dec_len, dec_len, self.num_exp).contiguous(). \
view(bs, dec_len, dec_len * self.num_exp)
class_a_bw = F.relu(z.transpose(-2, -1))
class_b_bw = F.relu(-z.transpose(-2, -1))
class_a_bw = class_a_bw.masked_fill(mod_mask_2 == 0, 0.0)
class_b_bw = class_b_bw.masked_fill(mod_mask_2 == 0, 0.0)
class_a_bw = class_a_bw / (class_a_bw.sum(dim=-1, keepdim=True) + self.eps)
class_b_bw = class_b_bw / (class_b_bw.sum(dim=-1, keepdim=True) + self.eps)
class_a = torch.matmul(class_a_bw, class_a + bias_exp)
class_b = torch.matmul(class_b_bw, class_b + bias_exp)
class_a = self.dropout_class_a_bw(class_a)
class_b = self.dropout_class_b_bw(class_b)
selector = torch.sigmoid(self.selector_embed(x))
x_result = selector * class_a + (1 - selector) * class_b
return x_result
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, num_exp, dropout_perc, eps=1e-9):
super().__init__()
self.norm_1 = nn.LayerNorm(d_model)
self.norm_2 = nn.LayerNorm(d_model)
self.norm_3 = nn.LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout_perc)
self.dropout_2 = nn.Dropout(dropout_perc)
self.dropout_3 = nn.Dropout(dropout_perc)
self.mha = MultiHeadAttention(d_model, num_heads, dropout_perc)
self.dyn_exp = DynamicExpansionBlock(d_model, num_exp, dropout_perc, eps)
self.ff = FeedForward(d_model, d_ff, dropout_perc)
def forward(self, x, n_indexes, cross_connection_x, input_attention_mask, cross_attention_mask):
# Pre-LayerNorm
x2 = self.norm_1(x)
x = x + self.dropout_1(self.dyn_exp(x=x2, n_indexes=n_indexes, mask=input_attention_mask))
x2 = self.norm_2(x)
x = x + self.dropout_2(self.mha(q=x2, k=cross_connection_x, v=cross_connection_x,
mask=cross_attention_mask))
x2 = self.norm_3(x)
x = x + self.dropout_3(self.ff(x2))
return x
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads, dropout_perc):
super(MultiHeadAttention, self).__init__()
assert d_model % num_heads == 0, "num heads must be multiple of d_model"
self.d_model = d_model
self.d_k = int(d_model / num_heads)
self.num_heads = num_heads
self.Wq = nn.Linear(d_model, self.d_k * num_heads)
self.Wk = nn.Linear(d_model, self.d_k * num_heads)
self.Wv = nn.Linear(d_model, self.d_k * num_heads)
self.out_linear = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
batch_size, q_seq_len, _ = q.shape
k_seq_len = k.size(1)
v_seq_len = v.size(1)
k_proj = self.Wk(k).view(batch_size, k_seq_len, self.num_heads, self.d_k)
q_proj = self.Wq(q).view(batch_size, q_seq_len, self.num_heads, self.d_k)
v_proj = self.Wv(v).view(batch_size, v_seq_len, self.num_heads, self.d_k)
k_proj = k_proj.transpose(2, 1)
q_proj = q_proj.transpose(2, 1)
v_proj = v_proj.transpose(2, 1)
sim_scores = torch.matmul(q_proj, k_proj.transpose(3, 2))
sim_scores = sim_scores / self.d_k ** 0.5
if mask is not None:
mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1)
sim_scores = sim_scores.masked_fill(mask == 0, value=-1e4)
sim_scores = F.softmax(input=sim_scores, dim=-1)
attention_applied = torch.matmul(sim_scores, v_proj)
attention_applied_concatenated = attention_applied.permute(0, 2, 1, 3).contiguous()\
.view(batch_size, q_seq_len, self.d_model)
out = self.out_linear(attention_applied_concatenated)
return out
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout_perc):
super(FeedForward, self).__init__()
self.linear_1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout_perc)
self.linear_2 = nn.Linear(d_ff, d_model)
def forward(self, x):
x = self.dropout(F.relu(self.linear_1(x)))
x = self.linear_2(x)
return x