logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

196 lines
7.9 KiB

import numpy as np
import torch
from torch import nn
from models.containers import Module
class ScaledDotProductAttention(nn.Module):
"""
Scaled dot-product attention
"""
def __init__(self, d_model, d_k, d_v, h):
'''
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
'''
super(ScaledDotProductAttention, self).__init__()
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.init_weights()
def init_weights(self):
nn.init.xavier_uniform_(self.fc_q.weight)
nn.init.xavier_uniform_(self.fc_k.weight)
nn.init.xavier_uniform_(self.fc_v.weight)
nn.init.xavier_uniform_(self.fc_o.weight)
nn.init.constant_(self.fc_q.bias, 0)
nn.init.constant_(self.fc_k.bias, 0)
nn.init.constant_(self.fc_v.bias, 0)
nn.init.constant_(self.fc_o.bias, 0)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
"""
Computes
:param queries: Queries (b_s, nq, d_model)
:param keys: Keys (b_s, nk, d_model)
:param values: Values (b_s, nk, d_model)
:param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
:return:
"""
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
if attention_weights is not None:
att = att * attention_weights
if attention_mask is not None:
att = att.masked_fill(attention_mask, -np.inf)
att = torch.softmax(att, -1)
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
class ScaledDotProductAttentionMemory(nn.Module):
"""
Scaled dot-product attention with memory
"""
def __init__(self, d_model, d_k, d_v, h, m):
"""
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
:param m: Number of memory slots
"""
super(ScaledDotProductAttentionMemory, self).__init__()
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.m = m
if self.m > 0:
self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
self.init_weights()
def init_weights(self):
nn.init.xavier_uniform_(self.fc_q.weight)
nn.init.xavier_uniform_(self.fc_k.weight)
nn.init.xavier_uniform_(self.fc_v.weight)
nn.init.xavier_uniform_(self.fc_o.weight)
nn.init.constant_(self.fc_q.bias, 0)
nn.init.constant_(self.fc_k.bias, 0)
nn.init.constant_(self.fc_v.bias, 0)
nn.init.constant_(self.fc_o.bias, 0)
if self.m > 0:
nn.init.normal_(self.m_k, 0, 1 / self.d_k)
nn.init.normal_(self.m_v, 0, 1 / self.m)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
"""
Computes
:param queries: Queries (b_s, nq, d_model)
:param keys: Keys (b_s, nk, d_model)
:param values: Values (b_s, nk, d_model)
:param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
:return:
"""
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
if self.m > 0:
m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
k = torch.cat([self.fc_k(keys), m_k], 1)
v = torch.cat([self.fc_v(values), m_v], 1)
else:
k = self.fc_k(keys)
v = self.fc_v(values)
k = k.view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = v.view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
if attention_weights is not None:
att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
if attention_mask is not None:
att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
att = torch.softmax(att, -1)
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
class MultiHeadAttention(Module):
"""
Multi-head attention layer with Dropout and Layer Normalization.
"""
def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
attention_module=None, attention_module_kwargs=None):
super(MultiHeadAttention, self).__init__()
self.identity_map_reordering = identity_map_reordering
if attention_module is not None:
if attention_module_kwargs is not None:
self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs)
else:
self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
else:
self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
self.can_be_stateful = can_be_stateful
if self.can_be_stateful:
self.register_state('running_keys', None)
self.register_state('running_values', None)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
if self.can_be_stateful and self._is_stateful:
if self.running_keys is None:
self.running_keys = keys
self.running_values = values
else:
self.running_keys = torch.cat([self.running_keys, keys], 1)
self.running_values = torch.cat([self.running_values, values], 1)
keys = self.running_keys
values = self.running_values
if self.identity_map_reordering:
q_norm = self.layer_norm(queries)
k_norm = self.layer_norm(keys)
v_norm = self.layer_norm(values)
out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
out = queries + self.dropout(torch.relu(out))
else:
out = self.attention(queries, keys, values, attention_mask, attention_weights)
out = self.dropout(out)
out = self.layer_norm(queries + out)
return out