import os import numpy as np import torch import torch.fft from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler from .melspec import build_mel_spec_layer from .noise import NoiseData from .ir import AIR, MicIRP from .preprocess import preprocess_music class NumpyMemmapDataset(Dataset): def __init__(self, path, dtype): self.path = path self.f = np.memmap(self.path, dtype=dtype) def __len__(self): return len(self.f) def __getitem__(self, idx): return self.f[idx] # need these, otherwise the pickler will read the whole data in memory def __getstate__(self): return {'path': self.path, 'dtype': self.f.dtype} def __setstate__(self, d): self.__init__(d['path'], d['dtype']) # data augmentation on music dataset class MusicSegmentDataset(Dataset): def __init__(self, params, train_val): # load some configs assert train_val in {'train', 'validate'} sample_rate = params['sample_rate'] self.augmented = True self.eval_time_shift = True self.segment_size = int(params['segment_size'] * sample_rate) self.hop_size = int(params['hop_size'] * sample_rate) self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb self.params = params # get fft size needed for reverb fftconv_n = 1024 air_len = int(params['air']['length'] * sample_rate) ir_len = int(params['micirp']['length'] * sample_rate) fft_needed = self.segment_size + self.pad_start + air_len + ir_len while fftconv_n < fft_needed: fftconv_n *= 2 self.fftconv_n = fftconv_n # datasets data augmentation cache_dir = params['cache_dir'] os.makedirs(cache_dir, exist_ok=True) if params['noise'][train_val]: self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir) else: self.noise = None if params['air'][train_val]: self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) else: self.air = None if params['micirp'][train_val]: self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) else: self.micirp = None # Load music dataset as memory mapped file file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0] file_name = os.path.join(cache_dir, '1' + file_name) if os.path.exists(file_name + '.npy'): print('load cached music from %s.bin' % file_name) else: preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name) self.f = NumpyMemmapDataset(file_name + '.bin', np.int16) # some segmentation settings song_len = np.load(file_name + '.npy') self.cues = [] # start location of segment i self.offset_left = [] # allowed left shift of segment i self.offset_right = [] # allowed right shift of segment i self.song_range = [] # range of song i is (start time, end time, start idx, end idx) t = 0 for duration in song_len: num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size start_cue = len(self.cues) for idx in range(num_segs): my_time = idx * self.hop_size self.cues.append(t + my_time) self.offset_left.append(my_time) self.offset_right.append(duration - my_time) end_cue = len(self.cues) self.song_range.append((t, t + duration, start_cue, end_cue)) t += duration # convert to torch Tensor to allow sharing memory self.cues = torch.LongTensor(self.cues) self.offset_left = torch.LongTensor(self.offset_left) self.offset_right = torch.LongTensor(self.offset_right) # setup mel spectrogram self.mel = build_mel_spec_layer(params) def get_single_segment(self, idx, offset, length): cue = int(self.cues[idx]) + offset left = int(self.offset_left[idx]) + offset right = int(self.offset_right[idx]) - offset # choose a segment # segment looks like: # cue # v # [ pad_start | segment_size | ... ] # \--- time_offset ---/ segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)] segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)]) # convert 16 bit to float return segment * np.float32(1/32768) def __getitem__(self, indices): if self.eval_time_shift: # collect all segments. It is faster to do batch processing shift_range = self.segment_size//2 x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices] # random time offset only on augmented segment, as a query audio segment_size = self.pad_start + self.segment_size offset1 = [self.segment_size//4] * len(x) offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() else: # collect all segments. It is faster to do batch processing x = [self.get_single_segment(i, 0, self.time_offset) for i in indices] # random time offset shift_range = self.time_offset - self.segment_size segment_size = self.pad_start + self.segment_size offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)] x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32)) x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)] x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32)) if self.augmented: # background noise if self.noise is not None: x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max']) # impulse response spec = torch.fft.rfft(x_aug, self.fftconv_n) if self.air is not None: spec *= self.air.random_choose(spec.shape[0]) if self.micirp is not None: spec *= self.micirp.random_choose(spec.shape[0]) x_aug = torch.fft.irfft(spec, self.fftconv_n) x_aug = x_aug[..., self.pad_start:segment_size] # output [x1_orig, x1_aug, x2_orig, x2_aug, ...] x = [x_orig, x_aug] if self.augmented else [x_orig] x = torch.stack(x, dim=1) if self.mel is not None: return self.mel(x) return x def fan_si_le(self): raise NotImplementedError('煩死了') def zuo_bu_chu_lai(self): raise NotImplementedError('做不起來') def __len__(self): return len(self.cues) def get_num_songs(self): return len(self.song_range) def get_song_segments(self, song_id): return self.song_range[song_id][2:4] def preload_song(self, song_id): start, end, _, _ = self.song_range[song_id] return self.f[start : end].copy() class TwoStageShuffler(Sampler): def __init__(self, music_data: MusicSegmentDataset, shuffle_size): self.music_data = music_data self.shuffle_size = shuffle_size self.shuffle = True self.loaded = set() self.generator = torch.Generator() self.generator2 = torch.Generator() def set_epoch(self, epoch): self.generator.manual_seed(42 + epoch) self.generator2.manual_seed(42 + epoch) def __len__(self): return len(self.music_data) def preload(self, song_id): if song_id not in self.loaded: self.music_data.preload_song(song_id) self.loaded.add(song_id) def baseline_shuffle(self): # the same as DataLoader with shuffle=True, but with a music preloader # 2021/10/18 sorry I have to remove preloader #for song in range(self.music_data.get_num_songs()): # self.preload(song) yield from torch.randperm(len(self), generator=self.generator).tolist() def shuffling_iter(self): # shuffle song list first shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator) # split song list into chunks chunks = torch.split(shuffle_song, self.shuffle_size) for nc, songs in enumerate(chunks): # sort songs to make preloader read more sequential songs = torch.sort(songs)[0].tolist() buf = [] for song in songs: # collect segment ids seg_start, seg_end = self.music_data.get_song_segments(song) buf += list(range(seg_start, seg_end)) if nc == 0: # load first chunk for song in songs: self.preload(song) # shuffle segments from song chunk shuffle_segs = torch.randperm(len(buf), generator=self.generator2) shuffle_segs = [buf[x] for x in shuffle_segs] preload_cnt = 0 for i, idx in enumerate(shuffle_segs): # output shuffled segment idx yield idx # preload next chunk while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size: song = shuffle_song[len(self.loaded)].item() self.preload(song) preload_cnt += 1 def non_shuffling_iter(self): # just return 0 ... len(dataset)-1 yield from range(len(self)) def __iter__(self): if self.shuffle: if self.shuffle_size is None: return self.baseline_shuffle() else: return self.shuffling_iter() return self.non_shuffling_iter() # since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder class SegmentedDataLoader: def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): assert train_val in {'train', 'validate'} self.dataset = MusicSegmentDataset(configs, train_val) assert configs['batch_size'] % 2 == 0 self.batch_size = configs['batch_size'] self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) self.num_workers = num_workers self.configs = configs self.pin_memory = pin_memory self.prefetch_factor = prefetch_factor # you can change shuffle to True/False self.shuffle = True # you can change augmented to True/False self.augmented = True # you can change eval time shift to True/False self.eval_time_shift = False self.loader = DataLoader( self.dataset, sampler=self.sampler, batch_size=None, num_workers=self.num_workers, pin_memory=self.pin_memory, prefetch_factor=self.prefetch_factor ) def set_epoch(self, epoch): self.shuffler.set_epoch(epoch) def __iter__(self): self.dataset.augmented = self.augmented self.dataset.eval_time_shift = self.eval_time_shift self.shuffler.shuffle = self.shuffle return iter(self.loader) def __len__(self): return len(self.loader)