logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

42 lines
1.5 KiB

import torch
class SpecAugment:
def __init__(self, params):
self.freq_min = params.get('cutout_min', 0.1) # 5
self.freq_max = params.get('cutout_max', 0.5) # 20
self.time_min = params.get('cutout_min', 0.1) # 5
self.time_max = params.get('cutout_max', 0.5) # 16
self.cutout_min = params.get('cutout_min', 0.1) # 0.1
self.cutout_max = params.get('cutout_max', 0.5) # 0.4
def get_mask(self, F, T):
mask = torch.zeros(F, T)
# cutout
cutout_max = self.cutout_max
cutout_min = self.cutout_min
f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min))
f = int(f)
f0 = torch.randint(0, F - f + 1, (1,))
t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min))
t = int(t)
t0 = torch.randint(0, T - t + 1, (1,))
mask[f0:f0+f, t0:t0+t] = 1
# frequency masking
f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min))
f = int(f)
f0 = torch.randint(0, F - f + 1, (1,))
mask[f0:f0+f, :] = 1
# time masking
t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min))
t = int(t)
t0 = torch.randint(0, T - t + 1, (1,))
mask[:, t0:t0+t] = 1
return mask
def augment(self, x):
mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device)
x = x * (1 - mask)
return x