logo
ChengZi 2 years ago
parent
commit
7f5d813615
  1. 0
      datautil/__init__.py
  2. 152
      datautil/audio.py
  3. 301
      datautil/dataset_v2.py
  4. 89
      datautil/ir.py
  5. 63
      datautil/melspec.py
  6. 109
      datautil/noise.py
  7. 55
      datautil/preprocess.py
  8. 26
      datautil/simpleutils.py
  9. 42
      datautil/specaug.py
  10. 8
      nn_fingerprint.py
  11. 152
      train_nnfp.py

0
datautil/__init__.py

152
datautil/audio.py

@ -0,0 +1,152 @@
import json
import os
import subprocess
import numpy as np
import wave
import io
# because builtin wave won't read wav files with more than 2 channels
class HackExtensibleWave:
def __init__(self, stream):
self.stream = stream
self.pos = 0
def read(self, n):
r = self.stream.read(n)
new_pos = self.pos + len(r)
if self.pos < 20 and self.pos + n >= 20:
r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:]
elif 20 <= self.pos < 22:
r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:]
self.pos = new_pos
return r
def ffmpeg_get_audio(filename):
error_log = open(os.devnull, 'w')
proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'],
stderr=error_log,
stdin=open(os.devnull),
stdout=subprocess.PIPE,
bufsize=1000000)
try:
dat = proc.stdout.read()
wav = wave.open(HackExtensibleWave(io.BytesIO(dat)))
ch = wav.getnchannels()
rate = wav.getframerate()
n = wav.getnframes()
dat = wav.readframes(n)
del wav
samples = np.frombuffer(dat, dtype=np.int16) / 32768
samples = samples.reshape([-1, ch]).T
return samples, rate
except (wave.Error, EOFError):
print('failed to decode %s. maybe the file is broken!' % filename)
return np.zeros([1, 0]), 44100
def wave_get_audio(filename):
with open(filename, 'rb') as fin:
wav = wave.open(HackExtensibleWave(fin))
smpwidth = wav.getsampwidth()
if smpwidth not in {1, 2, 3}:
return None
n = wav.getnframes()
if smpwidth == 1:
samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1
elif smpwidth == 2:
samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768
elif smpwidth == 3:
a = np.frombuffer(wav.readframes(n), dtype=np.uint8)
samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1)
del a
samples = samples / 8388608
samples = samples.reshape([-1, wav.getnchannels()]).T
return samples, wav.getframerate()
def get_audio(filename):
if str(filename).endswith('.wav'):
try:
a = wave_get_audio(filename)
if a: return a
except Exception:
pass
return ffmpeg_get_audio(filename)
class FfmpegStream:
def __init__(self, proc, sample_rate, nchannels, tmpfile):
self.proc = proc
self.sample_rate = sample_rate
self.nchannels = nchannels
self.tmpfile = None
self.stream = self.gen_stream()
self.tmpfile = tmpfile
def __del__(self):
self.proc.terminate()
self.proc.communicate()
del self.proc
if self.tmpfile:
os.unlink(self.tmpfile)
def gen_stream(self):
num = yield np.array([], dtype=np.int16)
if not num: num = 1024
while True:
to_read = num * self.nchannels * 2
dat = self.proc.stdout.read(to_read)
num = yield np.frombuffer(dat, dtype=np.int16)
if not num: num = 1024
if len(dat) < to_read:
break
def ffmpeg_stream_audio(filename, is_tmp=False):
while 1:
try:
stderr=open(os.devnull, 'w')
stdin=open(os.devnull)
break
except PermissionError:
print('PermissionError occured, try again')
proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams',
'-select_streams', 'a', '-print_format', 'json'],
stderr=stderr,
stdin=stdin,
stdout=subprocess.PIPE)
prop = json.loads(proc.stdout.read())
if 'streams' not in prop:
raise RuntimeError('FFmpeg cannot decode audio')
sample_rate = int(prop['streams'][0]['sample_rate'])
nchannels = prop['streams'][0]['channels']
proc = subprocess.Popen(['ffmpeg', '-i', filename,
'-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'],
stderr=stderr,
stdin=stdin,
stdout=subprocess.PIPE)
tmpfile = None
if is_tmp:
tmpfile = filename
return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile)
class WaveStream:
def __init__(self, filename, is_tmp=False):
self.is_tmp = None
self.file = open(filename, 'rb')
self.wave = wave.open(HackExtensibleWave(self.file))
self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth()
self.sample_rate = self.wave.getframerate()
self.nchannels = self.wave.getnchannels()
if self.wave.getsampwidth() != 2:
raise NotImplementedError('wave stream currently only supports 16bit wav')
self.stream = self.gen_stream()
self.is_tmp = filename if is_tmp else None
def gen_stream(self):
num = yield np.array([], dtype=np.int16)
if not num: num = 1024
while True:
dat = self.wave.readframes(num)
num = yield np.frombuffer(dat, dtype=np.int16)
if not num: num = 1024
if len(dat) < num * self.smpsize:
break
def __del__(self):
if self.is_tmp:
os.unlink(self.is_tmp)

