nnfp
              
                
                
            
          copied
				 11 changed files with 997 additions and 0 deletions
			
			
		@ -0,0 +1,152 @@ | 
			
		|||||
 | 
				import json | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				import subprocess | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import wave | 
			
		||||
 | 
				import io | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# because builtin wave won't read wav files with more than 2 channels | 
			
		||||
 | 
				class HackExtensibleWave: | 
			
		||||
 | 
				    def __init__(self, stream): | 
			
		||||
 | 
				        self.stream = stream | 
			
		||||
 | 
				        self.pos = 0 | 
			
		||||
 | 
				    def read(self, n): | 
			
		||||
 | 
				        r = self.stream.read(n) | 
			
		||||
 | 
				        new_pos = self.pos + len(r) | 
			
		||||
 | 
				        if self.pos < 20 and self.pos + n >= 20: | 
			
		||||
 | 
				            r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] | 
			
		||||
 | 
				        elif 20 <= self.pos < 22: | 
			
		||||
 | 
				            r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] | 
			
		||||
 | 
				        self.pos = new_pos | 
			
		||||
 | 
				        return r | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def ffmpeg_get_audio(filename): | 
			
		||||
 | 
				    error_log = open(os.devnull, 'w') | 
			
		||||
 | 
				    proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], | 
			
		||||
 | 
				        stderr=error_log, | 
			
		||||
 | 
				        stdin=open(os.devnull), | 
			
		||||
 | 
				        stdout=subprocess.PIPE, | 
			
		||||
 | 
				        bufsize=1000000) | 
			
		||||
 | 
				    try: | 
			
		||||
 | 
				        dat = proc.stdout.read() | 
			
		||||
 | 
				        wav = wave.open(HackExtensibleWave(io.BytesIO(dat))) | 
			
		||||
 | 
				        ch = wav.getnchannels() | 
			
		||||
 | 
				        rate = wav.getframerate() | 
			
		||||
 | 
				        n = wav.getnframes() | 
			
		||||
 | 
				        dat = wav.readframes(n) | 
			
		||||
 | 
				        del wav | 
			
		||||
 | 
				        samples = np.frombuffer(dat, dtype=np.int16) / 32768 | 
			
		||||
 | 
				        samples = samples.reshape([-1, ch]).T | 
			
		||||
 | 
				        return samples, rate | 
			
		||||
 | 
				    except (wave.Error, EOFError): | 
			
		||||
 | 
				        print('failed to decode %s. maybe the file is broken!' % filename) | 
			
		||||
 | 
				    return np.zeros([1, 0]), 44100 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def wave_get_audio(filename): | 
			
		||||
 | 
				    with open(filename, 'rb') as fin: | 
			
		||||
 | 
				        wav = wave.open(HackExtensibleWave(fin)) | 
			
		||||
 | 
				        smpwidth = wav.getsampwidth() | 
			
		||||
 | 
				        if smpwidth not in {1, 2, 3}: | 
			
		||||
 | 
				            return None | 
			
		||||
 | 
				        n = wav.getnframes() | 
			
		||||
 | 
				        if smpwidth == 1: | 
			
		||||
 | 
				            samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1 | 
			
		||||
 | 
				        elif smpwidth == 2: | 
			
		||||
 | 
				            samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768 | 
			
		||||
 | 
				        elif smpwidth == 3: | 
			
		||||
 | 
				            a = np.frombuffer(wav.readframes(n), dtype=np.uint8) | 
			
		||||
 | 
				            samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1) | 
			
		||||
 | 
				            del a | 
			
		||||
 | 
				            samples = samples / 8388608 | 
			
		||||
 | 
				        samples = samples.reshape([-1, wav.getnchannels()]).T | 
			
		||||
 | 
				        return samples, wav.getframerate() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def get_audio(filename): | 
			
		||||
 | 
				    if str(filename).endswith('.wav'): | 
			
		||||
 | 
				        try: | 
			
		||||
 | 
				            a = wave_get_audio(filename) | 
			
		||||
 | 
				            if a: return a | 
			
		||||
 | 
				        except Exception: | 
			
		||||
 | 
				            pass | 
			
		||||
 | 
				    return ffmpeg_get_audio(filename) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class FfmpegStream: | 
			
		||||
 | 
				    def __init__(self, proc, sample_rate, nchannels, tmpfile): | 
			
		||||
 | 
				        self.proc = proc | 
			
		||||
 | 
				        self.sample_rate = sample_rate | 
			
		||||
 | 
				        self.nchannels = nchannels | 
			
		||||
 | 
				        self.tmpfile = None | 
			
		||||
 | 
				        self.stream = self.gen_stream() | 
			
		||||
 | 
				        self.tmpfile = tmpfile | 
			
		||||
 | 
				    def __del__(self): | 
			
		||||
 | 
				        self.proc.terminate() | 
			
		||||
 | 
				        self.proc.communicate() | 
			
		||||
 | 
				        del self.proc | 
			
		||||
 | 
				        if self.tmpfile: | 
			
		||||
 | 
				            os.unlink(self.tmpfile) | 
			
		||||
 | 
				    def gen_stream(self): | 
			
		||||
 | 
				        num = yield np.array([], dtype=np.int16) | 
			
		||||
 | 
				        if not num: num = 1024 | 
			
		||||
 | 
				        while True: | 
			
		||||
 | 
				            to_read = num * self.nchannels * 2 | 
			
		||||
 | 
				            dat = self.proc.stdout.read(to_read) | 
			
		||||
 | 
				            num = yield np.frombuffer(dat, dtype=np.int16) | 
			
		||||
 | 
				            if not num: num = 1024 | 
			
		||||
 | 
				            if len(dat) < to_read: | 
			
		||||
 | 
				                break | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def ffmpeg_stream_audio(filename, is_tmp=False): | 
			
		||||
 | 
				    while 1: | 
			
		||||
 | 
				        try: | 
			
		||||
 | 
				            stderr=open(os.devnull, 'w') | 
			
		||||
 | 
				            stdin=open(os.devnull) | 
			
		||||
 | 
				            break | 
			
		||||
 | 
				        except PermissionError: | 
			
		||||
 | 
				            print('PermissionError occured, try again') | 
			
		||||
 | 
				    proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams', | 
			
		||||
 | 
				            '-select_streams', 'a', '-print_format', 'json'], | 
			
		||||
 | 
				        stderr=stderr, | 
			
		||||
 | 
				        stdin=stdin, | 
			
		||||
 | 
				        stdout=subprocess.PIPE) | 
			
		||||
 | 
				    prop = json.loads(proc.stdout.read()) | 
			
		||||
 | 
				    if 'streams' not in prop: | 
			
		||||
 | 
				        raise RuntimeError('FFmpeg cannot decode audio') | 
			
		||||
 | 
				    sample_rate = int(prop['streams'][0]['sample_rate']) | 
			
		||||
 | 
				    nchannels = prop['streams'][0]['channels'] | 
			
		||||
 | 
				    proc = subprocess.Popen(['ffmpeg', '-i', filename, | 
			
		||||
 | 
				            '-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'], | 
			
		||||
 | 
				        stderr=stderr, | 
			
		||||
 | 
				        stdin=stdin, | 
			
		||||
 | 
				        stdout=subprocess.PIPE) | 
			
		||||
 | 
				    tmpfile = None | 
			
		||||
 | 
				    if is_tmp: | 
			
		||||
 | 
				        tmpfile = filename | 
			
		||||
 | 
				    return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class WaveStream: | 
			
		||||
 | 
				    def __init__(self, filename, is_tmp=False): | 
			
		||||
 | 
				        self.is_tmp = None | 
			
		||||
 | 
				        self.file = open(filename, 'rb') | 
			
		||||
 | 
				        self.wave = wave.open(HackExtensibleWave(self.file)) | 
			
		||||
 | 
				        self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth() | 
			
		||||
 | 
				        self.sample_rate = self.wave.getframerate() | 
			
		||||
 | 
				        self.nchannels = self.wave.getnchannels() | 
			
		||||
 | 
				        if self.wave.getsampwidth() != 2: | 
			
		||||
 | 
				            raise NotImplementedError('wave stream currently only supports 16bit wav') | 
			
		||||
 | 
				        self.stream = self.gen_stream() | 
			
		||||
 | 
				        self.is_tmp = filename if is_tmp else None | 
			
		||||
 | 
				    def gen_stream(self): | 
			
		||||
 | 
				        num = yield np.array([], dtype=np.int16) | 
			
		||||
 | 
				        if not num: num = 1024 | 
			
		||||
 | 
				        while True: | 
			
		||||
 | 
				            dat = self.wave.readframes(num) | 
			
		||||
 | 
				            num = yield np.frombuffer(dat, dtype=np.int16) | 
			
		||||
 | 
				            if not num: num = 1024 | 
			
		||||
 | 
				            if len(dat) < num * self.smpsize: | 
			
		||||
 | 
				                break | 
			
		||||
 | 
				    def __del__(self): | 
			
		||||
 | 
				        if self.is_tmp: | 
			
		||||
 | 
				            os.unlink(self.is_tmp) | 
			
		||||
 | 
				
 | 
			
		||||
