nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
109 lines
4.2 KiB
109 lines
4.2 KiB
import csv
|
|
import os
|
|
import warnings
|
|
|
|
import tqdm
|
|
import numpy as np
|
|
import torch
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
import torchaudio
|
|
|
|
from .simpleutils import get_hash
|
|
from .audio import get_audio
|
|
|
|
class NoiseData:
|
|
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir):
|
|
print('loading noise dataset')
|
|
hashes = []
|
|
with open(list_csv, 'r') as fin:
|
|
reader = csv.reader(fin)
|
|
noises = []
|
|
firstrow = next(reader)
|
|
for row in reader:
|
|
noises.append(row[0])
|
|
hashes.append(get_hash(row[0]))
|
|
hash = get_hash(''.join(hashes))
|
|
#self.data = self.load_from_cache(list_csv, cache_dir, hash)
|
|
#if self.data is not None:
|
|
# print(self.data.shape)
|
|
# return
|
|
data = []
|
|
silence_threshold = 0
|
|
self.names = []
|
|
for name in tqdm.tqdm(noises):
|
|
smp, smprate = get_audio(os.path.join(noise_dir, name))
|
|
smp = torch.from_numpy(smp.astype(np.float32))
|
|
|
|
# convert to mono
|
|
smp = smp.mean(dim=0)
|
|
|
|
# strip silence start/end
|
|
abs_smp = torch.abs(smp)
|
|
if torch.max(abs_smp) <= silence_threshold:
|
|
print('%s too silent' % name)
|
|
continue
|
|
has_sound = (abs_smp > silence_threshold).to(torch.int)
|
|
start = int(torch.argmax(has_sound))
|
|
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0)))
|
|
smp = smp[max(start, 0) : end]
|
|
|
|
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
|
|
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999)
|
|
data.append(resampled)
|
|
self.names.append(name)
|
|
self.data = torch.cat(data)
|
|
self.boundary = [0] + [x.shape[0] for x in data]
|
|
self.boundary = torch.LongTensor(self.boundary).cumsum(0)
|
|
del data
|
|
#self.save_to_cache(list_csv, cache_dir, hash, self.data)
|
|
print(self.data.shape)
|
|
|
|
def load_from_cache(self, list_csv, cache_dir, hash):
|
|
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
|
|
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
|
|
if os.path.exists(loc) and os.path.exists(loc2):
|
|
with open(loc2, 'r') as fin:
|
|
read_hash = fin.read().strip()
|
|
if read_hash != hash:
|
|
return None
|
|
print('cache hit!')
|
|
return torch.from_numpy(np.fromfile(loc, dtype=np.float32))
|
|
return None
|
|
|
|
def save_to_cache(self, list_csv, cache_dir, hash, audio):
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
|
|
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
|
|
with open(loc2, 'w') as fout:
|
|
fout.write(hash)
|
|
print('save to cache')
|
|
audio.numpy().tofile(loc)
|
|
|
|
def random_choose(self, num, duration, out_name=False):
|
|
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long)
|
|
out = torch.zeros([num, duration], dtype=torch.float32)
|
|
for i in range(num):
|
|
start = int(indices[i])
|
|
end = start + duration
|
|
out[i] = self.data[start:end]
|
|
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1
|
|
if out_name:
|
|
return out, [self.names[x] for x in name_lookup]
|
|
return out
|
|
|
|
# x is a 2d array
|
|
def add_noises(self, x, snr_min, snr_max, out_name=False):
|
|
eps = 1e-12
|
|
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name)
|
|
if out_name:
|
|
noise, noise_name = noise
|
|
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt()
|
|
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt()
|
|
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max)
|
|
ratio = vol_x / vol_noise
|
|
ratio *= 10 ** -(snr / 20)
|
|
x_aug = x + ratio.unsqueeze(1) * noise
|
|
if out_name:
|
|
return x_aug, noise_name, snr
|
|
return x_aug
|