301
datautil/dataset_v2.py

@ -0,0 +1,301 @@
import os
import numpy as np
import torch
import torch.fft
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler
from .melspec import build_mel_spec_layer
from .noise import NoiseData
from .ir import AIR, MicIRP
from .preprocess import preprocess_music
class NumpyMemmapDataset(Dataset):
def __init__(self, path, dtype):
self.path = path
self.f = np.memmap(self.path, dtype=dtype)
def __len__(self):
return len(self.f)
def __getitem__(self, idx):
return self.f[idx]
# need these, otherwise the pickler will read the whole data in memory
def __getstate__(self):
return {'path': self.path, 'dtype': self.f.dtype}
def __setstate__(self, d):
self.__init__(d['path'], d['dtype'])
# data augmentation on music dataset
class MusicSegmentDataset(Dataset):
def __init__(self, params, train_val):
# load some configs
assert train_val in {'train', 'validate'}
sample_rate = params['sample_rate']
self.augmented = True
self.eval_time_shift = True
self.segment_size = int(params['segment_size'] * sample_rate)
self.hop_size = int(params['hop_size'] * sample_rate)
self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios
self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb
self.params = params
# get fft size needed for reverb
fftconv_n = 1024
air_len = int(params['air']['length'] * sample_rate)
ir_len = int(params['micirp']['length'] * sample_rate)
fft_needed = self.segment_size + self.pad_start + air_len + ir_len
while fftconv_n < fft_needed:
fftconv_n *= 2
self.fftconv_n = fftconv_n
# datasets data augmentation
cache_dir = params['cache_dir']
os.makedirs(cache_dir, exist_ok=True)
if params['noise'][train_val]:
self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir)
else: self.noise = None
if params['air'][train_val]:
self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate)
else: self.air = None
if params['micirp'][train_val]:
self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate)
else: self.micirp = None
# Load music dataset as memory mapped file
file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0]
file_name = os.path.join(cache_dir, '1' + file_name)
if os.path.exists(file_name + '.npy'):
print('load cached music from %s.bin' % file_name)
else:
preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name)
self.f = NumpyMemmapDataset(file_name + '.bin', np.int16)
# some segmentation settings
song_len = np.load(file_name + '.npy')
self.cues = [] # start location of segment i
self.offset_left = [] # allowed left shift of segment i
self.offset_right = [] # allowed right shift of segment i
self.song_range = [] # range of song i is (start time, end time, start idx, end idx)
t = 0
for duration in song_len:
num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size
start_cue = len(self.cues)
for idx in range(num_segs):
my_time = idx * self.hop_size
self.cues.append(t + my_time)
self.offset_left.append(my_time)
self.offset_right.append(duration - my_time)
end_cue = len(self.cues)
self.song_range.append((t, t + duration, start_cue, end_cue))
t += duration
# convert to torch Tensor to allow sharing memory
self.cues = torch.LongTensor(self.cues)
self.offset_left = torch.LongTensor(self.offset_left)
self.offset_right = torch.LongTensor(self.offset_right)
# setup mel spectrogram
self.mel = build_mel_spec_layer(params)
def get_single_segment(self, idx, offset, length):
cue = int(self.cues[idx]) + offset
left = int(self.offset_left[idx]) + offset
right = int(self.offset_right[idx]) - offset
# choose a segment
# segment looks like:
# cue
# v
# [ pad_start | segment_size | ... ]
# \--- time_offset ---/
segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)]
segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)])
# convert 16 bit to float
return segment * np.float32(1/32768)
def __getitem__(self, indices):
if self.eval_time_shift:
# collect all segments. It is faster to do batch processing
shift_range = self.segment_size//2
x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices]
# random time offset only on augmented segment, as a query audio
segment_size = self.pad_start + self.segment_size
offset1 = [self.segment_size//4] * len(x)
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
else:
# collect all segments. It is faster to do batch processing
x = [self.get_single_segment(i, 0, self.time_offset) for i in indices]
# random time offset
shift_range = self.time_offset - self.segment_size
segment_size = self.pad_start + self.segment_size
offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist()
x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)]
x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32))
x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)]
x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32))
if self.augmented:
# background noise
if self.noise is not None:
x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max'])
# impulse response
spec = torch.fft.rfft(x_aug, self.fftconv_n)
if self.air is not None:
spec *= self.air.random_choose(spec.shape[0])
if self.micirp is not None:
spec *= self.micirp.random_choose(spec.shape[0])
x_aug = torch.fft.irfft(spec, self.fftconv_n)
x_aug = x_aug[..., self.pad_start:segment_size]
# output [x1_orig, x1_aug, x2_orig, x2_aug, ...]
x = [x_orig, x_aug] if self.augmented else [x_orig]
x = torch.stack(x, dim=1)
if self.mel is not None:
return self.mel(x)
return x
def fan_si_le(self):
raise NotImplementedError('煩死了')
def zuo_bu_chu_lai(self):
raise NotImplementedError('做不起來')
def __len__(self):
return len(self.cues)
def get_num_songs(self):
return len(self.song_range)
def get_song_segments(self, song_id):
return self.song_range[song_id][2:4]
def preload_song(self, song_id):
start, end, _, _ = self.song_range[song_id]
return self.f[start : end].copy()
class TwoStageShuffler(Sampler):
def __init__(self, music_data: MusicSegmentDataset, shuffle_size):
self.music_data = music_data
self.shuffle_size = shuffle_size
self.shuffle = True
self.loaded = set()
self.generator = torch.Generator()
self.generator2 = torch.Generator()
def set_epoch(self, epoch):
self.generator.manual_seed(42 + epoch)
self.generator2.manual_seed(42 + epoch)
def __len__(self):
return len(self.music_data)
def preload(self, song_id):
if song_id not in self.loaded:
self.music_data.preload_song(song_id)
self.loaded.add(song_id)
def baseline_shuffle(self):
# the same as DataLoader with shuffle=True, but with a music preloader
# 2021/10/18 sorry I have to remove preloader
#for song in range(self.music_data.get_num_songs()):
# self.preload(song)
yield from torch.randperm(len(self), generator=self.generator).tolist()
def shuffling_iter(self):
# shuffle song list first
shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator)
# split song list into chunks
chunks = torch.split(shuffle_song, self.shuffle_size)
for nc, songs in enumerate(chunks):
# sort songs to make preloader read more sequential
songs = torch.sort(songs)[0].tolist()
buf = []
for song in songs:
# collect segment ids
seg_start, seg_end = self.music_data.get_song_segments(song)
buf += list(range(seg_start, seg_end))
if nc == 0:
# load first chunk
for song in songs:
self.preload(song)
# shuffle segments from song chunk
shuffle_segs = torch.randperm(len(buf), generator=self.generator2)
shuffle_segs = [buf[x] for x in shuffle_segs]
preload_cnt = 0
for i, idx in enumerate(shuffle_segs):
# output shuffled segment idx
yield idx
# preload next chunk
while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size:
song = shuffle_song[len(self.loaded)].item()
self.preload(song)
preload_cnt += 1
def non_shuffling_iter(self):
# just return 0 ... len(dataset)-1
yield from range(len(self))
def __iter__(self):
if self.shuffle:
if self.shuffle_size is None:
return self.baseline_shuffle()
else:
return self.shuffling_iter()
return self.non_shuffling_iter()
# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder
class SegmentedDataLoader:
def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2):
assert train_val in {'train', 'validate'}
self.dataset = MusicSegmentDataset(configs, train_val)
assert configs['batch_size'] % 2 == 0
self.batch_size = configs['batch_size']
self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size'])
self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False)
self.num_workers = num_workers
self.configs = configs
self.pin_memory = pin_memory
self.prefetch_factor = prefetch_factor
# you can change shuffle to True/False
self.shuffle = True
# you can change augmented to True/False
self.augmented = True
# you can change eval time shift to True/False
self.eval_time_shift = False
self.loader = DataLoader(
self.dataset,
sampler=self.sampler,
batch_size=None,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
prefetch_factor=self.prefetch_factor
)
def set_epoch(self, epoch):
self.shuffler.set_epoch(epoch)
def __iter__(self):
self.dataset.augmented = self.augmented
self.dataset.eval_time_shift = self.eval_time_shift
self.shuffler.shuffle = self.shuffle
return iter(self.loader)
def __len__(self):
return len(self.loader)