@ -0,0 +1,301 @@ | 
			
		|||||
 | 
				import os | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				import torch.fft | 
			
		||||
 | 
				from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from .melspec import build_mel_spec_layer | 
			
		||||
 | 
				from .noise import NoiseData | 
			
		||||
 | 
				from .ir import AIR, MicIRP | 
			
		||||
 | 
				from .preprocess import preprocess_music | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class NumpyMemmapDataset(Dataset): | 
			
		||||
 | 
				    def __init__(self, path, dtype): | 
			
		||||
 | 
				        self.path = path | 
			
		||||
 | 
				        self.f = np.memmap(self.path, dtype=dtype) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def __len__(self): | 
			
		||||
 | 
				        return len(self.f) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def __getitem__(self, idx): | 
			
		||||
 | 
				        return self.f[idx] | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    # need these, otherwise the pickler will read the whole data in memory | 
			
		||||
 | 
				    def __getstate__(self): | 
			
		||||
 | 
				        return {'path': self.path, 'dtype': self.f.dtype} | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def __setstate__(self, d): | 
			
		||||
 | 
				        self.__init__(d['path'], d['dtype']) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# data augmentation on music dataset | 
			
		||||
 | 
				class MusicSegmentDataset(Dataset): | 
			
		||||
 | 
				    def __init__(self, params, train_val): | 
			
		||||
 | 
				        # load some configs | 
			
		||||
 | 
				        assert train_val in {'train', 'validate'} | 
			
		||||
 | 
				        sample_rate = params['sample_rate'] | 
			
		||||
 | 
				        self.augmented = True | 
			
		||||
 | 
				        self.eval_time_shift = True | 
			
		||||
 | 
				        self.segment_size = int(params['segment_size'] * sample_rate) | 
			
		||||
 | 
				        self.hop_size = int(params['hop_size'] * sample_rate) | 
			
		||||
 | 
				        self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios | 
			
		||||
 | 
				        self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb | 
			
		||||
 | 
				        self.params = params | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # get fft size needed for reverb | 
			
		||||
 | 
				        fftconv_n = 1024 | 
			
		||||
 | 
				        air_len = int(params['air']['length'] * sample_rate) | 
			
		||||
 | 
				        ir_len = int(params['micirp']['length'] * sample_rate) | 
			
		||||
 | 
				        fft_needed = self.segment_size + self.pad_start + air_len + ir_len | 
			
		||||
 | 
				        while fftconv_n < fft_needed: | 
			
		||||
 | 
				            fftconv_n *= 2 | 
			
		||||
 | 
				        self.fftconv_n = fftconv_n | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # datasets data augmentation | 
			
		||||
 | 
				        cache_dir = params['cache_dir'] | 
			
		||||
 | 
				        os.makedirs(cache_dir, exist_ok=True) | 
			
		||||
 | 
				        if params['noise'][train_val]: | 
			
		||||
 | 
				            self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir) | 
			
		||||
 | 
				        else: self.noise = None | 
			
		||||
 | 
				        if params['air'][train_val]: | 
			
		||||
 | 
				            self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) | 
			
		||||
 | 
				        else: self.air = None | 
			
		||||
 | 
				        if params['micirp'][train_val]: | 
			
		||||
 | 
				            self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) | 
			
		||||
 | 
				        else: self.micirp = None | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # Load music dataset as memory mapped file | 
			
		||||
 | 
				        file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0] | 
			
		||||
 | 
				        file_name = os.path.join(cache_dir, '1' + file_name) | 
			
		||||
 | 
				        if os.path.exists(file_name + '.npy'): | 
			
		||||
 | 
				            print('load cached music from %s.bin' % file_name) | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name) | 
			
		||||
 | 
				        self.f = NumpyMemmapDataset(file_name + '.bin', np.int16) | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # some segmentation settings | 
			
		||||
 | 
				        song_len = np.load(file_name + '.npy') | 
			
		||||
 | 
				        self.cues = [] # start location of segment i | 
			
		||||
 | 
				        self.offset_left = [] # allowed left shift of segment i | 
			
		||||
 | 
				        self.offset_right = [] # allowed right shift of segment i | 
			
		||||
 | 
				        self.song_range = [] # range of song i is (start time, end time, start idx, end idx) | 
			
		||||
 | 
				        t = 0 | 
			
		||||
 | 
				        for duration in song_len: | 
			
		||||
 | 
				            num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size | 
			
		||||
 | 
				            start_cue = len(self.cues) | 
			
		||||
 | 
				            for idx in range(num_segs): | 
			
		||||
 | 
				                my_time = idx * self.hop_size | 
			
		||||
 | 
				                self.cues.append(t + my_time) | 
			
		||||
 | 
				                self.offset_left.append(my_time) | 
			
		||||
 | 
				                self.offset_right.append(duration - my_time) | 
			
		||||
 | 
				            end_cue = len(self.cues) | 
			
		||||
 | 
				            self.song_range.append((t, t + duration, start_cue, end_cue)) | 
			
		||||
 | 
				            t += duration | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # convert to torch Tensor to allow sharing memory | 
			
		||||
 | 
				        self.cues = torch.LongTensor(self.cues) | 
			
		||||
 | 
				        self.offset_left = torch.LongTensor(self.offset_left) | 
			
		||||
 | 
				        self.offset_right = torch.LongTensor(self.offset_right) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # setup mel spectrogram | 
			
		||||
 | 
				        self.mel = build_mel_spec_layer(params) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def get_single_segment(self, idx, offset, length): | 
			
		||||
 | 
				        cue = int(self.cues[idx]) + offset | 
			
		||||
 | 
				        left = int(self.offset_left[idx]) + offset | 
			
		||||
 | 
				        right = int(self.offset_right[idx]) - offset | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # choose a segment | 
			
		||||
 | 
				        # segment looks like: | 
			
		||||
 | 
				        #            cue | 
			
		||||
 | 
				        #             v | 
			
		||||
 | 
				        # [ pad_start | segment_size |   ...  ] | 
			
		||||
 | 
				        #             \---   time_offset   ---/ | 
			
		||||
 | 
				        segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)] | 
			
		||||
 | 
				        segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)]) | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # convert 16 bit to float | 
			
		||||
 | 
				        return segment * np.float32(1/32768) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def __getitem__(self, indices): | 
			
		||||
 | 
				        if self.eval_time_shift: | 
			
		||||
 | 
				            # collect all segments. It is faster to do batch processing | 
			
		||||
 | 
				            shift_range = self.segment_size//2 | 
			
		||||
 | 
				            x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				            # random time offset only on augmented segment, as a query audio | 
			
		||||
 | 
				            segment_size = self.pad_start + self.segment_size | 
			
		||||
 | 
				            offset1 = [self.segment_size//4] * len(x) | 
			
		||||
 | 
				            offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            # collect all segments. It is faster to do batch processing | 
			
		||||
 | 
				            x = [self.get_single_segment(i, 0, self.time_offset) for i in indices] | 
			
		||||
 | 
				         | 
			
		||||
 | 
				            # random time offset | 
			
		||||
 | 
				            shift_range = self.time_offset - self.segment_size | 
			
		||||
 | 
				            segment_size = self.pad_start + self.segment_size | 
			
		||||
 | 
				            offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() | 
			
		||||
 | 
				            offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)] | 
			
		||||
 | 
				        x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32)) | 
			
		||||
 | 
				        x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)] | 
			
		||||
 | 
				        x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        if self.augmented: | 
			
		||||
 | 
				            # background noise | 
			
		||||
 | 
				            if self.noise is not None: | 
			
		||||
 | 
				                x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max']) | 
			
		||||
 | 
				         | 
			
		||||
 | 
				            # impulse response | 
			
		||||
 | 
				            spec = torch.fft.rfft(x_aug, self.fftconv_n) | 
			
		||||
 | 
				            if self.air is not None: | 
			
		||||
 | 
				                spec *= self.air.random_choose(spec.shape[0]) | 
			
		||||
 | 
				            if self.micirp is not None: | 
			
		||||
 | 
				                spec *= self.micirp.random_choose(spec.shape[0]) | 
			
		||||
 | 
				            x_aug = torch.fft.irfft(spec, self.fftconv_n) | 
			
		||||
 | 
				            x_aug = x_aug[..., self.pad_start:segment_size] | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # output [x1_orig, x1_aug, x2_orig, x2_aug, ...] | 
			
		||||
 | 
				        x = [x_orig, x_aug] if self.augmented else [x_orig] | 
			
		||||
 | 
				        x = torch.stack(x, dim=1) | 
			
		||||
 | 
				        if self.mel is not None: | 
			
		||||
 | 
				            return self.mel(x) | 
			
		||||
 | 
				        return x | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def fan_si_le(self): | 
			
		||||
 | 
				        raise NotImplementedError('煩死了') | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def zuo_bu_chu_lai(self): | 
			
		||||
 | 
				        raise NotImplementedError('做不起來') | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def __len__(self): | 
			
		||||
 | 
				        return len(self.cues) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def get_num_songs(self): | 
			
		||||
 | 
				        return len(self.song_range) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def get_song_segments(self, song_id): | 
			
		||||
 | 
				        return self.song_range[song_id][2:4] | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def preload_song(self, song_id): | 
			
		||||
 | 
				        start, end, _, _ = self.song_range[song_id] | 
			
		||||
 | 
				        return self.f[start : end].copy() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class TwoStageShuffler(Sampler): | 
			
		||||
 | 
				    def __init__(self, music_data: MusicSegmentDataset, shuffle_size): | 
			
		||||
 | 
				        self.music_data = music_data | 
			
		||||
 | 
				        self.shuffle_size = shuffle_size | 
			
		||||
 | 
				        self.shuffle = True | 
			
		||||
 | 
				        self.loaded = set() | 
			
		||||
 | 
				        self.generator = torch.Generator() | 
			
		||||
 | 
				        self.generator2 = torch.Generator() | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def set_epoch(self, epoch): | 
			
		||||
 | 
				        self.generator.manual_seed(42 + epoch) | 
			
		||||
 | 
				        self.generator2.manual_seed(42 + epoch) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __len__(self): | 
			
		||||
 | 
				        return len(self.music_data) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def preload(self, song_id): | 
			
		||||
 | 
				        if song_id not in self.loaded: | 
			
		||||
 | 
				            self.music_data.preload_song(song_id) | 
			
		||||
 | 
				            self.loaded.add(song_id) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def baseline_shuffle(self): | 
			
		||||
 | 
				        # the same as DataLoader with shuffle=True, but with a music preloader | 
			
		||||
 | 
				        # 2021/10/18 sorry I have to remove preloader | 
			
		||||
 | 
				        #for song in range(self.music_data.get_num_songs()): | 
			
		||||
 | 
				        #    self.preload(song) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        yield from torch.randperm(len(self), generator=self.generator).tolist() | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def shuffling_iter(self): | 
			
		||||
 | 
				        # shuffle song list first | 
			
		||||
 | 
				        shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator) | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # split song list into chunks | 
			
		||||
 | 
				        chunks = torch.split(shuffle_song, self.shuffle_size) | 
			
		||||
 | 
				        for nc, songs in enumerate(chunks): | 
			
		||||
 | 
				            # sort songs to make preloader read more sequential | 
			
		||||
 | 
				            songs = torch.sort(songs)[0].tolist() | 
			
		||||
 | 
				             | 
			
		||||
 | 
				            buf = [] | 
			
		||||
 | 
				            for song in songs: | 
			
		||||
 | 
				                # collect segment ids | 
			
		||||
 | 
				                seg_start, seg_end = self.music_data.get_song_segments(song) | 
			
		||||
 | 
				                buf += list(range(seg_start, seg_end)) | 
			
		||||
 | 
				             | 
			
		||||
 | 
				            if nc == 0: | 
			
		||||
 | 
				                # load first chunk | 
			
		||||
 | 
				                for song in songs: | 
			
		||||
 | 
				                    self.preload(song) | 
			
		||||
 | 
				             | 
			
		||||
 | 
				            # shuffle segments from song chunk | 
			
		||||
 | 
				            shuffle_segs = torch.randperm(len(buf), generator=self.generator2) | 
			
		||||
 | 
				            shuffle_segs = [buf[x] for x in shuffle_segs] | 
			
		||||
 | 
				            preload_cnt = 0 | 
			
		||||
 | 
				            for i, idx in enumerate(shuffle_segs): | 
			
		||||
 | 
				                # output shuffled segment idx | 
			
		||||
 | 
				                yield idx | 
			
		||||
 | 
				                 | 
			
		||||
 | 
				                # preload next chunk | 
			
		||||
 | 
				                while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size: | 
			
		||||
 | 
				                    song = shuffle_song[len(self.loaded)].item() | 
			
		||||
 | 
				                    self.preload(song) | 
			
		||||
 | 
				                    preload_cnt += 1 | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def non_shuffling_iter(self): | 
			
		||||
 | 
				        # just return 0 ... len(dataset)-1 | 
			
		||||
 | 
				        yield from range(len(self)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __iter__(self): | 
			
		||||
 | 
				        if self.shuffle: | 
			
		||||
 | 
				            if self.shuffle_size is None: | 
			
		||||
 | 
				                return self.baseline_shuffle() | 
			
		||||
 | 
				            else: | 
			
		||||
 | 
				                return self.shuffling_iter() | 
			
		||||
 | 
				        return self.non_shuffling_iter() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder | 
			
		||||
 | 
				class SegmentedDataLoader: | 
			
		||||
 | 
				    def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): | 
			
		||||
 | 
				        assert train_val in {'train', 'validate'} | 
			
		||||
 | 
				        self.dataset = MusicSegmentDataset(configs, train_val) | 
			
		||||
 | 
				        assert configs['batch_size'] % 2 == 0 | 
			
		||||
 | 
				        self.batch_size = configs['batch_size'] | 
			
		||||
 | 
				        self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) | 
			
		||||
 | 
				        self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) | 
			
		||||
 | 
				        self.num_workers = num_workers | 
			
		||||
 | 
				        self.configs = configs | 
			
		||||
 | 
				        self.pin_memory = pin_memory | 
			
		||||
 | 
				        self.prefetch_factor = prefetch_factor | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # you can change shuffle to True/False | 
			
		||||
 | 
				        self.shuffle = True | 
			
		||||
 | 
				        # you can change augmented to True/False | 
			
		||||
 | 
				        self.augmented = True | 
			
		||||
 | 
				        # you can change eval time shift to True/False | 
			
		||||
 | 
				        self.eval_time_shift = False | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        self.loader = DataLoader( | 
			
		||||
 | 
				            self.dataset, | 
			
		||||
 | 
				            sampler=self.sampler, | 
			
		||||
 | 
				            batch_size=None, | 
			
		||||
 | 
				            num_workers=self.num_workers, | 
			
		||||
 | 
				            pin_memory=self.pin_memory, | 
			
		||||
 | 
				            prefetch_factor=self.prefetch_factor | 
			
		||||
 | 
				        ) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def set_epoch(self, epoch): | 
			
		||||
 | 
				        self.shuffler.set_epoch(epoch) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __iter__(self): | 
			
		||||
 | 
				        self.dataset.augmented = self.augmented | 
			
		||||
 | 
				        self.dataset.eval_time_shift = self.eval_time_shift | 
			
		||||
 | 
				        self.shuffler.shuffle = self.shuffle | 
			
		||||
 | 
				        return iter(self.loader) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __len__(self): | 
			
		||||
 | 
				        return len(self.loader) | 
			
		||||
