nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
302 lines
12 KiB
302 lines
12 KiB
2 years ago
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.fft
|
||
|
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler
|
||
|
|
||
|
from .melspec import build_mel_spec_layer
|
||
|
from .noise import NoiseData
|
||
|
from .ir import AIR, MicIRP
|
||
|
from .preprocess import preprocess_music
|
||
|
|
||
|
class NumpyMemmapDataset(Dataset):
|
||
|
def __init__(self, path, dtype):
|
||
|
self.path = path
|
||
|
self.f = np.memmap(self.path, dtype=dtype)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.f)
|
||
|
|
||
|
def __getitem__(self, idx):
|
||
|
return self.f[idx]
|
||
|
|
||
|
# need these, otherwise the pickler will read the whole data in memory
|
||
|
def __getstate__(self):
|
||
|
return {'path': self.path, 'dtype': self.f.dtype}
|
||
|
|
||
|
def __setstate__(self, d):
|
||
|
self.__init__(d['path'], d['dtype'])
|
||
|
|
||
|
# data augmentation on music dataset
|
||
|
class MusicSegmentDataset(Dataset):
|
||
|
def __init__(self, params, train_val):
|
||
|
# load some configs
|
||
|
assert train_val in {'train', 'validate'}
|
||
|
sample_rate = params['sample_rate']
|
||
|
self.augmented = True
|
||
|
self.eval_time_shift = True
|
||
|
self.segment_size = int(params['segment_size'] * sample_rate)
|
||
|
self.hop_size = int(params['hop_size'] * sample_rate)
|
||
|
self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios
|
||
|
self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb
|
||
|
self.params = params
|
||
|
|
||
|
# get fft size needed for reverb
|
||
|
fftconv_n = 1024
|
||
|
air_len = int(params['air']['length'] * sample_rate)
|
||
|
ir_len = int(params['micirp']['length'] * sample_rate)
|
||
|
fft_needed = self.segment_size + self.pad_start + air_len + ir_len
|
||
|
while fftconv_n < fft_needed:
|
||
|
fftconv_n *= 2
|
||
|
self.fftconv_n = fftconv_n
|
||
|
|
||
|
# datasets data augmentation
|
||
|
cache_dir = params['cache_dir']
|
||
|
os.makedirs(cache_dir, exist_ok=True)
|
||
|
if params['noise'][train_val]:
|
||
|
self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir)
|
||
|
else: self.noise = None
|
||
|
if params['air'][train_val]:
|
||
|
self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate)
|
||
|
else: self.air = None
|
||
|
if params['micirp'][train_val]:
|
||
|
self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate)
|
||
|
else: self.micirp = None
|
||
|
|
||
|
# Load music dataset as memory mapped file
|
||
|
file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0]
|
||
|
file_name = os.path.join(cache_dir, '1' + file_name)
|
||
|
if os.path.exists(file_name + '.npy'):
|
||
|
print('load cached music from %s.bin' % file_name)
|
||
|
else:
|
||
|
preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name)
|
||
|
self.f = NumpyMemmapDataset(file_name + '.bin', np.int16)
|
||
|
|
||
|
# some segmentation settings
|
||
|
song_len = np.load(file_name + '.npy')
|
||
|
self.cues = [] # start location of segment i
|
||
|
self.offset_left = [] # allowed left shift of segment i
|
||
|
self.offset_right = [] # allowed right shift of segment i
|
||
|
self.song_range = [] # range of song i is (start time, end time, start idx, end idx)
|
||
|
t = 0
|
||
|
for duration in song_len:
|
||
|
num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size
|
||
|
start_cue = len(self.cues)
|
||
|
for idx in range(num_segs):
|
||
|
my_time = idx * self.hop_size
|
||
|
self.cues.append(t + my_time)
|
||
|
self.offset_left.append(my_time)
|
||
|
self.offset_right.append(duration - my_time)
|
||
|
end_cue = len(self.cues)
|
||
|
self.song_range.append((t, t + duration, start_cue, end_cue))
|
||
|
t += duration
|
||
|
|
||
|
# convert to torch Tensor to allow sharing memory
|
||
|
self.cues = torch.LongTensor(self.cues)
|
||
|
self.offset_left = torch.LongTensor(self.offset_left)
|
||
|
self.offset_right = torch.LongTensor(self.offset_right)
|
||
|
|
||
|
# setup mel spectrogram
|
||
|
self.mel = build_mel_spec_layer(params)
|
||
|
|
||
|
def get_single_segment(self, idx, offset, length):
|
||
|
cue = int(self.cues[idx]) + offset
|
||
|
left = int(self.offset_left[idx]) + offset
|
||
|
right = int(self.offset_right[idx]) - offset
|
||
|
|
||
|
# choose a segment
|
||
|
# segment looks like:
|
||
|
# cue
|
||
|
# v
|
||
|
# [ pad_start | segment_size | ... ]
|
||
|
# \--- time_offset ---/
|
||
|
segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)]
|
||
|
segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)])
|
||
|
|
||
|
# convert 16 bit to float
|
||
|
return segment * np.float32(1/32768)
|
||
|
|
||
|
def __getitem__(self, indices):
|
||
|
if self.eval_time_shift:
|
||
|
# collect all segments. It is faster to do batch processing
|
||
|
shift_range = self.segment_size//2
|
||
|
x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices]
|
||
|
|
||
|
# random time offset only on augmented segment, as a query audio
|
||
|
segment_size = self.pad_start + self.segment_size
|
||
|
offset1 = [self.segment_size//4] * len(x)
|
||
|
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
|
||
|
else:
|
||
|
# collect all segments. It is faster to do batch processing
|
||
|
x = [self.get_single_segment(i, 0, self.time_offset) for i in indices]
|
||
|
|
||
|
# random time offset
|
||
|
shift_range = self.time_offset - self.segment_size
|
||
|
segment_size = self.pad_start + self.segment_size
|
||
|
offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
|
||
|
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
|
||
|
|
||
|
x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)]
|
||
|
x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32))
|
||
|
x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)]
|
||
|
x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32))
|
||
|
|
||
|
if self.augmented:
|
||
|
# background noise
|
||
|
if self.noise is not None:
|
||
|
x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max'])
|
||
|
|
||
|
# impulse response
|
||
|
spec = torch.fft.rfft(x_aug, self.fftconv_n)
|
||
|
if self.air is not None:
|
||
|
spec *= self.air.random_choose(spec.shape[0])
|
||
|
if self.micirp is not None:
|
||
|
spec *= self.micirp.random_choose(spec.shape[0])
|
||
|
x_aug = torch.fft.irfft(spec, self.fftconv_n)
|
||
|
x_aug = x_aug[..., self.pad_start:segment_size]
|
||
|
|
||
|
# output [x1_orig, x1_aug, x2_orig, x2_aug, ...]
|
||
|
x = [x_orig, x_aug] if self.augmented else [x_orig]
|
||
|
x = torch.stack(x, dim=1)
|
||
|
if self.mel is not None:
|
||
|
return self.mel(x)
|
||
|
return x
|
||
|
|
||
|
def fan_si_le(self):
|
||
|
raise NotImplementedError('煩死了')
|
||
|
|
||
|
def zuo_bu_chu_lai(self):
|
||
|
raise NotImplementedError('做不起來')
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.cues)
|
||
|
|
||
|
def get_num_songs(self):
|
||
|
return len(self.song_range)
|
||
|
|
||
|
def get_song_segments(self, song_id):
|
||
|
return self.song_range[song_id][2:4]
|
||
|
|
||
|
def preload_song(self, song_id):
|
||
|
start, end, _, _ = self.song_range[song_id]
|
||
|
return self.f[start : end].copy()
|
||
|
|
||
|
class TwoStageShuffler(Sampler):
|
||
|
def __init__(self, music_data: MusicSegmentDataset, shuffle_size):
|
||
|
self.music_data = music_data
|
||
|
self.shuffle_size = shuffle_size
|
||
|
self.shuffle = True
|
||
|
self.loaded = set()
|
||
|
self.generator = torch.Generator()
|
||
|
self.generator2 = torch.Generator()
|
||
|
|
||
|
def set_epoch(self, epoch):
|
||
|
self.generator.manual_seed(42 + epoch)
|
||
|
self.generator2.manual_seed(42 + epoch)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.music_data)
|
||
|
|
||
|
def preload(self, song_id):
|
||
|
if song_id not in self.loaded:
|
||
|
self.music_data.preload_song(song_id)
|
||
|
self.loaded.add(song_id)
|
||
|
|
||
|
def baseline_shuffle(self):
|
||
|
# the same as DataLoader with shuffle=True, but with a music preloader
|
||
|
# 2021/10/18 sorry I have to remove preloader
|
||
|
#for song in range(self.music_data.get_num_songs()):
|
||
|
# self.preload(song)
|
||
|
|
||
|
yield from torch.randperm(len(self), generator=self.generator).tolist()
|
||
|
|
||
|
def shuffling_iter(self):
|
||
|
# shuffle song list first
|
||
|
shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator)
|
||
|
|
||
|
# split song list into chunks
|
||
|
chunks = torch.split(shuffle_song, self.shuffle_size)
|
||
|
for nc, songs in enumerate(chunks):
|
||
|
# sort songs to make preloader read more sequential
|
||
|
songs = torch.sort(songs)[0].tolist()
|
||
|
|
||
|
buf = []
|
||
|
for song in songs:
|
||
|
# collect segment ids
|
||
|
seg_start, seg_end = self.music_data.get_song_segments(song)
|
||
|
buf += list(range(seg_start, seg_end))
|
||
|
|
||
|
if nc == 0:
|
||
|
# load first chunk
|
||
|
for song in songs:
|
||
|
self.preload(song)
|
||
|
|
||
|
# shuffle segments from song chunk
|
||
|
shuffle_segs = torch.randperm(len(buf), generator=self.generator2)
|
||
|
shuffle_segs = [buf[x] for x in shuffle_segs]
|
||
|
preload_cnt = 0
|
||
|
for i, idx in enumerate(shuffle_segs):
|
||
|
# output shuffled segment idx
|
||
|
yield idx
|
||
|
|
||
|
# preload next chunk
|
||
|
while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size:
|
||
|
song = shuffle_song[len(self.loaded)].item()
|
||
|
self.preload(song)
|
||
|
preload_cnt += 1
|
||
|
|
||
|
def non_shuffling_iter(self):
|
||
|
# just return 0 ... len(dataset)-1
|
||
|
yield from range(len(self))
|
||
|
|
||
|
def __iter__(self):
|
||
|
if self.shuffle:
|
||
|
if self.shuffle_size is None:
|
||
|
return self.baseline_shuffle()
|
||
|
else:
|
||
|
return self.shuffling_iter()
|
||
|
return self.non_shuffling_iter()
|
||
|
|
||
|
# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder
|
||
|
class SegmentedDataLoader:
|
||
|
def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2):
|
||
|
assert train_val in {'train', 'validate'}
|
||
|
self.dataset = MusicSegmentDataset(configs, train_val)
|
||
|
assert configs['batch_size'] % 2 == 0
|
||
|
self.batch_size = configs['batch_size']
|
||
|
self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size'])
|
||
|
self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False)
|
||
|
self.num_workers = num_workers
|
||
|
self.configs = configs
|
||
|
self.pin_memory = pin_memory
|
||
|
self.prefetch_factor = prefetch_factor
|
||
|
|
||
|
# you can change shuffle to True/False
|
||
|
self.shuffle = True
|
||
|
# you can change augmented to True/False
|
||
|
self.augmented = True
|
||
|
# you can change eval time shift to True/False
|
||
|
self.eval_time_shift = False
|
||
|
|
||
|
self.loader = DataLoader(
|
||
|
self.dataset,
|
||
|
sampler=self.sampler,
|
||
|
batch_size=None,
|
||
|
num_workers=self.num_workers,
|
||
|
pin_memory=self.pin_memory,
|
||
|
prefetch_factor=self.prefetch_factor
|
||
|
)
|
||
|
|
||
|
def set_epoch(self, epoch):
|
||
|
self.shuffler.set_epoch(epoch)
|
||
|
|
||
|
def __iter__(self):
|
||
|
self.dataset.augmented = self.augmented
|
||
|
self.dataset.eval_time_shift = self.eval_time_shift
|
||
|
self.shuffler.shuffle = self.shuffle
|
||
|
return iter(self.loader)
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.loader)
|