89
datautil/ir.py

@ -0,0 +1,89 @@
import argparse
import csv
import os
import warnings
import scipy.io
import numpy as np
import torch
import torch.fft
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torchaudio
from .audio import get_audio
class AIR:
def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000):
print('loading Aachen IR dataset')
with open(list_csv, 'r') as fin:
reader = csv.reader(fin)
airs = []
firstrow = next(reader)
for row in reader:
airs.append(row[0])
data = []
to_len = int(length * sample_rate)
self.names = []
for name in airs:
mat = scipy.io.loadmat(os.path.join(air_dir, name))
h_air = torch.tensor(mat['h_air'].astype(np.float32))
assert h_air.shape[0] == 1
h_air = h_air[0]
air_info = mat['air_info']
fs = int(air_info['fs'][0][0][0][0])
self.names.append(str(air_info['room'][0][0][0]))
resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air)
truncated = resampled[0:to_len]
freqd = torch.fft.rfft(truncated, fftconv_n)
data.append(freqd)
self.data = torch.stack(data)
def random_choose(self, num):
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long)
return self.data[indices]
def random_choose_name(self):
index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item()
return self.data[index], self.names[index]
class MicIRP:
def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000):
print('loading microphone IR dataset')
with open(list_csv, 'r') as fin:
reader = csv.reader(fin)
mics = []
firstrow = next(reader)
for row in reader:
mics.append(row[0])
data = []
to_len = int(length * sample_rate)
for name in mics:
smp, smprate = get_audio(os.path.join(mic_dir, name))
smp = torch.FloatTensor(smp).mean(dim=0)
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
truncated = resampled[0:to_len]
freqd = torch.fft.rfft(truncated, fftconv_n)
data.append(freqd)
self.data = torch.stack(data)
def random_choose(self, num):
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long)
return self.data[indices]
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('air')
args.add_argument('out')
args = args.parse_args()
with open(args.out, 'w', encoding='utf8', newline='\n') as fout:
writer = csv.writer(fout)
writer.writerow(['file'])
files = []
for name in os.listdir(args.air):
if name.endswith('.mat'):
files.append(name)
files.sort()
for name in files:
writer.writerow([name])