@ -0,0 +1,89 @@ | 
			
		|||||
 | 
				import argparse | 
			
		||||
 | 
				import csv | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				import warnings | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import scipy.io | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				import torch.fft | 
			
		||||
 | 
				with warnings.catch_warnings(): | 
			
		||||
 | 
				    warnings.simplefilter("ignore") | 
			
		||||
 | 
				    import torchaudio | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from .audio import get_audio | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class AIR: | 
			
		||||
 | 
				    def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000): | 
			
		||||
 | 
				        print('loading Aachen IR dataset') | 
			
		||||
 | 
				        with open(list_csv, 'r') as fin: | 
			
		||||
 | 
				            reader = csv.reader(fin) | 
			
		||||
 | 
				            airs = [] | 
			
		||||
 | 
				            firstrow = next(reader) | 
			
		||||
 | 
				            for row in reader: | 
			
		||||
 | 
				                airs.append(row[0]) | 
			
		||||
 | 
				        data = [] | 
			
		||||
 | 
				        to_len = int(length * sample_rate) | 
			
		||||
 | 
				        self.names = [] | 
			
		||||
 | 
				        for name in airs: | 
			
		||||
 | 
				            mat = scipy.io.loadmat(os.path.join(air_dir, name)) | 
			
		||||
 | 
				            h_air = torch.tensor(mat['h_air'].astype(np.float32)) | 
			
		||||
 | 
				            assert h_air.shape[0] == 1 | 
			
		||||
 | 
				            h_air = h_air[0] | 
			
		||||
 | 
				            air_info = mat['air_info'] | 
			
		||||
 | 
				            fs = int(air_info['fs'][0][0][0][0]) | 
			
		||||
 | 
				            self.names.append(str(air_info['room'][0][0][0])) | 
			
		||||
 | 
				            resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air) | 
			
		||||
 | 
				            truncated = resampled[0:to_len] | 
			
		||||
 | 
				            freqd = torch.fft.rfft(truncated, fftconv_n) | 
			
		||||
 | 
				            data.append(freqd) | 
			
		||||
 | 
				        self.data = torch.stack(data) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def random_choose(self, num): | 
			
		||||
 | 
				        indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) | 
			
		||||
 | 
				        return self.data[indices] | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def random_choose_name(self): | 
			
		||||
 | 
				        index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item() | 
			
		||||
 | 
				        return self.data[index], self.names[index] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class MicIRP: | 
			
		||||
 | 
				    def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000): | 
			
		||||
 | 
				        print('loading microphone IR dataset') | 
			
		||||
 | 
				        with open(list_csv, 'r') as fin: | 
			
		||||
 | 
				            reader = csv.reader(fin) | 
			
		||||
 | 
				            mics = [] | 
			
		||||
 | 
				            firstrow = next(reader) | 
			
		||||
 | 
				            for row in reader: | 
			
		||||
 | 
				                mics.append(row[0]) | 
			
		||||
 | 
				        data = [] | 
			
		||||
 | 
				        to_len = int(length * sample_rate) | 
			
		||||
 | 
				        for name in mics: | 
			
		||||
 | 
				            smp, smprate = get_audio(os.path.join(mic_dir, name)) | 
			
		||||
 | 
				            smp = torch.FloatTensor(smp).mean(dim=0) | 
			
		||||
 | 
				            resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) | 
			
		||||
 | 
				            truncated = resampled[0:to_len] | 
			
		||||
 | 
				            freqd = torch.fft.rfft(truncated, fftconv_n) | 
			
		||||
 | 
				            data.append(freqd) | 
			
		||||
 | 
				        self.data = torch.stack(data) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def random_choose(self, num): | 
			
		||||
 | 
				        indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) | 
			
		||||
 | 
				        return self.data[indices] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				if __name__ == '__main__': | 
			
		||||
 | 
				    args = argparse.ArgumentParser() | 
			
		||||
 | 
				    args.add_argument('air') | 
			
		||||
 | 
				    args.add_argument('out') | 
			
		||||
 | 
				    args = args.parse_args() | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    with open(args.out, 'w', encoding='utf8', newline='\n') as fout: | 
			
		||||
 | 
				        writer = csv.writer(fout) | 
			
		||||
 | 
				        writer.writerow(['file']) | 
			
		||||
 | 
				        files = [] | 
			
		||||
 | 
				        for name in os.listdir(args.air): | 
			
		||||
 | 
				            if name.endswith('.mat'): | 
			
		||||
 | 
				                files.append(name) | 
			
		||||
 | 
				        files.sort() | 
			
		||||
 | 
				        for name in files: | 
			
		||||
 | 
				            writer.writerow([name]) | 
			
		||||
