nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
110 lines
4.2 KiB
110 lines
4.2 KiB
2 years ago
|
import csv
|
||
|
import os
|
||
|
import warnings
|
||
|
|
||
|
import tqdm
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter("ignore")
|
||
|
import torchaudio
|
||
|
|
||
|
from .simpleutils import get_hash
|
||
|
from .audio import get_audio
|
||
|
|
||
|
class NoiseData:
|
||
|
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir):
|
||
|
print('loading noise dataset')
|
||
|
hashes = []
|
||
|
with open(list_csv, 'r') as fin:
|
||
|
reader = csv.reader(fin)
|
||
|
noises = []
|
||
|
firstrow = next(reader)
|
||
|
for row in reader:
|
||
|
noises.append(row[0])
|
||
|
hashes.append(get_hash(row[0]))
|
||
|
hash = get_hash(''.join(hashes))
|
||
|
#self.data = self.load_from_cache(list_csv, cache_dir, hash)
|
||
|
#if self.data is not None:
|
||
|
# print(self.data.shape)
|
||
|
# return
|
||
|
data = []
|
||
|
silence_threshold = 0
|
||
|
self.names = []
|
||
|
for name in tqdm.tqdm(noises):
|
||
|
smp, smprate = get_audio(os.path.join(noise_dir, name))
|
||
|
smp = torch.from_numpy(smp.astype(np.float32))
|
||
|
|
||
|
# convert to mono
|
||
|
smp = smp.mean(dim=0)
|
||
|
|
||
|
# strip silence start/end
|
||
|
abs_smp = torch.abs(smp)
|
||
|
if torch.max(abs_smp) <= silence_threshold:
|
||
|
print('%s too silent' % name)
|
||
|
continue
|
||
|
has_sound = (abs_smp > silence_threshold).to(torch.int)
|
||
|
start = int(torch.argmax(has_sound))
|
||
|
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0)))
|
||
|
smp = smp[max(start, 0) : end]
|
||
|
|
||
|
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
|
||
|
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999)
|
||
|
data.append(resampled)
|
||
|
self.names.append(name)
|
||
|
self.data = torch.cat(data)
|
||
|
self.boundary = [0] + [x.shape[0] for x in data]
|
||
|
self.boundary = torch.LongTensor(self.boundary).cumsum(0)
|
||
|
del data
|
||
|
#self.save_to_cache(list_csv, cache_dir, hash, self.data)
|
||
|
print(self.data.shape)
|
||
|
|
||
|
def load_from_cache(self, list_csv, cache_dir, hash):
|
||
|
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
|
||
|
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
|
||
|
if os.path.exists(loc) and os.path.exists(loc2):
|
||
|
with open(loc2, 'r') as fin:
|
||
|
read_hash = fin.read().strip()
|
||
|
if read_hash != hash:
|
||
|
return None
|
||
|
print('cache hit!')
|
||
|
return torch.from_numpy(np.fromfile(loc, dtype=np.float32))
|
||
|
return None
|
||
|
|
||
|
def save_to_cache(self, list_csv, cache_dir, hash, audio):
|
||
|
os.makedirs(cache_dir, exist_ok=True)
|
||
|
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
|
||
|
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
|
||
|
with open(loc2, 'w') as fout:
|
||
|
fout.write(hash)
|
||
|
print('save to cache')
|
||
|
audio.numpy().tofile(loc)
|
||
|
|
||
|
def random_choose(self, num, duration, out_name=False):
|
||
|
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long)
|
||
|
out = torch.zeros([num, duration], dtype=torch.float32)
|
||
|
for i in range(num):
|
||
|
start = int(indices[i])
|
||
|
end = start + duration
|
||
|
out[i] = self.data[start:end]
|
||
|
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1
|
||
|
if out_name:
|
||
|
return out, [self.names[x] for x in name_lookup]
|
||
|
return out
|
||
|
|
||
|
# x is a 2d array
|
||
|
def add_noises(self, x, snr_min, snr_max, out_name=False):
|
||
|
eps = 1e-12
|
||
|
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name)
|
||
|
if out_name:
|
||
|
noise, noise_name = noise
|
||
|
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt()
|
||
|
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt()
|
||
|
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max)
|
||
|
ratio = vol_x / vol_noise
|
||
|
ratio *= 10 ** -(snr / 20)
|
||
|
x_aug = x + ratio.unsqueeze(1) * noise
|
||
|
if out_name:
|
||
|
return x_aug, noise_name, snr
|
||
|
return x_aug
|