63
datautil/melspec.py

@ -0,0 +1,63 @@
import torch
import torchaudio
class MelSpec(torch.nn.Module):
def __init__(self,
sample_rate=8000,
stft_n=1024,
stft_hop=256,
f_min=300,
f_max=4000,
n_mels=256,
naf_mode=False,
mel_log='log',
spec_norm='l2'):
super(MelSpec, self).__init__()
self.naf_mode = naf_mode
self.mel_log = mel_log
self.spec_norm = spec_norm
self.mel = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_fft=stft_n,
hop_length=stft_hop,
f_min=f_min,
f_max=f_max,
n_mels=n_mels,
window_fn=torch.hann_window,
power = 1 if naf_mode else 2,
pad_mode = 'constant' if naf_mode else 'reflect',
norm = 'slaney' if naf_mode else None,
mel_scale = 'slaney' if naf_mode else 'htk'
)
def forward(self, x):
# normalize volume
p = 1e999 if self.spec_norm == 'max' else 2
x = torch.nn.functional.normalize(x, p=p, dim=-1)
if self.naf_mode:
x = self.mel(x) + 0.06
else:
x = self.mel(x) + 1e-8
if self.mel_log == 'log10':
x = torch.log10(x)
elif self.mel_log == 'log':
x = torch.log(x)
if self.spec_norm == 'max':
x = x - torch.amax(x, dim=(-2,-1), keepdim=True)
return x
def build_mel_spec_layer(params):
return MelSpec(
sample_rate = params['sample_rate'],
stft_n = params['stft_n'],
stft_hop = params['stft_hop'],
f_min = params['f_min'],
f_max = params['f_max'],
n_mels = params['n_mels'],
naf_mode = params.get('naf_mode', False),
mel_log = params.get('mel_log', 'log'),
spec_norm = params.get('spec_norm', 'l2')
)