@ -0,0 +1,63 @@ | 
			
		|||||
 | 
				import torch | 
			
		||||
 | 
				import torchaudio | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class MelSpec(torch.nn.Module): | 
			
		||||
 | 
				    def __init__(self, | 
			
		||||
 | 
				            sample_rate=8000, | 
			
		||||
 | 
				            stft_n=1024, | 
			
		||||
 | 
				            stft_hop=256, | 
			
		||||
 | 
				            f_min=300, | 
			
		||||
 | 
				            f_max=4000, | 
			
		||||
 | 
				            n_mels=256, | 
			
		||||
 | 
				            naf_mode=False, | 
			
		||||
 | 
				            mel_log='log', | 
			
		||||
 | 
				            spec_norm='l2'): | 
			
		||||
 | 
				        super(MelSpec, self).__init__() | 
			
		||||
 | 
				        self.naf_mode = naf_mode | 
			
		||||
 | 
				        self.mel_log = mel_log | 
			
		||||
 | 
				        self.spec_norm = spec_norm | 
			
		||||
 | 
				        self.mel = torchaudio.transforms.MelSpectrogram( | 
			
		||||
 | 
				            sample_rate=sample_rate, | 
			
		||||
 | 
				            n_fft=stft_n, | 
			
		||||
 | 
				            hop_length=stft_hop, | 
			
		||||
 | 
				            f_min=f_min, | 
			
		||||
 | 
				            f_max=f_max, | 
			
		||||
 | 
				            n_mels=n_mels, | 
			
		||||
 | 
				            window_fn=torch.hann_window, | 
			
		||||
 | 
				            power = 1 if naf_mode else 2, | 
			
		||||
 | 
				            pad_mode = 'constant' if naf_mode else 'reflect', | 
			
		||||
 | 
				            norm = 'slaney' if naf_mode else None, | 
			
		||||
 | 
				            mel_scale = 'slaney' if naf_mode else 'htk' | 
			
		||||
 | 
				        ) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def forward(self, x): | 
			
		||||
 | 
				        # normalize volume | 
			
		||||
 | 
				        p = 1e999 if self.spec_norm == 'max' else 2 | 
			
		||||
 | 
				        x = torch.nn.functional.normalize(x, p=p, dim=-1) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        if self.naf_mode: | 
			
		||||
 | 
				            x = self.mel(x) + 0.06 | 
			
		||||
 | 
				        else: | 
			
		||||
 | 
				            x = self.mel(x) + 1e-8 | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        if self.mel_log == 'log10': | 
			
		||||
 | 
				            x = torch.log10(x) | 
			
		||||
 | 
				        elif self.mel_log == 'log': | 
			
		||||
 | 
				            x = torch.log(x) | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        if self.spec_norm == 'max': | 
			
		||||
 | 
				            x = x - torch.amax(x, dim=(-2,-1), keepdim=True) | 
			
		||||
 | 
				        return x | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def build_mel_spec_layer(params): | 
			
		||||
 | 
				    return MelSpec( | 
			
		||||
 | 
				        sample_rate = params['sample_rate'], | 
			
		||||
 | 
				        stft_n = params['stft_n'], | 
			
		||||
 | 
				        stft_hop = params['stft_hop'], | 
			
		||||
 | 
				        f_min = params['f_min'], | 
			
		||||
 | 
				        f_max = params['f_max'], | 
			
		||||
 | 
				        n_mels = params['n_mels'], | 
			
		||||
 | 
				        naf_mode = params.get('naf_mode', False), | 
			
		||||
 | 
				        mel_log = params.get('mel_log', 'log'), | 
			
		||||
 | 
				        spec_norm = params.get('spec_norm', 'l2') | 
			
		||||
 | 
				    ) | 
			
		||||
