logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

110 lines
4.2 KiB

2 years ago
import csv
import os
import warnings
import tqdm
import numpy as np
import torch
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torchaudio
from .simpleutils import get_hash
from .audio import get_audio
class NoiseData:
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir):
print('loading noise dataset')
hashes = []
with open(list_csv, 'r') as fin:
reader = csv.reader(fin)
noises = []
firstrow = next(reader)
for row in reader:
noises.append(row[0])
hashes.append(get_hash(row[0]))
hash = get_hash(''.join(hashes))
#self.data = self.load_from_cache(list_csv, cache_dir, hash)
#if self.data is not None:
# print(self.data.shape)
# return
data = []
silence_threshold = 0
self.names = []
for name in tqdm.tqdm(noises):
smp, smprate = get_audio(os.path.join(noise_dir, name))
smp = torch.from_numpy(smp.astype(np.float32))
# convert to mono
smp = smp.mean(dim=0)
# strip silence start/end
abs_smp = torch.abs(smp)
if torch.max(abs_smp) <= silence_threshold:
print('%s too silent' % name)
continue
has_sound = (abs_smp > silence_threshold).to(torch.int)
start = int(torch.argmax(has_sound))
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0)))
smp = smp[max(start, 0) : end]
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999)
data.append(resampled)
self.names.append(name)
self.data = torch.cat(data)
self.boundary = [0] + [x.shape[0] for x in data]
self.boundary = torch.LongTensor(self.boundary).cumsum(0)
del data
#self.save_to_cache(list_csv, cache_dir, hash, self.data)
print(self.data.shape)
def load_from_cache(self, list_csv, cache_dir, hash):
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
if os.path.exists(loc) and os.path.exists(loc2):
with open(loc2, 'r') as fin:
read_hash = fin.read().strip()
if read_hash != hash:
return None
print('cache hit!')
return torch.from_numpy(np.fromfile(loc, dtype=np.float32))
return None
def save_to_cache(self, list_csv, cache_dir, hash, audio):
os.makedirs(cache_dir, exist_ok=True)
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
with open(loc2, 'w') as fout:
fout.write(hash)
print('save to cache')
audio.numpy().tofile(loc)
def random_choose(self, num, duration, out_name=False):
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long)
out = torch.zeros([num, duration], dtype=torch.float32)
for i in range(num):
start = int(indices[i])
end = start + duration
out[i] = self.data[start:end]
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1
if out_name:
return out, [self.names[x] for x in name_lookup]
return out
# x is a 2d array
def add_noises(self, x, snr_min, snr_max, out_name=False):
eps = 1e-12
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name)
if out_name:
noise, noise_name = noise
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt()
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt()
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max)
ratio = vol_x / vol_noise
ratio *= 10 ** -(snr / 20)
x_aug = x + ratio.unsqueeze(1) * noise
if out_name:
return x_aug, noise_name, snr
return x_aug