logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

64 lines
1.9 KiB

2 years ago
import torch
import torchaudio
class MelSpec(torch.nn.Module):
def __init__(self,
sample_rate=8000,
stft_n=1024,
stft_hop=256,
f_min=300,
f_max=4000,
n_mels=256,
naf_mode=False,
mel_log='log',
spec_norm='l2'):
super(MelSpec, self).__init__()
self.naf_mode = naf_mode
self.mel_log = mel_log
self.spec_norm = spec_norm
self.mel = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=stft_n,
hop_length=stft_hop,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
window_fn=torch.hann_window,
power = 1 if naf_mode else 2,
pad_mode = 'constant' if naf_mode else 'reflect',
norm = 'slaney' if naf_mode else None,
mel_scale = 'slaney' if naf_mode else 'htk'
)
def forward(self, x):
# normalize volume
p = 1e999 if self.spec_norm == 'max' else 2
x = torch.nn.functional.normalize(x, p=p, dim=-1)
if self.naf_mode:
x = self.mel(x) + 0.06
else:
x = self.mel(x) + 1e-8
if self.mel_log == 'log10':
x = torch.log10(x)
elif self.mel_log == 'log':
x = torch.log(x)
if self.spec_norm == 'max':
x = x - torch.amax(x, dim=(-2,-1), keepdim=True)
return x
def build_mel_spec_layer(params):
return MelSpec(
sample_rate = params['sample_rate'],
stft_n = params['stft_n'],
stft_hop = params['stft_hop'],
f_min = params['f_min'],
f_max = params['f_max'],
n_mels = params['n_mels'],
naf_mode = params.get('naf_mode', False),
mel_log = params.get('mel_log', 'log'),
spec_norm = params.get('spec_norm', 'l2')
)