@ -0,0 +1,109 @@ | 
			
		|||||
 | 
				import csv | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				import warnings | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import tqdm | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				with warnings.catch_warnings(): | 
			
		||||
 | 
				    warnings.simplefilter("ignore") | 
			
		||||
 | 
				    import torchaudio | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from .simpleutils import get_hash | 
			
		||||
 | 
				from .audio import get_audio | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class NoiseData: | 
			
		||||
 | 
				    def __init__(self, noise_dir, list_csv, sample_rate, cache_dir): | 
			
		||||
 | 
				        print('loading noise dataset') | 
			
		||||
 | 
				        hashes = [] | 
			
		||||
 | 
				        with open(list_csv, 'r') as fin: | 
			
		||||
 | 
				            reader = csv.reader(fin) | 
			
		||||
 | 
				            noises = [] | 
			
		||||
 | 
				            firstrow = next(reader) | 
			
		||||
 | 
				            for row in reader: | 
			
		||||
 | 
				                noises.append(row[0]) | 
			
		||||
 | 
				                hashes.append(get_hash(row[0])) | 
			
		||||
 | 
				        hash = get_hash(''.join(hashes)) | 
			
		||||
 | 
				        #self.data = self.load_from_cache(list_csv, cache_dir, hash) | 
			
		||||
 | 
				        #if self.data is not None: | 
			
		||||
 | 
				        #    print(self.data.shape) | 
			
		||||
 | 
				        #    return | 
			
		||||
 | 
				        data = [] | 
			
		||||
 | 
				        silence_threshold = 0 | 
			
		||||
 | 
				        self.names = [] | 
			
		||||
 | 
				        for name in tqdm.tqdm(noises): | 
			
		||||
 | 
				            smp, smprate = get_audio(os.path.join(noise_dir, name)) | 
			
		||||
 | 
				            smp = torch.from_numpy(smp.astype(np.float32)) | 
			
		||||
 | 
				             | 
			
		||||
 | 
				            # convert to mono | 
			
		||||
 | 
				            smp = smp.mean(dim=0) | 
			
		||||
 | 
				             | 
			
		||||
 | 
				            # strip silence start/end | 
			
		||||
 | 
				            abs_smp = torch.abs(smp) | 
			
		||||
 | 
				            if torch.max(abs_smp) <= silence_threshold: | 
			
		||||
 | 
				                print('%s too silent' % name) | 
			
		||||
 | 
				                continue | 
			
		||||
 | 
				            has_sound = (abs_smp > silence_threshold).to(torch.int) | 
			
		||||
 | 
				            start = int(torch.argmax(has_sound)) | 
			
		||||
 | 
				            end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0))) | 
			
		||||
 | 
				            smp = smp[max(start, 0) : end] | 
			
		||||
 | 
				             | 
			
		||||
 | 
				            resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) | 
			
		||||
 | 
				            resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999) | 
			
		||||
 | 
				            data.append(resampled) | 
			
		||||
 | 
				            self.names.append(name) | 
			
		||||
 | 
				        self.data = torch.cat(data) | 
			
		||||
 | 
				        self.boundary = [0] + [x.shape[0] for x in data] | 
			
		||||
 | 
				        self.boundary = torch.LongTensor(self.boundary).cumsum(0) | 
			
		||||
 | 
				        del data | 
			
		||||
 | 
				        #self.save_to_cache(list_csv, cache_dir, hash, self.data) | 
			
		||||
 | 
				        print(self.data.shape) | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def load_from_cache(self, list_csv, cache_dir, hash): | 
			
		||||
 | 
				        loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') | 
			
		||||
 | 
				        loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') | 
			
		||||
 | 
				        if os.path.exists(loc) and os.path.exists(loc2): | 
			
		||||
 | 
				            with open(loc2, 'r') as fin: | 
			
		||||
 | 
				                read_hash = fin.read().strip() | 
			
		||||
 | 
				            if read_hash != hash: | 
			
		||||
 | 
				                return None | 
			
		||||
 | 
				            print('cache hit!') | 
			
		||||
 | 
				            return torch.from_numpy(np.fromfile(loc, dtype=np.float32)) | 
			
		||||
 | 
				        return None | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def save_to_cache(self, list_csv, cache_dir, hash, audio): | 
			
		||||
 | 
				        os.makedirs(cache_dir, exist_ok=True) | 
			
		||||
 | 
				        loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') | 
			
		||||
 | 
				        loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') | 
			
		||||
 | 
				        with open(loc2, 'w') as fout: | 
			
		||||
 | 
				            fout.write(hash) | 
			
		||||
 | 
				        print('save to cache') | 
			
		||||
 | 
				        audio.numpy().tofile(loc) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def random_choose(self, num, duration, out_name=False): | 
			
		||||
 | 
				        indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long) | 
			
		||||
 | 
				        out = torch.zeros([num, duration], dtype=torch.float32) | 
			
		||||
 | 
				        for i in range(num): | 
			
		||||
 | 
				            start = int(indices[i]) | 
			
		||||
 | 
				            end = start + duration | 
			
		||||
 | 
				            out[i] = self.data[start:end] | 
			
		||||
 | 
				        name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1 | 
			
		||||
 | 
				        if out_name: | 
			
		||||
 | 
				            return out, [self.names[x] for x in name_lookup] | 
			
		||||
 | 
				        return out | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    # x is a 2d array | 
			
		||||
 | 
				    def add_noises(self, x, snr_min, snr_max, out_name=False): | 
			
		||||
 | 
				        eps = 1e-12 | 
			
		||||
 | 
				        noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name) | 
			
		||||
 | 
				        if out_name: | 
			
		||||
 | 
				            noise, noise_name = noise | 
			
		||||
 | 
				        vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt() | 
			
		||||
 | 
				        vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt() | 
			
		||||
 | 
				        snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max) | 
			
		||||
 | 
				        ratio = vol_x / vol_noise | 
			
		||||
 | 
				        ratio *= 10 ** -(snr / 20) | 
			
		||||
 | 
				        x_aug = x + ratio.unsqueeze(1) * noise | 
			
		||||
 | 
				        if out_name: | 
			
		||||
 | 
				            return x_aug, noise_name, snr | 
			
		||||
 | 
				        return x_aug | 
			
		||||
