From 7f5d813615a0f9e0ef4d9375dc6ad9a6a0279f77 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 12 Jan 2023 16:24:01 +0800 Subject: [PATCH] train --- datautil/__init__.py | 0 datautil/audio.py | 152 ++++++++++++++++++++ datautil/dataset_v2.py | 301 ++++++++++++++++++++++++++++++++++++++++ datautil/ir.py | 89 ++++++++++++ datautil/melspec.py | 63 +++++++++ datautil/noise.py | 109 +++++++++++++++ datautil/preprocess.py | 55 ++++++++ datautil/simpleutils.py | 26 ++++ datautil/specaug.py | 42 ++++++ nn_fingerprint.py | 8 ++ train_nnfp.py | 152 ++++++++++++++++++++ 11 files changed, 997 insertions(+) create mode 100644 datautil/__init__.py create mode 100644 datautil/audio.py create mode 100644 datautil/dataset_v2.py create mode 100644 datautil/ir.py create mode 100644 datautil/melspec.py create mode 100644 datautil/noise.py create mode 100644 datautil/preprocess.py create mode 100644 datautil/simpleutils.py create mode 100644 datautil/specaug.py create mode 100644 train_nnfp.py diff --git a/datautil/__init__.py b/datautil/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/datautil/audio.py b/datautil/audio.py new file mode 100644 index 0000000..4e5f384 --- /dev/null +++ b/datautil/audio.py @@ -0,0 +1,152 @@ +import json +import os +import subprocess + +import numpy as np +import wave +import io + + +# because builtin wave won't read wav files with more than 2 channels +class HackExtensibleWave: + def __init__(self, stream): + self.stream = stream + self.pos = 0 + def read(self, n): + r = self.stream.read(n) + new_pos = self.pos + len(r) + if self.pos < 20 and self.pos + n >= 20: + r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] + elif 20 <= self.pos < 22: + r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] + self.pos = new_pos + return r + +def ffmpeg_get_audio(filename): + error_log = open(os.devnull, 'w') + proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], + stderr=error_log, + stdin=open(os.devnull), + stdout=subprocess.PIPE, + bufsize=1000000) + try: + dat = proc.stdout.read() + wav = wave.open(HackExtensibleWave(io.BytesIO(dat))) + ch = wav.getnchannels() + rate = wav.getframerate() + n = wav.getnframes() + dat = wav.readframes(n) + del wav + samples = np.frombuffer(dat, dtype=np.int16) / 32768 + samples = samples.reshape([-1, ch]).T + return samples, rate + except (wave.Error, EOFError): + print('failed to decode %s. maybe the file is broken!' % filename) + return np.zeros([1, 0]), 44100 + +def wave_get_audio(filename): + with open(filename, 'rb') as fin: + wav = wave.open(HackExtensibleWave(fin)) + smpwidth = wav.getsampwidth() + if smpwidth not in {1, 2, 3}: + return None + n = wav.getnframes() + if smpwidth == 1: + samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1 + elif smpwidth == 2: + samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768 + elif smpwidth == 3: + a = np.frombuffer(wav.readframes(n), dtype=np.uint8) + samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1) + del a + samples = samples / 8388608 + samples = samples.reshape([-1, wav.getnchannels()]).T + return samples, wav.getframerate() + +def get_audio(filename): + if str(filename).endswith('.wav'): + try: + a = wave_get_audio(filename) + if a: return a + except Exception: + pass + return ffmpeg_get_audio(filename) + +class FfmpegStream: + def __init__(self, proc, sample_rate, nchannels, tmpfile): + self.proc = proc + self.sample_rate = sample_rate + self.nchannels = nchannels + self.tmpfile = None + self.stream = self.gen_stream() + self.tmpfile = tmpfile + def __del__(self): + self.proc.terminate() + self.proc.communicate() + del self.proc + if self.tmpfile: + os.unlink(self.tmpfile) + def gen_stream(self): + num = yield np.array([], dtype=np.int16) + if not num: num = 1024 + while True: + to_read = num * self.nchannels * 2 + dat = self.proc.stdout.read(to_read) + num = yield np.frombuffer(dat, dtype=np.int16) + if not num: num = 1024 + if len(dat) < to_read: + break + +def ffmpeg_stream_audio(filename, is_tmp=False): + while 1: + try: + stderr=open(os.devnull, 'w') + stdin=open(os.devnull) + break + except PermissionError: + print('PermissionError occured, try again') + proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams', + '-select_streams', 'a', '-print_format', 'json'], + stderr=stderr, + stdin=stdin, + stdout=subprocess.PIPE) + prop = json.loads(proc.stdout.read()) + if 'streams' not in prop: + raise RuntimeError('FFmpeg cannot decode audio') + sample_rate = int(prop['streams'][0]['sample_rate']) + nchannels = prop['streams'][0]['channels'] + proc = subprocess.Popen(['ffmpeg', '-i', filename, + '-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'], + stderr=stderr, + stdin=stdin, + stdout=subprocess.PIPE) + tmpfile = None + if is_tmp: + tmpfile = filename + return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile) + +class WaveStream: + def __init__(self, filename, is_tmp=False): + self.is_tmp = None + self.file = open(filename, 'rb') + self.wave = wave.open(HackExtensibleWave(self.file)) + self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth() + self.sample_rate = self.wave.getframerate() + self.nchannels = self.wave.getnchannels() + if self.wave.getsampwidth() != 2: + raise NotImplementedError('wave stream currently only supports 16bit wav') + self.stream = self.gen_stream() + self.is_tmp = filename if is_tmp else None + def gen_stream(self): + num = yield np.array([], dtype=np.int16) + if not num: num = 1024 + while True: + dat = self.wave.readframes(num) + num = yield np.frombuffer(dat, dtype=np.int16) + if not num: num = 1024 + if len(dat) < num * self.smpsize: + break + def __del__(self): + if self.is_tmp: + os.unlink(self.is_tmp) + diff --git a/datautil/dataset_v2.py b/datautil/dataset_v2.py new file mode 100644 index 0000000..daabd2a --- /dev/null +++ b/datautil/dataset_v2.py @@ -0,0 +1,301 @@ +import os + +import numpy as np +import torch +import torch.fft +from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler + +from .melspec import build_mel_spec_layer +from .noise import NoiseData +from .ir import AIR, MicIRP +from .preprocess import preprocess_music + +class NumpyMemmapDataset(Dataset): + def __init__(self, path, dtype): + self.path = path + self.f = np.memmap(self.path, dtype=dtype) + + def __len__(self): + return len(self.f) + + def __getitem__(self, idx): + return self.f[idx] + + # need these, otherwise the pickler will read the whole data in memory + def __getstate__(self): + return {'path': self.path, 'dtype': self.f.dtype} + + def __setstate__(self, d): + self.__init__(d['path'], d['dtype']) + +# data augmentation on music dataset +class MusicSegmentDataset(Dataset): + def __init__(self, params, train_val): + # load some configs + assert train_val in {'train', 'validate'} + sample_rate = params['sample_rate'] + self.augmented = True + self.eval_time_shift = True + self.segment_size = int(params['segment_size'] * sample_rate) + self.hop_size = int(params['hop_size'] * sample_rate) + self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios + self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb + self.params = params + + # get fft size needed for reverb + fftconv_n = 1024 + air_len = int(params['air']['length'] * sample_rate) + ir_len = int(params['micirp']['length'] * sample_rate) + fft_needed = self.segment_size + self.pad_start + air_len + ir_len + while fftconv_n < fft_needed: + fftconv_n *= 2 + self.fftconv_n = fftconv_n + + # datasets data augmentation + cache_dir = params['cache_dir'] + os.makedirs(cache_dir, exist_ok=True) + if params['noise'][train_val]: + self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir) + else: self.noise = None + if params['air'][train_val]: + self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) + else: self.air = None + if params['micirp'][train_val]: + self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) + else: self.micirp = None + + # Load music dataset as memory mapped file + file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0] + file_name = os.path.join(cache_dir, '1' + file_name) + if os.path.exists(file_name + '.npy'): + print('load cached music from %s.bin' % file_name) + else: + preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name) + self.f = NumpyMemmapDataset(file_name + '.bin', np.int16) + + # some segmentation settings + song_len = np.load(file_name + '.npy') + self.cues = [] # start location of segment i + self.offset_left = [] # allowed left shift of segment i + self.offset_right = [] # allowed right shift of segment i + self.song_range = [] # range of song i is (start time, end time, start idx, end idx) + t = 0 + for duration in song_len: + num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size + start_cue = len(self.cues) + for idx in range(num_segs): + my_time = idx * self.hop_size + self.cues.append(t + my_time) + self.offset_left.append(my_time) + self.offset_right.append(duration - my_time) + end_cue = len(self.cues) + self.song_range.append((t, t + duration, start_cue, end_cue)) + t += duration + + # convert to torch Tensor to allow sharing memory + self.cues = torch.LongTensor(self.cues) + self.offset_left = torch.LongTensor(self.offset_left) + self.offset_right = torch.LongTensor(self.offset_right) + + # setup mel spectrogram + self.mel = build_mel_spec_layer(params) + + def get_single_segment(self, idx, offset, length): + cue = int(self.cues[idx]) + offset + left = int(self.offset_left[idx]) + offset + right = int(self.offset_right[idx]) - offset + + # choose a segment + # segment looks like: + # cue + # v + # [ pad_start | segment_size | ... ] + # \--- time_offset ---/ + segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)] + segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)]) + + # convert 16 bit to float + return segment * np.float32(1/32768) + + def __getitem__(self, indices): + if self.eval_time_shift: + # collect all segments. It is faster to do batch processing + shift_range = self.segment_size//2 + x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices] + + # random time offset only on augmented segment, as a query audio + segment_size = self.pad_start + self.segment_size + offset1 = [self.segment_size//4] * len(x) + offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() + else: + # collect all segments. It is faster to do batch processing + x = [self.get_single_segment(i, 0, self.time_offset) for i in indices] + + # random time offset + shift_range = self.time_offset - self.segment_size + segment_size = self.pad_start + self.segment_size + offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() + offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() + + x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)] + x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32)) + x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)] + x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32)) + + if self.augmented: + # background noise + if self.noise is not None: + x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max']) + + # impulse response + spec = torch.fft.rfft(x_aug, self.fftconv_n) + if self.air is not None: + spec *= self.air.random_choose(spec.shape[0]) + if self.micirp is not None: + spec *= self.micirp.random_choose(spec.shape[0]) + x_aug = torch.fft.irfft(spec, self.fftconv_n) + x_aug = x_aug[..., self.pad_start:segment_size] + + # output [x1_orig, x1_aug, x2_orig, x2_aug, ...] + x = [x_orig, x_aug] if self.augmented else [x_orig] + x = torch.stack(x, dim=1) + if self.mel is not None: + return self.mel(x) + return x + + def fan_si_le(self): + raise NotImplementedError('煩死了') + + def zuo_bu_chu_lai(self): + raise NotImplementedError('做不起來') + + def __len__(self): + return len(self.cues) + + def get_num_songs(self): + return len(self.song_range) + + def get_song_segments(self, song_id): + return self.song_range[song_id][2:4] + + def preload_song(self, song_id): + start, end, _, _ = self.song_range[song_id] + return self.f[start : end].copy() + +class TwoStageShuffler(Sampler): + def __init__(self, music_data: MusicSegmentDataset, shuffle_size): + self.music_data = music_data + self.shuffle_size = shuffle_size + self.shuffle = True + self.loaded = set() + self.generator = torch.Generator() + self.generator2 = torch.Generator() + + def set_epoch(self, epoch): + self.generator.manual_seed(42 + epoch) + self.generator2.manual_seed(42 + epoch) + + def __len__(self): + return len(self.music_data) + + def preload(self, song_id): + if song_id not in self.loaded: + self.music_data.preload_song(song_id) + self.loaded.add(song_id) + + def baseline_shuffle(self): + # the same as DataLoader with shuffle=True, but with a music preloader + # 2021/10/18 sorry I have to remove preloader + #for song in range(self.music_data.get_num_songs()): + # self.preload(song) + + yield from torch.randperm(len(self), generator=self.generator).tolist() + + def shuffling_iter(self): + # shuffle song list first + shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator) + + # split song list into chunks + chunks = torch.split(shuffle_song, self.shuffle_size) + for nc, songs in enumerate(chunks): + # sort songs to make preloader read more sequential + songs = torch.sort(songs)[0].tolist() + + buf = [] + for song in songs: + # collect segment ids + seg_start, seg_end = self.music_data.get_song_segments(song) + buf += list(range(seg_start, seg_end)) + + if nc == 0: + # load first chunk + for song in songs: + self.preload(song) + + # shuffle segments from song chunk + shuffle_segs = torch.randperm(len(buf), generator=self.generator2) + shuffle_segs = [buf[x] for x in shuffle_segs] + preload_cnt = 0 + for i, idx in enumerate(shuffle_segs): + # output shuffled segment idx + yield idx + + # preload next chunk + while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size: + song = shuffle_song[len(self.loaded)].item() + self.preload(song) + preload_cnt += 1 + + def non_shuffling_iter(self): + # just return 0 ... len(dataset)-1 + yield from range(len(self)) + + def __iter__(self): + if self.shuffle: + if self.shuffle_size is None: + return self.baseline_shuffle() + else: + return self.shuffling_iter() + return self.non_shuffling_iter() + +# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder +class SegmentedDataLoader: + def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): + assert train_val in {'train', 'validate'} + self.dataset = MusicSegmentDataset(configs, train_val) + assert configs['batch_size'] % 2 == 0 + self.batch_size = configs['batch_size'] + self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) + self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) + self.num_workers = num_workers + self.configs = configs + self.pin_memory = pin_memory + self.prefetch_factor = prefetch_factor + + # you can change shuffle to True/False + self.shuffle = True + # you can change augmented to True/False + self.augmented = True + # you can change eval time shift to True/False + self.eval_time_shift = False + + self.loader = DataLoader( + self.dataset, + sampler=self.sampler, + batch_size=None, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + prefetch_factor=self.prefetch_factor + ) + + def set_epoch(self, epoch): + self.shuffler.set_epoch(epoch) + + def __iter__(self): + self.dataset.augmented = self.augmented + self.dataset.eval_time_shift = self.eval_time_shift + self.shuffler.shuffle = self.shuffle + return iter(self.loader) + + def __len__(self): + return len(self.loader) diff --git a/datautil/ir.py b/datautil/ir.py new file mode 100644 index 0000000..2ec35af --- /dev/null +++ b/datautil/ir.py @@ -0,0 +1,89 @@ +import argparse +import csv +import os +import warnings + +import scipy.io +import numpy as np +import torch +import torch.fft +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import torchaudio + +from .audio import get_audio + +class AIR: + def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000): + print('loading Aachen IR dataset') + with open(list_csv, 'r') as fin: + reader = csv.reader(fin) + airs = [] + firstrow = next(reader) + for row in reader: + airs.append(row[0]) + data = [] + to_len = int(length * sample_rate) + self.names = [] + for name in airs: + mat = scipy.io.loadmat(os.path.join(air_dir, name)) + h_air = torch.tensor(mat['h_air'].astype(np.float32)) + assert h_air.shape[0] == 1 + h_air = h_air[0] + air_info = mat['air_info'] + fs = int(air_info['fs'][0][0][0][0]) + self.names.append(str(air_info['room'][0][0][0])) + resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air) + truncated = resampled[0:to_len] + freqd = torch.fft.rfft(truncated, fftconv_n) + data.append(freqd) + self.data = torch.stack(data) + + def random_choose(self, num): + indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) + return self.data[indices] + + def random_choose_name(self): + index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item() + return self.data[index], self.names[index] + +class MicIRP: + def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000): + print('loading microphone IR dataset') + with open(list_csv, 'r') as fin: + reader = csv.reader(fin) + mics = [] + firstrow = next(reader) + for row in reader: + mics.append(row[0]) + data = [] + to_len = int(length * sample_rate) + for name in mics: + smp, smprate = get_audio(os.path.join(mic_dir, name)) + smp = torch.FloatTensor(smp).mean(dim=0) + resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) + truncated = resampled[0:to_len] + freqd = torch.fft.rfft(truncated, fftconv_n) + data.append(freqd) + self.data = torch.stack(data) + + def random_choose(self, num): + indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) + return self.data[indices] + +if __name__ == '__main__': + args = argparse.ArgumentParser() + args.add_argument('air') + args.add_argument('out') + args = args.parse_args() + + with open(args.out, 'w', encoding='utf8', newline='\n') as fout: + writer = csv.writer(fout) + writer.writerow(['file']) + files = [] + for name in os.listdir(args.air): + if name.endswith('.mat'): + files.append(name) + files.sort() + for name in files: + writer.writerow([name]) diff --git a/datautil/melspec.py b/datautil/melspec.py new file mode 100644 index 0000000..508b893 --- /dev/null +++ b/datautil/melspec.py @@ -0,0 +1,63 @@ +import torch +import torchaudio + +class MelSpec(torch.nn.Module): + def __init__(self, + sample_rate=8000, + stft_n=1024, + stft_hop=256, + f_min=300, + f_max=4000, + n_mels=256, + naf_mode=False, + mel_log='log', + spec_norm='l2'): + super(MelSpec, self).__init__() + self.naf_mode = naf_mode + self.mel_log = mel_log + self.spec_norm = spec_norm + self.mel = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=stft_n, + hop_length=stft_hop, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + window_fn=torch.hann_window, + power = 1 if naf_mode else 2, + pad_mode = 'constant' if naf_mode else 'reflect', + norm = 'slaney' if naf_mode else None, + mel_scale = 'slaney' if naf_mode else 'htk' + ) + + def forward(self, x): + # normalize volume + p = 1e999 if self.spec_norm == 'max' else 2 + x = torch.nn.functional.normalize(x, p=p, dim=-1) + + if self.naf_mode: + x = self.mel(x) + 0.06 + else: + x = self.mel(x) + 1e-8 + + if self.mel_log == 'log10': + x = torch.log10(x) + elif self.mel_log == 'log': + x = torch.log(x) + + if self.spec_norm == 'max': + x = x - torch.amax(x, dim=(-2,-1), keepdim=True) + return x + +def build_mel_spec_layer(params): + return MelSpec( + sample_rate = params['sample_rate'], + stft_n = params['stft_n'], + stft_hop = params['stft_hop'], + f_min = params['f_min'], + f_max = params['f_max'], + n_mels = params['n_mels'], + naf_mode = params.get('naf_mode', False), + mel_log = params.get('mel_log', 'log'), + spec_norm = params.get('spec_norm', 'l2') + ) diff --git a/datautil/noise.py b/datautil/noise.py new file mode 100644 index 0000000..1a428f3 --- /dev/null +++ b/datautil/noise.py @@ -0,0 +1,109 @@ +import csv +import os +import warnings + +import tqdm +import numpy as np +import torch +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import torchaudio + +from .simpleutils import get_hash +from .audio import get_audio + +class NoiseData: + def __init__(self, noise_dir, list_csv, sample_rate, cache_dir): + print('loading noise dataset') + hashes = [] + with open(list_csv, 'r') as fin: + reader = csv.reader(fin) + noises = [] + firstrow = next(reader) + for row in reader: + noises.append(row[0]) + hashes.append(get_hash(row[0])) + hash = get_hash(''.join(hashes)) + #self.data = self.load_from_cache(list_csv, cache_dir, hash) + #if self.data is not None: + # print(self.data.shape) + # return + data = [] + silence_threshold = 0 + self.names = [] + for name in tqdm.tqdm(noises): + smp, smprate = get_audio(os.path.join(noise_dir, name)) + smp = torch.from_numpy(smp.astype(np.float32)) + + # convert to mono + smp = smp.mean(dim=0) + + # strip silence start/end + abs_smp = torch.abs(smp) + if torch.max(abs_smp) <= silence_threshold: + print('%s too silent' % name) + continue + has_sound = (abs_smp > silence_threshold).to(torch.int) + start = int(torch.argmax(has_sound)) + end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0))) + smp = smp[max(start, 0) : end] + + resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) + resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999) + data.append(resampled) + self.names.append(name) + self.data = torch.cat(data) + self.boundary = [0] + [x.shape[0] for x in data] + self.boundary = torch.LongTensor(self.boundary).cumsum(0) + del data + #self.save_to_cache(list_csv, cache_dir, hash, self.data) + print(self.data.shape) + + def load_from_cache(self, list_csv, cache_dir, hash): + loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') + loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') + if os.path.exists(loc) and os.path.exists(loc2): + with open(loc2, 'r') as fin: + read_hash = fin.read().strip() + if read_hash != hash: + return None + print('cache hit!') + return torch.from_numpy(np.fromfile(loc, dtype=np.float32)) + return None + + def save_to_cache(self, list_csv, cache_dir, hash, audio): + os.makedirs(cache_dir, exist_ok=True) + loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') + loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') + with open(loc2, 'w') as fout: + fout.write(hash) + print('save to cache') + audio.numpy().tofile(loc) + + def random_choose(self, num, duration, out_name=False): + indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long) + out = torch.zeros([num, duration], dtype=torch.float32) + for i in range(num): + start = int(indices[i]) + end = start + duration + out[i] = self.data[start:end] + name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1 + if out_name: + return out, [self.names[x] for x in name_lookup] + return out + + # x is a 2d array + def add_noises(self, x, snr_min, snr_max, out_name=False): + eps = 1e-12 + noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name) + if out_name: + noise, noise_name = noise + vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt() + vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt() + snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max) + ratio = vol_x / vol_noise + ratio *= 10 ** -(snr / 20) + x_aug = x + ratio.unsqueeze(1) * noise + if out_name: + return x_aug, noise_name, snr + return x_aug diff --git a/datautil/preprocess.py b/datautil/preprocess.py new file mode 100644 index 0000000..55f52e5 --- /dev/null +++ b/datautil/preprocess.py @@ -0,0 +1,55 @@ +import csv +import os + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset +import torchaudio +import tqdm + +from .audio import get_audio + +class Preprocessor(Dataset): + def __init__(self, files, dir, sample_rate): + self.files = files + self.dir = dir + self.resampler = {} + self.sample_rate = sample_rate + + def __getitem__(self, n): + dat = get_audio(os.path.join(self.dir, self.files[n])) + wav, smprate = dat + if smprate not in self.resampler: + self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate) + wav = torch.Tensor(wav) + wav = wav.mean(dim=0) + wav = self.resampler[smprate](torch.Tensor(wav)) + + # quantize to 16 bit again + wav *= 32768 + torch.clamp(wav, -32768, 32767, out=wav) + wav = wav.to(torch.int16) + return wav + + def __len__(self): + return len(self.files) + +def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out): + print('converting music to wav') + with open(music_csv) as fin: + reader = csv.reader(fin) + next(reader) + files = [row[0] for row in reader] + + preprocessor = Preprocessor(files, music_dir, sample_rate) + loader = DataLoader(preprocessor, num_workers=4, batch_size=None) + out_file = open(preprocess_out + '.bin', 'wb') + song_lens = [] + for wav in tqdm.tqdm(loader): + # torch.set_num_threads(1) # default multithreading causes cpu contention + + wav = wav.numpy() + out_file.write(wav.tobytes()) + song_lens.append(wav.shape[0]) + out_file.close() + np.save(preprocess_out, np.array(song_lens, dtype=np.int64)) diff --git a/datautil/simpleutils.py b/datautil/simpleutils.py new file mode 100644 index 0000000..a7f9ac2 --- /dev/null +++ b/datautil/simpleutils.py @@ -0,0 +1,26 @@ +# every utils that don't use torch +import hashlib +import json +import logging +import multiprocessing as mp +import os + + +def get_hash(s): + m = hashlib.md5() + m.update(s.encode('utf8')) + return m.hexdigest() + + +def read_config(path): + with open(path, 'r') as fin: + return json.load(fin) + + +def init_logger(app_name): + os.makedirs('logs', exist_ok=True) + logger = mp.get_logger() + logger.setLevel(logging.INFO) + handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8") + handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s')) + logger.addHandler(handler) diff --git a/datautil/specaug.py b/datautil/specaug.py new file mode 100644 index 0000000..e307521 --- /dev/null +++ b/datautil/specaug.py @@ -0,0 +1,42 @@ +import torch + +class SpecAugment: + def __init__(self, params): + self.freq_min = params.get('cutout_min', 0.1) # 5 + self.freq_max = params.get('cutout_max', 0.5) # 20 + self.time_min = params.get('cutout_min', 0.1) # 5 + self.time_max = params.get('cutout_max', 0.5) # 16 + + self.cutout_min = params.get('cutout_min', 0.1) # 0.1 + self.cutout_max = params.get('cutout_max', 0.5) # 0.4 + + def get_mask(self, F, T): + mask = torch.zeros(F, T) + # cutout + cutout_max = self.cutout_max + cutout_min = self.cutout_min + f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) + f = int(f) + f0 = torch.randint(0, F - f + 1, (1,)) + t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) + t = int(t) + t0 = torch.randint(0, T - t + 1, (1,)) + mask[f0:f0+f, t0:t0+t] = 1 + + # frequency masking + f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min)) + f = int(f) + f0 = torch.randint(0, F - f + 1, (1,)) + mask[f0:f0+f, :] = 1 + + # time masking + t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min)) + t = int(t) + t0 = torch.randint(0, T - t + 1, (1,)) + mask[:, t0:t0+t] = 1 + return mask + + def augment(self, x): + mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device) + x = x * (1 - mask) + return x diff --git a/nn_fingerprint.py b/nn_fingerprint.py index d2fe4a4..dbefb44 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -235,3 +235,11 @@ class NNFingerprint(NNOperator): else: log.error(f'Unsupported format "{format}".') return Path(path).resolve() + + def train(self, **kwargs): + from .train_nnfp import train_nnfp + config_json_path = kwargs['config_json_path'] + train_nnfp( + self._model, + config_json_path=config_json_path + ) diff --git a/train_nnfp.py b/train_nnfp.py new file mode 100644 index 0000000..09342dd --- /dev/null +++ b/train_nnfp.py @@ -0,0 +1,152 @@ +## This training script is with reference to https://github.com/stdio2016/pfann +from typing import Any + +import numpy as np +import torch +from torch import nn +from towhee.trainer.callback import Callback +from towhee.trainer.trainer import Trainer +from towhee.trainer.training_config import TrainingConfig + +from .datautil.dataset_v2 import SegmentedDataLoader +from .datautil.simpleutils import read_config +from .datautil.specaug import SpecAugment +from torch.cuda.amp import autocast, GradScaler + + +# other requirements: torchvision, torchmetrics==0.7.0 + + +def similarity_loss(y, tau): + a = torch.matmul(y, y.T) + a /= tau + Ls = [] + for i in range(y.shape[0]): + nn_self = torch.cat([a[i, :i], a[i, i + 1:]]) + softmax = torch.nn.functional.log_softmax(nn_self, dim=0) + Ls.append(softmax[i if i % 2 == 0 else i - 1]) + Ls = torch.stack(Ls) + + loss = torch.sum(Ls) / -y.shape[0] + return loss + + +class SetEpochCallback(Callback): + def __init__(self, dataloader): + super().__init__() + self.dataloader = dataloader + + def on_epoch_begin(self, epochs, logs): + self.dataloader.set_epoch(epochs) + + +class NNFPTrainer(Trainer): + def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, + model_card, params): + super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, + model_card) + self.specaug = SpecAugment(params) + self.scaler = GradScaler() + self.tau = params.get('tau', 0.05) + self.losses = [] + self.set_epoch_callback = SetEpochCallback(self.train_dataloader) + self.callbacks.add_callback(self.set_epoch_callback) + print('evaluate before fine-tune...') + self.evaluate(self.model, dict()) + + def train_step(self, model: nn.Module, inputs: Any) -> dict: + x = inputs + self.optimizer.zero_grad() + + x = torch.flatten(x, 0, 1) + x = self.specaug.augment(x) + + with autocast(): + y = model(x.to(self.configs.device)) + loss = similarity_loss(y, self.tau) + self.scaler.scale(loss).backward() + + self.scaler.step(self.optimizer) + self.scaler.update() + + self.lr_scheduler.step() + + lossnum = float(loss.item()) + self.losses.append(lossnum) + step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0} + return step_logs + + @torch.no_grad() + def evaluate(self, model: nn.Module, logs: dict) -> dict: + validate_N = 0 + y_embed = [] + minibatch = 40 + if torch.cuda.get_device_properties(0).total_memory > 11e9: + minibatch = 640 + for x in self.eval_dataloader: + x = torch.flatten(x, 0, 1) + for xx in torch.split(x, minibatch): + y = model(xx.to(self.configs.device)).cpu() + y_embed.append(y) + y_embed = torch.cat(y_embed) + y_embed_org = y_embed[0::2] + y_embed_aug = y_embed[1::2].to(self.configs.device) + + # compute validation score on GPU + self_score = [] + for embeds in torch.split(y_embed_org, 320): + A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) + self_score.append(A.diagonal(-validate_N).cpu()) + validate_N += embeds.shape[0] + self_score = torch.cat(self_score).to(self.configs.device) + + ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device) + + for embeds in torch.split(y_embed_org, 320): + A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) + ranks += (A.T >= self_score).sum(dim=0) + acc = int((ranks == 1).sum()) + print('\nvalidate score: %f' % (acc / validate_N,)) + # logs['epoch_metric'] = acc / validate_N + return logs + + +def train_nnfp(model, config_json_path): + model.train() + params = read_config(config_json_path) + train_dataloader = SegmentedDataLoader('train', params, num_workers=4) + print('training data contains %d samples' % len(train_dataloader.dataset)) + train_dataloader.shuffle = True + train_dataloader.eval_time_shift = False + train_dataloader.augmented = True + train_dataloader.set_epoch(0) + + val_dataloader = SegmentedDataLoader('validate', params, num_workers=4) + val_dataloader.shuffle = False + val_dataloader.eval_time_shift = True + val_dataloader.set_epoch(-1) + + training_config = TrainingConfig( + batch_size=params['batch_size'], + epoch_num=params['epoch'], + output_dir='fine_tune_output', + lr_scheduler_type='cosine', + eval_strategy='epoch', + lr=params['lr'] + ) + + optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr) + + nnfp_trainer = NNFPTrainer( + model=model, + training_config=training_config, + train_dataset=train_dataloader.dataset, + eval_dataset=None, + train_dataloader=train_dataloader, + eval_dataloader=val_dataloader, + model_card=None, + params = params + ) + nnfp_trainer.set_optimizer(optimizer) + + nnfp_trainer.train()