expansionnet-v2
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
23 lines
930 B
23 lines
930 B
2 years ago
|
|
||
|
import torch
|
||
|
|
||
|
|
||
|
def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank):
|
||
|
batch_size, output_seq_len, input_seq_len = mask_size
|
||
|
mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank)
|
||
|
|
||
|
for batch_idx in range(batch_size):
|
||
|
mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0
|
||
|
mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0
|
||
|
return mask
|
||
|
|
||
|
|
||
|
def create_no_peak_and_pad_mask(mask_size, num_pads, rank):
|
||
|
batch_size, seq_len, seq_len = mask_size
|
||
|
mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8),
|
||
|
diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank)
|
||
|
for batch_idx in range(batch_size):
|
||
|
mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0
|
||
|
mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0
|
||
|
return mask
|