@ -0,0 +1,55 @@ | 
			
		|||||
 | 
				import csv | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from torch.utils.data import DataLoader, Dataset | 
			
		||||
 | 
				import torchaudio | 
			
		||||
 | 
				import tqdm | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from .audio import get_audio | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class Preprocessor(Dataset): | 
			
		||||
 | 
				    def __init__(self, files, dir, sample_rate): | 
			
		||||
 | 
				        self.files = files | 
			
		||||
 | 
				        self.dir = dir | 
			
		||||
 | 
				        self.resampler = {} | 
			
		||||
 | 
				        self.sample_rate = sample_rate | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __getitem__(self, n): | 
			
		||||
 | 
				        dat = get_audio(os.path.join(self.dir, self.files[n])) | 
			
		||||
 | 
				        wav, smprate = dat | 
			
		||||
 | 
				        if smprate not in self.resampler: | 
			
		||||
 | 
				            self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate) | 
			
		||||
 | 
				        wav = torch.Tensor(wav) | 
			
		||||
 | 
				        wav = wav.mean(dim=0) | 
			
		||||
 | 
				        wav = self.resampler[smprate](torch.Tensor(wav)) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # quantize to 16 bit again | 
			
		||||
 | 
				        wav *= 32768 | 
			
		||||
 | 
				        torch.clamp(wav, -32768, 32767, out=wav) | 
			
		||||
 | 
				        wav = wav.to(torch.int16) | 
			
		||||
 | 
				        return wav | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def __len__(self): | 
			
		||||
 | 
				        return len(self.files) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out): | 
			
		||||
 | 
				    print('converting music to wav') | 
			
		||||
 | 
				    with open(music_csv) as fin: | 
			
		||||
 | 
				        reader = csv.reader(fin) | 
			
		||||
 | 
				        next(reader) | 
			
		||||
 | 
				        files = [row[0] for row in reader] | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    preprocessor = Preprocessor(files, music_dir, sample_rate) | 
			
		||||
 | 
				    loader = DataLoader(preprocessor, num_workers=4, batch_size=None) | 
			
		||||
 | 
				    out_file = open(preprocess_out + '.bin', 'wb') | 
			
		||||
 | 
				    song_lens = [] | 
			
		||||
 | 
				    for wav in tqdm.tqdm(loader): | 
			
		||||
 | 
				        # torch.set_num_threads(1) # default multithreading causes cpu contention | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        wav = wav.numpy() | 
			
		||||
 | 
				        out_file.write(wav.tobytes()) | 
			
		||||
 | 
				        song_lens.append(wav.shape[0]) | 
			
		||||
 | 
				    out_file.close() | 
			
		||||
 | 
				    np.save(preprocess_out, np.array(song_lens, dtype=np.int64)) | 
			
		||||