109
datautil/noise.py

@ -0,0 +1,109 @@
import csv
import os
import warnings
import tqdm
import numpy as np
import torch
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torchaudio
from .simpleutils import get_hash
from .audio import get_audio
class NoiseData:
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir):
print('loading noise dataset')
hashes = []
with open(list_csv, 'r') as fin:
reader = csv.reader(fin)
noises = []
firstrow = next(reader)
for row in reader:
noises.append(row[0])
hashes.append(get_hash(row[0]))
hash = get_hash(''.join(hashes))
#self.data = self.load_from_cache(list_csv, cache_dir, hash)
#if self.data is not None:
# print(self.data.shape)
# return
data = []
silence_threshold = 0
self.names = []
for name in tqdm.tqdm(noises):
smp, smprate = get_audio(os.path.join(noise_dir, name))
smp = torch.from_numpy(smp.astype(np.float32))
# convert to mono
smp = smp.mean(dim=0)
# strip silence start/end
abs_smp = torch.abs(smp)
if torch.max(abs_smp) <= silence_threshold:
print('%s too silent' % name)
continue
has_sound = (abs_smp > silence_threshold).to(torch.int)
start = int(torch.argmax(has_sound))
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0)))
smp = smp[max(start, 0) : end]
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999)
data.append(resampled)
self.names.append(name)
self.data = torch.cat(data)
self.boundary = [0] + [x.shape[0] for x in data]
self.boundary = torch.LongTensor(self.boundary).cumsum(0)
del data
#self.save_to_cache(list_csv, cache_dir, hash, self.data)
print(self.data.shape)
def load_from_cache(self, list_csv, cache_dir, hash):
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
if os.path.exists(loc) and os.path.exists(loc2):
with open(loc2, 'r') as fin:
read_hash = fin.read().strip()
if read_hash != hash:
return None
print('cache hit!')
return torch.from_numpy(np.fromfile(loc, dtype=np.float32))
return None
def save_to_cache(self, list_csv, cache_dir, hash, audio):
os.makedirs(cache_dir, exist_ok=True)
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy')
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash')
with open(loc2, 'w') as fout:
fout.write(hash)
print('save to cache')
audio.numpy().tofile(loc)
def random_choose(self, num, duration, out_name=False):
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long)
out = torch.zeros([num, duration], dtype=torch.float32)
for i in range(num):
start = int(indices[i])
end = start + duration
out[i] = self.data[start:end]
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1
if out_name:
return out, [self.names[x] for x in name_lookup]
return out
# x is a 2d array
def add_noises(self, x, snr_min, snr_max, out_name=False):
eps = 1e-12
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name)
if out_name:
noise, noise_name = noise
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt()
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt()
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max)
ratio = vol_x / vol_noise
ratio *= 10 ** -(snr / 20)
x_aug = x + ratio.unsqueeze(1) * noise
if out_name:
return x_aug, noise_name, snr
return x_aug

55
datautil/preprocess.py

