logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

301 lines
12 KiB

import os
import numpy as np
import torch
import torch.fft
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler
from .melspec import build_mel_spec_layer
from .noise import NoiseData
from .ir import AIR, MicIRP
from .preprocess import preprocess_music
class NumpyMemmapDataset(Dataset):
def __init__(self, path, dtype):
self.path = path
self.f = np.memmap(self.path, dtype=dtype)
def __len__(self):
return len(self.f)
def __getitem__(self, idx):
return self.f[idx]
# need these, otherwise the pickler will read the whole data in memory
def __getstate__(self):
return {'path': self.path, 'dtype': self.f.dtype}
def __setstate__(self, d):
self.__init__(d['path'], d['dtype'])
# data augmentation on music dataset
class MusicSegmentDataset(Dataset):
def __init__(self, params, train_val):
# load some configs
assert train_val in {'train', 'validate'}
sample_rate = params['sample_rate']
self.augmented = True
self.eval_time_shift = True
self.segment_size = int(params['segment_size'] * sample_rate)
self.hop_size = int(params['hop_size'] * sample_rate)
self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios
self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb
self.params = params
# get fft size needed for reverb
fftconv_n = 1024
air_len = int(params['air']['length'] * sample_rate)
ir_len = int(params['micirp']['length'] * sample_rate)
fft_needed = self.segment_size + self.pad_start + air_len + ir_len
while fftconv_n < fft_needed:
fftconv_n *= 2
self.fftconv_n = fftconv_n
# datasets data augmentation
cache_dir = params['cache_dir']
os.makedirs(cache_dir, exist_ok=True)
if params['noise'][train_val]:
self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir)
else: self.noise = None
if params['air'][train_val]:
self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate)
else: self.air = None
if params['micirp'][train_val]:
self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate)
else: self.micirp = None
# Load music dataset as memory mapped file
file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0]
file_name = os.path.join(cache_dir, '1' + file_name)
if os.path.exists(file_name + '.npy'):
print('load cached music from %s.bin' % file_name)
else:
preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name)
self.f = NumpyMemmapDataset(file_name + '.bin', np.int16)
# some segmentation settings
song_len = np.load(file_name + '.npy')
self.cues = [] # start location of segment i
self.offset_left = [] # allowed left shift of segment i
self.offset_right = [] # allowed right shift of segment i
self.song_range = [] # range of song i is (start time, end time, start idx, end idx)
t = 0
for duration in song_len:
num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size
start_cue = len(self.cues)
for idx in range(num_segs):
my_time = idx * self.hop_size
self.cues.append(t + my_time)
self.offset_left.append(my_time)
self.offset_right.append(duration - my_time)
end_cue = len(self.cues)
self.song_range.append((t, t + duration, start_cue, end_cue))
t += duration
# convert to torch Tensor to allow sharing memory
self.cues = torch.LongTensor(self.cues)
self.offset_left = torch.LongTensor(self.offset_left)
self.offset_right = torch.LongTensor(self.offset_right)
# setup mel spectrogram
self.mel = build_mel_spec_layer(params)
def get_single_segment(self, idx, offset, length):
cue = int(self.cues[idx]) + offset
left = int(self.offset_left[idx]) + offset
right = int(self.offset_right[idx]) - offset
# choose a segment
# segment looks like:
# cue
# v
# [ pad_start | segment_size | ... ]
# \--- time_offset ---/
segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)]
segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)])
# convert 16 bit to float
return segment * np.float32(1/32768)
def __getitem__(self, indices):
if self.eval_time_shift:
# collect all segments. It is faster to do batch processing
shift_range = self.segment_size//2
x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices]
# random time offset only on augmented segment, as a query audio
segment_size = self.pad_start + self.segment_size
offset1 = [self.segment_size//4] * len(x)
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
else:
# collect all segments. It is faster to do batch processing
x = [self.get_single_segment(i, 0, self.time_offset) for i in indices]
# random time offset
shift_range = self.time_offset - self.segment_size
segment_size = self.pad_start + self.segment_size
offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)]
x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32))
x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)]
x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32))
if self.augmented:
# background noise
if self.noise is not None:
x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max'])
# impulse response
spec = torch.fft.rfft(x_aug, self.fftconv_n)
if self.air is not None:
spec *= self.air.random_choose(spec.shape[0])
if self.micirp is not None:
spec *= self.micirp.random_choose(spec.shape[0])
x_aug = torch.fft.irfft(spec, self.fftconv_n)
x_aug = x_aug[..., self.pad_start:segment_size]
# output [x1_orig, x1_aug, x2_orig, x2_aug, ...]
x = [x_orig, x_aug] if self.augmented else [x_orig]
x = torch.stack(x, dim=1)
if self.mel is not None:
return self.mel(x)
return x
def fan_si_le(self):
raise NotImplementedError('煩死了')
def zuo_bu_chu_lai(self):
raise NotImplementedError('做不起來')
def __len__(self):
return len(self.cues)
def get_num_songs(self):
return len(self.song_range)
def get_song_segments(self, song_id):
return self.song_range[song_id][2:4]
def preload_song(self, song_id):
start, end, _, _ = self.song_range[song_id]
return self.f[start : end].copy()
class TwoStageShuffler(Sampler):
def __init__(self, music_data: MusicSegmentDataset, shuffle_size):
self.music_data = music_data
self.shuffle_size = shuffle_size
self.shuffle = True
self.loaded = set()
self.generator = torch.Generator()
self.generator2 = torch.Generator()
def set_epoch(self, epoch):
self.generator.manual_seed(42 + epoch)
self.generator2.manual_seed(42 + epoch)
def __len__(self):
return len(self.music_data)
def preload(self, song_id):
if song_id not in self.loaded:
self.music_data.preload_song(song_id)
self.loaded.add(song_id)
def baseline_shuffle(self):
# the same as DataLoader with shuffle=True, but with a music preloader
# 2021/10/18 sorry I have to remove preloader
#for song in range(self.music_data.get_num_songs()):
# self.preload(song)
yield from torch.randperm(len(self), generator=self.generator).tolist()
def shuffling_iter(self):
# shuffle song list first
shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator)
# split song list into chunks
chunks = torch.split(shuffle_song, self.shuffle_size)
for nc, songs in enumerate(chunks):
# sort songs to make preloader read more sequential
songs = torch.sort(songs)[0].tolist()
buf = []
for song in songs:
# collect segment ids
seg_start, seg_end = self.music_data.get_song_segments(song)
buf += list(range(seg_start, seg_end))
if nc == 0:
# load first chunk
for song in songs:
self.preload(song)
# shuffle segments from song chunk
shuffle_segs = torch.randperm(len(buf), generator=self.generator2)
shuffle_segs = [buf[x] for x in shuffle_segs]
preload_cnt = 0
for i, idx in enumerate(shuffle_segs):
# output shuffled segment idx
yield idx
# preload next chunk
while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size:
song = shuffle_song[len(self.loaded)].item()
self.preload(song)
preload_cnt += 1
def non_shuffling_iter(self):
# just return 0 ... len(dataset)-1
yield from range(len(self))
def __iter__(self):
if self.shuffle:
if self.shuffle_size is None:
return self.baseline_shuffle()
else:
return self.shuffling_iter()
return self.non_shuffling_iter()
# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder
class SegmentedDataLoader:
def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2):
assert train_val in {'train', 'validate'}
self.dataset = MusicSegmentDataset(configs, train_val)
assert configs['batch_size'] % 2 == 0
self.batch_size = configs['batch_size']
self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size'])
self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False)
self.num_workers = num_workers
self.configs = configs
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
# you can change shuffle to True/False
self.shuffle = True
# you can change augmented to True/False
self.augmented = True
# you can change eval time shift to True/False
self.eval_time_shift = False
self.loader = DataLoader(
self.dataset,
sampler=self.sampler,
batch_size=None,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor
)
def set_epoch(self, epoch):
self.shuffler.set_epoch(epoch)
def __iter__(self):
self.dataset.augmented = self.augmented
self.dataset.eval_time_shift = self.eval_time_shift
self.shuffler.shuffle = self.shuffle
return iter(self.loader)
def __len__(self):
return len(self.loader)