@ -0,0 +1,26 @@ | 
			
		|||||
 | 
				# every utils that don't use torch | 
			
		||||
 | 
				import hashlib | 
			
		||||
 | 
				import json | 
			
		||||
 | 
				import logging | 
			
		||||
 | 
				import multiprocessing as mp | 
			
		||||
 | 
				import os | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def get_hash(s): | 
			
		||||
 | 
				    m = hashlib.md5() | 
			
		||||
 | 
				    m.update(s.encode('utf8')) | 
			
		||||
 | 
				    return m.hexdigest() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def read_config(path): | 
			
		||||
 | 
				    with open(path, 'r') as fin: | 
			
		||||
 | 
				        return json.load(fin) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def init_logger(app_name): | 
			
		||||
 | 
				    os.makedirs('logs', exist_ok=True) | 
			
		||||
 | 
				    logger = mp.get_logger() | 
			
		||||
 | 
				    logger.setLevel(logging.INFO) | 
			
		||||
 | 
				    handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8") | 
			
		||||
 | 
				    handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s')) | 
			
		||||
 | 
				    logger.addHandler(handler) | 
			
		||||
@ -0,0 +1,42 @@ | 
			
		|||||
 | 
				import torch | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class SpecAugment: | 
			
		||||
 | 
				    def __init__(self, params): | 
			
		||||
 | 
				        self.freq_min = params.get('cutout_min', 0.1) # 5 | 
			
		||||
 | 
				        self.freq_max = params.get('cutout_max', 0.5) # 20 | 
			
		||||
 | 
				        self.time_min = params.get('cutout_min', 0.1) # 5 | 
			
		||||
 | 
				        self.time_max = params.get('cutout_max', 0.5) # 16 | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        self.cutout_min = params.get('cutout_min', 0.1) # 0.1 | 
			
		||||
 | 
				        self.cutout_max = params.get('cutout_max', 0.5) # 0.4 | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def get_mask(self, F, T): | 
			
		||||
 | 
				        mask = torch.zeros(F, T) | 
			
		||||
 | 
				        # cutout | 
			
		||||
 | 
				        cutout_max = self.cutout_max | 
			
		||||
 | 
				        cutout_min = self.cutout_min | 
			
		||||
 | 
				        f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) | 
			
		||||
 | 
				        f = int(f) | 
			
		||||
 | 
				        f0 = torch.randint(0, F - f + 1, (1,)) | 
			
		||||
 | 
				        t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) | 
			
		||||
 | 
				        t = int(t) | 
			
		||||
 | 
				        t0 = torch.randint(0, T - t + 1, (1,)) | 
			
		||||
 | 
				        mask[f0:f0+f, t0:t0+t] = 1 | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # frequency masking | 
			
		||||
 | 
				        f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min)) | 
			
		||||
 | 
				        f = int(f) | 
			
		||||
 | 
				        f0 = torch.randint(0, F - f + 1, (1,)) | 
			
		||||
 | 
				        mask[f0:f0+f, :] = 1 | 
			
		||||
 | 
				         | 
			
		||||
 | 
				        # time masking | 
			
		||||
 | 
				        t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min)) | 
			
		||||
 | 
				        t = int(t) | 
			
		||||
 | 
				        t0 = torch.randint(0, T - t + 1, (1,)) | 
			
		||||
 | 
				        mask[:, t0:t0+t] = 1 | 
			
		||||
 | 
				        return mask | 
			
		||||
 | 
				     | 
			
		||||
 | 
				    def augment(self, x): | 
			
		||||
 | 
				        mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device) | 
			
		||||
 | 
				        x = x * (1 - mask) | 
			
		||||
 | 
				        return x | 
			
		||||