@ -0,0 +1,55 @@
import csv
import os
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import torchaudio
import tqdm
from .audio import get_audio
class Preprocessor(Dataset):
def __init__(self, files, dir, sample_rate):
self.files = files
self.dir = dir
self.resampler = {}
self.sample_rate = sample_rate
def __getitem__(self, n):
dat = get_audio(os.path.join(self.dir, self.files[n]))
wav, smprate = dat
if smprate not in self.resampler:
self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate)
wav = torch.Tensor(wav)
wav = wav.mean(dim=0)
wav = self.resampler[smprate](torch.Tensor(wav))
# quantize to 16 bit again
wav *= 32768
torch.clamp(wav, -32768, 32767, out=wav)
wav = wav.to(torch.int16)
return wav
def __len__(self):
return len(self.files)
def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out):
print('converting music to wav')
with open(music_csv) as fin:
reader = csv.reader(fin)
next(reader)
files = [row[0] for row in reader]
preprocessor = Preprocessor(files, music_dir, sample_rate)
loader = DataLoader(preprocessor, num_workers=4, batch_size=None)
out_file = open(preprocess_out + '.bin', 'wb')
song_lens = []
for wav in tqdm.tqdm(loader):
# torch.set_num_threads(1) # default multithreading causes cpu contention
wav = wav.numpy()
out_file.write(wav.tobytes())
song_lens.append(wav.shape[0])
out_file.close()
np.save(preprocess_out, np.array(song_lens, dtype=np.int64))

26
datautil/simpleutils.py

@ -0,0 +1,26 @@
# every utils that don't use torch
import hashlib
import json
import logging
import multiprocessing as mp
import os
def get_hash(s):
m = hashlib.md5()
m.update(s.encode('utf8'))
return m.hexdigest()
def read_config(path):
with open(path, 'r') as fin:
return json.load(fin)
def init_logger(app_name):
os.makedirs('logs', exist_ok=True)
logger = mp.get_logger()
logger.setLevel(logging.INFO)
handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8")
handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s'))
logger.addHandler(handler)

42
datautil/specaug.py

@ -0,0 +1,42 @@
import torch
class SpecAugment:
def __init__(self, params):
self.freq_min = params.get('cutout_min', 0.1) # 5
self.freq_max = params.get('cutout_max', 0.5) # 20
self.time_min = params.get('cutout_min', 0.1) # 5
self.time_max = params.get('cutout_max', 0.5) # 16
self.cutout_min = params.get('cutout_min', 0.1) # 0.1
self.cutout_max = params.get('cutout_max', 0.5) # 0.4
def get_mask(self, F, T):
mask = torch.zeros(F, T)
# cutout
cutout_max = self.cutout_max
cutout_min = self.cutout_min
f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min))
f = int(f)
f0 = torch.randint(0, F - f + 1, (1,))
t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min))
t = int(t)
t0 = torch.randint(0, T - t + 1, (1,))
mask[f0:f0+f, t0:t0+t] = 1
# frequency masking
f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min))
f = int(f)
f0 = torch.randint(0, F - f + 1, (1,))
mask[f0:f0+f, :] = 1
# time masking
t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min))
t = int(t)
t0 = torch.randint(0, T - t + 1, (1,))
mask[:, t0:t0+t] = 1
return mask
def augment(self, x):
mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device)
x = x * (1 - mask)
return x

8
nn_fingerprint.py

@ -235,3 +235,11 @@ class NNFingerprint(NNOperator):
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()
def train(self, **kwargs):
from .train_nnfp import train_nnfp
config_json_path = kwargs['config_json_path']
train_nnfp(
self._model,
config_json_path=config_json_path
)

152
train_nnfp.py

