logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

23 lines
930 B

import torch
def create_pad_mask(mask_size, pad_along_row_input, pad_along_column_input, rank):
batch_size, output_seq_len, input_seq_len = mask_size
mask = torch.ones(size=(batch_size, output_seq_len, input_seq_len), dtype=torch.int8).to(rank)
for batch_idx in range(batch_size):
mask[batch_idx, :, (input_seq_len - pad_along_column_input[batch_idx]):] = 0
mask[batch_idx, (output_seq_len - pad_along_row_input[batch_idx]):, :] = 0
return mask
def create_no_peak_and_pad_mask(mask_size, num_pads, rank):
batch_size, seq_len, seq_len = mask_size
mask = torch.tril(torch.ones(size=(seq_len, seq_len), dtype=torch.int8),
diagonal=0).unsqueeze(0).repeat(batch_size, 1, 1).to(rank)
for batch_idx in range(batch_size):
mask[batch_idx, :, seq_len - num_pads[batch_idx]:] = 0
mask[batch_idx, (seq_len - num_pads[batch_idx]):, :] = 0
return mask