@ -0,0 +1,152 @@ | 
			
		|||||
 | 
				## This training script is with reference to https://github.com/stdio2016/pfann | 
			
		||||
 | 
				from typing import Any | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				import numpy as np | 
			
		||||
 | 
				import torch | 
			
		||||
 | 
				from torch import nn | 
			
		||||
 | 
				from towhee.trainer.callback import Callback | 
			
		||||
 | 
				from towhee.trainer.trainer import Trainer | 
			
		||||
 | 
				from towhee.trainer.training_config import TrainingConfig | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				from .datautil.dataset_v2 import SegmentedDataLoader | 
			
		||||
 | 
				from .datautil.simpleutils import read_config | 
			
		||||
 | 
				from .datautil.specaug import SpecAugment | 
			
		||||
 | 
				from torch.cuda.amp import autocast, GradScaler | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				# other requirements: torchvision, torchmetrics==0.7.0 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def similarity_loss(y, tau): | 
			
		||||
 | 
				    a = torch.matmul(y, y.T) | 
			
		||||
 | 
				    a /= tau | 
			
		||||
 | 
				    Ls = [] | 
			
		||||
 | 
				    for i in range(y.shape[0]): | 
			
		||||
 | 
				        nn_self = torch.cat([a[i, :i], a[i, i + 1:]]) | 
			
		||||
 | 
				        softmax = torch.nn.functional.log_softmax(nn_self, dim=0) | 
			
		||||
 | 
				        Ls.append(softmax[i if i % 2 == 0 else i - 1]) | 
			
		||||
 | 
				    Ls = torch.stack(Ls) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    loss = torch.sum(Ls) / -y.shape[0] | 
			
		||||
 | 
				    return loss | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class SetEpochCallback(Callback): | 
			
		||||
 | 
				    def __init__(self, dataloader): | 
			
		||||
 | 
				        super().__init__() | 
			
		||||
 | 
				        self.dataloader = dataloader | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def on_epoch_begin(self, epochs, logs): | 
			
		||||
 | 
				        self.dataloader.set_epoch(epochs) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				class NNFPTrainer(Trainer): | 
			
		||||
 | 
				    def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, | 
			
		||||
 | 
				                 model_card, params): | 
			
		||||
 | 
				        super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, | 
			
		||||
 | 
				                         model_card) | 
			
		||||
 | 
				        self.specaug = SpecAugment(params) | 
			
		||||
 | 
				        self.scaler = GradScaler() | 
			
		||||
 | 
				        self.tau = params.get('tau', 0.05) | 
			
		||||
 | 
				        self.losses = [] | 
			
		||||
 | 
				        self.set_epoch_callback = SetEpochCallback(self.train_dataloader) | 
			
		||||
 | 
				        self.callbacks.add_callback(self.set_epoch_callback) | 
			
		||||
 | 
				        print('evaluate before fine-tune...') | 
			
		||||
 | 
				        self.evaluate(self.model, dict()) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    def train_step(self, model: nn.Module, inputs: Any) -> dict: | 
			
		||||
 | 
				        x = inputs | 
			
		||||
 | 
				        self.optimizer.zero_grad() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        x = torch.flatten(x, 0, 1) | 
			
		||||
 | 
				        x = self.specaug.augment(x) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        with autocast(): | 
			
		||||
 | 
				            y = model(x.to(self.configs.device)) | 
			
		||||
 | 
				            loss = similarity_loss(y, self.tau) | 
			
		||||
 | 
				        self.scaler.scale(loss).backward() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        self.scaler.step(self.optimizer) | 
			
		||||
 | 
				        self.scaler.update() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        self.lr_scheduler.step() | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        lossnum = float(loss.item()) | 
			
		||||
 | 
				        self.losses.append(lossnum) | 
			
		||||
 | 
				        step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0} | 
			
		||||
 | 
				        return step_logs | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    @torch.no_grad() | 
			
		||||
 | 
				    def evaluate(self, model: nn.Module, logs: dict) -> dict: | 
			
		||||
 | 
				        validate_N = 0 | 
			
		||||
 | 
				        y_embed = [] | 
			
		||||
 | 
				        minibatch = 40 | 
			
		||||
 | 
				        if torch.cuda.get_device_properties(0).total_memory > 11e9: | 
			
		||||
 | 
				            minibatch = 640 | 
			
		||||
 | 
				        for x in self.eval_dataloader: | 
			
		||||
 | 
				            x = torch.flatten(x, 0, 1) | 
			
		||||
 | 
				            for xx in torch.split(x, minibatch): | 
			
		||||
 | 
				                y = model(xx.to(self.configs.device)).cpu() | 
			
		||||
 | 
				                y_embed.append(y) | 
			
		||||
 | 
				        y_embed = torch.cat(y_embed) | 
			
		||||
 | 
				        y_embed_org = y_embed[0::2] | 
			
		||||
 | 
				        y_embed_aug = y_embed[1::2].to(self.configs.device) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        # compute validation score on GPU | 
			
		||||
 | 
				        self_score = [] | 
			
		||||
 | 
				        for embeds in torch.split(y_embed_org, 320): | 
			
		||||
 | 
				            A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) | 
			
		||||
 | 
				            self_score.append(A.diagonal(-validate_N).cpu()) | 
			
		||||
 | 
				            validate_N += embeds.shape[0] | 
			
		||||
 | 
				        self_score = torch.cat(self_score).to(self.configs.device) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				        for embeds in torch.split(y_embed_org, 320): | 
			
		||||
 | 
				            A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) | 
			
		||||
 | 
				            ranks += (A.T >= self_score).sum(dim=0) | 
			
		||||
 | 
				        acc = int((ranks == 1).sum()) | 
			
		||||
 | 
				        print('\nvalidate score: %f' % (acc / validate_N,)) | 
			
		||||
 | 
				        # logs['epoch_metric'] = acc / validate_N | 
			
		||||
 | 
				        return logs | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				def train_nnfp(model, config_json_path): | 
			
		||||
 | 
				    model.train() | 
			
		||||
 | 
				    params = read_config(config_json_path) | 
			
		||||
 | 
				    train_dataloader = SegmentedDataLoader('train', params, num_workers=4) | 
			
		||||
 | 
				    print('training data contains %d samples' % len(train_dataloader.dataset)) | 
			
		||||
 | 
				    train_dataloader.shuffle = True | 
			
		||||
 | 
				    train_dataloader.eval_time_shift = False | 
			
		||||
 | 
				    train_dataloader.augmented = True | 
			
		||||
 | 
				    train_dataloader.set_epoch(0) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    val_dataloader = SegmentedDataLoader('validate', params, num_workers=4) | 
			
		||||
 | 
				    val_dataloader.shuffle = False | 
			
		||||
 | 
				    val_dataloader.eval_time_shift = True | 
			
		||||
 | 
				    val_dataloader.set_epoch(-1) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    training_config = TrainingConfig( | 
			
		||||
 | 
				        batch_size=params['batch_size'], | 
			
		||||
 | 
				        epoch_num=params['epoch'], | 
			
		||||
 | 
				        output_dir='fine_tune_output', | 
			
		||||
 | 
				        lr_scheduler_type='cosine', | 
			
		||||
 | 
				        eval_strategy='epoch', | 
			
		||||
 | 
				        lr=params['lr'] | 
			
		||||
 | 
				    ) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    nnfp_trainer = NNFPTrainer( | 
			
		||||
 | 
				        model=model, | 
			
		||||
 | 
				        training_config=training_config, | 
			
		||||
 | 
				        train_dataset=train_dataloader.dataset, | 
			
		||||
 | 
				        eval_dataset=None, | 
			
		||||
 | 
				        train_dataloader=train_dataloader, | 
			
		||||
 | 
				        eval_dataloader=val_dataloader, | 
			
		||||
 | 
				        model_card=None, | 
			
		||||
 | 
				        params = params | 
			
		||||
 | 
				    ) | 
			
		||||
 | 
				    nnfp_trainer.set_optimizer(optimizer) | 
			
		||||
 | 
				
 | 
			
		||||
 | 
				    nnfp_trainer.train() | 
			
		||||
					Loading…
					
					
				
		Reference in new issue