@ -0,0 +1,152 @@
## This training script is with reference to https://github.com/stdio2016/pfann
from typing import Any
import numpy as np
import torch
from torch import nn
from towhee.trainer.callback import Callback
from towhee.trainer.trainer import Trainer
from towhee.trainer.training_config import TrainingConfig
from .datautil.dataset_v2 import SegmentedDataLoader
from .datautil.simpleutils import read_config
from .datautil.specaug import SpecAugment
from torch.cuda.amp import autocast, GradScaler
# other requirements: torchvision, torchmetrics==0.7.0
def similarity_loss(y, tau):
a = torch.matmul(y, y.T)
a /= tau
Ls = []
for i in range(y.shape[0]):
nn_self = torch.cat([a[i, :i], a[i, i + 1:]])
softmax = torch.nn.functional.log_softmax(nn_self, dim=0)
Ls.append(softmax[i if i % 2 == 0 else i - 1])
Ls = torch.stack(Ls)
loss = torch.sum(Ls) / -y.shape[0]
return loss
class SetEpochCallback(Callback):
def __init__(self, dataloader):
super().__init__()
self.dataloader = dataloader
def on_epoch_begin(self, epochs, logs):
self.dataloader.set_epoch(epochs)
class NNFPTrainer(Trainer):
def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader,
model_card, params):
super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader,
model_card)
self.specaug = SpecAugment(params)
self.scaler = GradScaler()
self.tau = params.get('tau', 0.05)
self.losses = []
self.set_epoch_callback = SetEpochCallback(self.train_dataloader)
self.callbacks.add_callback(self.set_epoch_callback)
print('evaluate before fine-tune...')
self.evaluate(self.model, dict())
def train_step(self, model: nn.Module, inputs: Any) -> dict:
x = inputs
self.optimizer.zero_grad()
x = torch.flatten(x, 0, 1)
x = self.specaug.augment(x)
with autocast():
y = model(x.to(self.configs.device))
loss = similarity_loss(y, self.tau)
self.scaler.scale(loss).backward()
self.scaler.step(self.optimizer)
self.scaler.update()
self.lr_scheduler.step()
lossnum = float(loss.item())
self.losses.append(lossnum)
step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0}
return step_logs
@torch.no_grad()
def evaluate(self, model: nn.Module, logs: dict) -> dict:
validate_N = 0
y_embed = []
minibatch = 40
if torch.cuda.get_device_properties(0).total_memory > 11e9:
minibatch = 640
for x in self.eval_dataloader:
x = torch.flatten(x, 0, 1)
for xx in torch.split(x, minibatch):
y = model(xx.to(self.configs.device)).cpu()
y_embed.append(y)
y_embed = torch.cat(y_embed)
y_embed_org = y_embed[0::2]
y_embed_aug = y_embed[1::2].to(self.configs.device)
# compute validation score on GPU
self_score = []
for embeds in torch.split(y_embed_org, 320):
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device))
self_score.append(A.diagonal(-validate_N).cpu())
validate_N += embeds.shape[0]
self_score = torch.cat(self_score).to(self.configs.device)
ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device)
for embeds in torch.split(y_embed_org, 320):
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device))
ranks += (A.T >= self_score).sum(dim=0)
acc = int((ranks == 1).sum())
print('\nvalidate score: %f' % (acc / validate_N,))
# logs['epoch_metric'] = acc / validate_N
return logs
def train_nnfp(model, config_json_path):
model.train()
params = read_config(config_json_path)
train_dataloader = SegmentedDataLoader('train', params, num_workers=4)
print('training data contains %d samples' % len(train_dataloader.dataset))
train_dataloader.shuffle = True
train_dataloader.eval_time_shift = False
train_dataloader.augmented = True
train_dataloader.set_epoch(0)
val_dataloader = SegmentedDataLoader('validate', params, num_workers=4)
val_dataloader.shuffle = False
val_dataloader.eval_time_shift = True
val_dataloader.set_epoch(-1)
training_config = TrainingConfig(
batch_size=params['batch_size'],
epoch_num=params['epoch'],
output_dir='fine_tune_output',
lr_scheduler_type='cosine',
eval_strategy='epoch',
lr=params['lr']
)
optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr)
nnfp_trainer = NNFPTrainer(
model=model,
training_config=training_config,
train_dataset=train_dataloader.dataset,
eval_dataset=None,
train_dataloader=train_dataloader,
eval_dataloader=val_dataloader,
model_card=None,
params = params
)
nnfp_trainer.set_optimizer(optimizer)
nnfp_trainer.train()
Loading…
Cancel
Save