11 changed files with 997 additions and 0 deletions
@ -0,0 +1,152 @@ |
import json |
import os |
import subprocess |
import numpy as np |
import wave |
import io |
# because builtin wave won't read wav files with more than 2 channels |
class HackExtensibleWave: |
def __init__(self, stream): |
self.stream = stream |
self.pos = 0 |
def read(self, n): |
r = self.stream.read(n) |
new_pos = self.pos + len(r) |
if self.pos < 20 and self.pos + n >= 20: |
r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] |
elif 20 <= self.pos < 22: |
r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] |
self.pos = new_pos |
return r |
def ffmpeg_get_audio(filename): |
error_log = open(os.devnull, 'w') |
proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], |
stderr=error_log, |
stdin=open(os.devnull), |
stdout=subprocess.PIPE, |
bufsize=1000000) |
try: |
dat = proc.stdout.read() |
wav = wave.open(HackExtensibleWave(io.BytesIO(dat))) |
ch = wav.getnchannels() |
rate = wav.getframerate() |
n = wav.getnframes() |
dat = wav.readframes(n) |
del wav |
samples = np.frombuffer(dat, dtype=np.int16) / 32768 |
samples = samples.reshape([-1, ch]).T |
return samples, rate |
except (wave.Error, EOFError): |
print('failed to decode %s. maybe the file is broken!' % filename) |
return np.zeros([1, 0]), 44100 |
def wave_get_audio(filename): |
with open(filename, 'rb') as fin: |
wav = wave.open(HackExtensibleWave(fin)) |
smpwidth = wav.getsampwidth() |
if smpwidth not in {1, 2, 3}: |
return None |
n = wav.getnframes() |
if smpwidth == 1: |
samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1 |
elif smpwidth == 2: |
samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768 |
elif smpwidth == 3: |
a = np.frombuffer(wav.readframes(n), dtype=np.uint8) |
samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1) |
del a |
samples = samples / 8388608 |
samples = samples.reshape([-1, wav.getnchannels()]).T |
return samples, wav.getframerate() |
def get_audio(filename): |
if str(filename).endswith('.wav'): |
try: |
a = wave_get_audio(filename) |
if a: return a |
except Exception: |
pass |
return ffmpeg_get_audio(filename) |
class FfmpegStream: |
def __init__(self, proc, sample_rate, nchannels, tmpfile): |
self.proc = proc |
self.sample_rate = sample_rate |
self.nchannels = nchannels |
self.tmpfile = None |
self.stream = self.gen_stream() |
self.tmpfile = tmpfile |
def __del__(self): |
self.proc.terminate() |
self.proc.communicate() |
del self.proc |
if self.tmpfile: |
os.unlink(self.tmpfile) |
def gen_stream(self): |
num = yield np.array([], dtype=np.int16) |
if not num: num = 1024 |
while True: |
to_read = num * self.nchannels * 2 |
dat = self.proc.stdout.read(to_read) |
num = yield np.frombuffer(dat, dtype=np.int16) |
if not num: num = 1024 |
if len(dat) < to_read: |
break |
def ffmpeg_stream_audio(filename, is_tmp=False): |
while 1: |
try: |
stderr=open(os.devnull, 'w') |
stdin=open(os.devnull) |
break |
except PermissionError: |
print('PermissionError occured, try again') |
proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams', |
'-select_streams', 'a', '-print_format', 'json'], |
stderr=stderr, |
stdin=stdin, |
stdout=subprocess.PIPE) |
prop = json.loads(proc.stdout.read()) |
if 'streams' not in prop: |
raise RuntimeError('FFmpeg cannot decode audio') |
sample_rate = int(prop['streams'][0]['sample_rate']) |
nchannels = prop['streams'][0]['channels'] |
proc = subprocess.Popen(['ffmpeg', '-i', filename, |
'-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'], |
stderr=stderr, |
stdin=stdin, |
stdout=subprocess.PIPE) |
tmpfile = None |
if is_tmp: |
tmpfile = filename |
return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile) |
class WaveStream: |
def __init__(self, filename, is_tmp=False): |
self.is_tmp = None |
self.file = open(filename, 'rb') |
self.wave = wave.open(HackExtensibleWave(self.file)) |
self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth() |
self.sample_rate = self.wave.getframerate() |
self.nchannels = self.wave.getnchannels() |
if self.wave.getsampwidth() != 2: |
raise NotImplementedError('wave stream currently only supports 16bit wav') |
self.stream = self.gen_stream() |
self.is_tmp = filename if is_tmp else None |
def gen_stream(self): |
num = yield np.array([], dtype=np.int16) |
if not num: num = 1024 |
while True: |
dat = self.wave.readframes(num) |
num = yield np.frombuffer(dat, dtype=np.int16) |
if not num: num = 1024 |
if len(dat) < num * self.smpsize: |
break |
def __del__(self): |
if self.is_tmp: |
os.unlink(self.is_tmp) |
@ -0,0 +1,301 @@ |
import os |
import numpy as np |
import torch |
import torch.fft |
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler |
from .melspec import build_mel_spec_layer |
from .noise import NoiseData |
from .ir import AIR, MicIRP |
from .preprocess import preprocess_music |
class NumpyMemmapDataset(Dataset): |
def __init__(self, path, dtype): |
self.path = path |
self.f = np.memmap(self.path, dtype=dtype) |
def __len__(self): |
return len(self.f) |
def __getitem__(self, idx): |
return self.f[idx] |
# need these, otherwise the pickler will read the whole data in memory |
def __getstate__(self): |
return {'path': self.path, 'dtype': self.f.dtype} |
def __setstate__(self, d): |
self.__init__(d['path'], d['dtype']) |
# data augmentation on music dataset |
class MusicSegmentDataset(Dataset): |
def __init__(self, params, train_val): |
# load some configs |
assert train_val in {'train', 'validate'} |
sample_rate = params['sample_rate'] |
self.augmented = True |
self.eval_time_shift = True |
self.segment_size = int(params['segment_size'] * sample_rate) |
self.hop_size = int(params['hop_size'] * sample_rate) |
self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios |
self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb |
self.params = params |
# get fft size needed for reverb |
fftconv_n = 1024 |
air_len = int(params['air']['length'] * sample_rate) |
ir_len = int(params['micirp']['length'] * sample_rate) |
fft_needed = self.segment_size + self.pad_start + air_len + ir_len |
while fftconv_n < fft_needed: |
fftconv_n *= 2 |
self.fftconv_n = fftconv_n |
# datasets data augmentation |
cache_dir = params['cache_dir'] |
os.makedirs(cache_dir, exist_ok=True) |
if params['noise'][train_val]: |
self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir) |
else: self.noise = None |
if params['air'][train_val]: |
self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) |
else: self.air = None |
if params['micirp'][train_val]: |
self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) |
else: self.micirp = None |
# Load music dataset as memory mapped file |
file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0] |
file_name = os.path.join(cache_dir, '1' + file_name) |
if os.path.exists(file_name + '.npy'): |
print('load cached music from %s.bin' % file_name) |
else: |
preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name) |
self.f = NumpyMemmapDataset(file_name + '.bin', np.int16) |
# some segmentation settings |
song_len = np.load(file_name + '.npy') |
self.cues = [] # start location of segment i |
self.offset_left = [] # allowed left shift of segment i |
self.offset_right = [] # allowed right shift of segment i |
self.song_range = [] # range of song i is (start time, end time, start idx, end idx) |
t = 0 |
for duration in song_len: |
num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size |
start_cue = len(self.cues) |
for idx in range(num_segs): |
my_time = idx * self.hop_size |
self.cues.append(t + my_time) |
self.offset_left.append(my_time) |
self.offset_right.append(duration - my_time) |
end_cue = len(self.cues) |
self.song_range.append((t, t + duration, start_cue, end_cue)) |
t += duration |
# convert to torch Tensor to allow sharing memory |
self.cues = torch.LongTensor(self.cues) |
self.offset_left = torch.LongTensor(self.offset_left) |
self.offset_right = torch.LongTensor(self.offset_right) |
# setup mel spectrogram |
self.mel = build_mel_spec_layer(params) |
def get_single_segment(self, idx, offset, length): |
cue = int(self.cues[idx]) + offset |
left = int(self.offset_left[idx]) + offset |
right = int(self.offset_right[idx]) - offset |
# choose a segment |
# segment looks like: |
# cue |
# v |
# [ pad_start | segment_size | ... ] |
# \--- time_offset ---/ |
segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)] |
segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)]) |
# convert 16 bit to float |
return segment * np.float32(1/32768) |
def __getitem__(self, indices): |
if self.eval_time_shift: |
# collect all segments. It is faster to do batch processing |
shift_range = self.segment_size//2 |
x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices] |
# random time offset only on augmented segment, as a query audio |
segment_size = self.pad_start + self.segment_size |
offset1 = [self.segment_size//4] * len(x) |
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
else: |
# collect all segments. It is faster to do batch processing |
x = [self.get_single_segment(i, 0, self.time_offset) for i in indices] |
# random time offset |
shift_range = self.time_offset - self.segment_size |
segment_size = self.pad_start + self.segment_size |
offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)] |
x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32)) |
x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)] |
x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32)) |
if self.augmented: |
# background noise |
if self.noise is not None: |
x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max']) |
# impulse response |
spec = torch.fft.rfft(x_aug, self.fftconv_n) |
if self.air is not None: |
spec *= self.air.random_choose(spec.shape[0]) |
if self.micirp is not None: |
spec *= self.micirp.random_choose(spec.shape[0]) |
x_aug = torch.fft.irfft(spec, self.fftconv_n) |
x_aug = x_aug[..., self.pad_start:segment_size] |
# output [x1_orig, x1_aug, x2_orig, x2_aug, ...] |
x = [x_orig, x_aug] if self.augmented else [x_orig] |
x = torch.stack(x, dim=1) |
if self.mel is not None: |
return self.mel(x) |
return x |
def fan_si_le(self): |
raise NotImplementedError('煩死了') |
def zuo_bu_chu_lai(self): |
raise NotImplementedError('做不起來') |
def __len__(self): |
return len(self.cues) |
def get_num_songs(self): |
return len(self.song_range) |
def get_song_segments(self, song_id): |
return self.song_range[song_id][2:4] |
def preload_song(self, song_id): |
start, end, _, _ = self.song_range[song_id] |
return self.f[start : end].copy() |
class TwoStageShuffler(Sampler): |
def __init__(self, music_data: MusicSegmentDataset, shuffle_size): |
self.music_data = music_data |
self.shuffle_size = shuffle_size |
self.shuffle = True |
self.loaded = set() |
self.generator = torch.Generator() |
self.generator2 = torch.Generator() |
def set_epoch(self, epoch): |
self.generator.manual_seed(42 + epoch) |
self.generator2.manual_seed(42 + epoch) |
def __len__(self): |
return len(self.music_data) |
def preload(self, song_id): |
if song_id not in self.loaded: |
self.music_data.preload_song(song_id) |
self.loaded.add(song_id) |
def baseline_shuffle(self): |
# the same as DataLoader with shuffle=True, but with a music preloader |
# 2021/10/18 sorry I have to remove preloader |
#for song in range(self.music_data.get_num_songs()): |
# self.preload(song) |
yield from torch.randperm(len(self), generator=self.generator).tolist() |
def shuffling_iter(self): |
# shuffle song list first |
shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator) |
# split song list into chunks |
chunks = torch.split(shuffle_song, self.shuffle_size) |
for nc, songs in enumerate(chunks): |
# sort songs to make preloader read more sequential |
songs = torch.sort(songs)[0].tolist() |
buf = [] |
for song in songs: |
# collect segment ids |
seg_start, seg_end = self.music_data.get_song_segments(song) |
buf += list(range(seg_start, seg_end)) |
if nc == 0: |
# load first chunk |
for song in songs: |
self.preload(song) |
# shuffle segments from song chunk |
shuffle_segs = torch.randperm(len(buf), generator=self.generator2) |
shuffle_segs = [buf[x] for x in shuffle_segs] |
preload_cnt = 0 |
for i, idx in enumerate(shuffle_segs): |
# output shuffled segment idx |
yield idx |
# preload next chunk |
while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size: |
song = shuffle_song[len(self.loaded)].item() |
self.preload(song) |
preload_cnt += 1 |
def non_shuffling_iter(self): |
# just return 0 ... len(dataset)-1 |
yield from range(len(self)) |
def __iter__(self): |
if self.shuffle: |
if self.shuffle_size is None: |
return self.baseline_shuffle() |
else: |
return self.shuffling_iter() |
return self.non_shuffling_iter() |
# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder |
class SegmentedDataLoader: |
def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): |
assert train_val in {'train', 'validate'} |
self.dataset = MusicSegmentDataset(configs, train_val) |
assert configs['batch_size'] % 2 == 0 |
self.batch_size = configs['batch_size'] |
self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) |
self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) |
self.num_workers = num_workers |
self.configs = configs |
self.pin_memory = pin_memory |
self.prefetch_factor = prefetch_factor |
# you can change shuffle to True/False |
self.shuffle = True |
# you can change augmented to True/False |
self.augmented = True |
# you can change eval time shift to True/False |
self.eval_time_shift = False |
self.loader = DataLoader( |
self.dataset, |
sampler=self.sampler, |
batch_size=None, |
num_workers=self.num_workers, |
pin_memory=self.pin_memory, |
prefetch_factor=self.prefetch_factor |
) |
def set_epoch(self, epoch): |
self.shuffler.set_epoch(epoch) |
def __iter__(self): |
self.dataset.augmented = self.augmented |
self.dataset.eval_time_shift = self.eval_time_shift |
self.shuffler.shuffle = self.shuffle |
return iter(self.loader) |
def __len__(self): |
return len(self.loader) |
@ -0,0 +1,89 @@ |
import argparse |
import csv |
import os |
import warnings |
import scipy.io |
import numpy as np |
import torch |
import torch.fft |
with warnings.catch_warnings(): |
warnings.simplefilter("ignore") |
import torchaudio |
from .audio import get_audio |
class AIR: |
def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000): |
print('loading Aachen IR dataset') |
with open(list_csv, 'r') as fin: |
reader = csv.reader(fin) |
airs = [] |
firstrow = next(reader) |
for row in reader: |
airs.append(row[0]) |
data = [] |
to_len = int(length * sample_rate) |
self.names = [] |
for name in airs: |
mat = scipy.io.loadmat(os.path.join(air_dir, name)) |
h_air = torch.tensor(mat['h_air'].astype(np.float32)) |
assert h_air.shape[0] == 1 |
h_air = h_air[0] |
air_info = mat['air_info'] |
fs = int(air_info['fs'][0][0][0][0]) |
self.names.append(str(air_info['room'][0][0][0])) |
resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air) |
truncated = resampled[0:to_len] |
freqd = torch.fft.rfft(truncated, fftconv_n) |
data.append(freqd) |
self.data = torch.stack(data) |
def random_choose(self, num): |
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) |
return self.data[indices] |
def random_choose_name(self): |
index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item() |
return self.data[index], self.names[index] |
class MicIRP: |
def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000): |
print('loading microphone IR dataset') |
with open(list_csv, 'r') as fin: |
reader = csv.reader(fin) |
mics = [] |
firstrow = next(reader) |
for row in reader: |
mics.append(row[0]) |
data = [] |
to_len = int(length * sample_rate) |
for name in mics: |
smp, smprate = get_audio(os.path.join(mic_dir, name)) |
smp = torch.FloatTensor(smp).mean(dim=0) |
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) |
truncated = resampled[0:to_len] |
freqd = torch.fft.rfft(truncated, fftconv_n) |
data.append(freqd) |
self.data = torch.stack(data) |
def random_choose(self, num): |
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) |
return self.data[indices] |
if __name__ == '__main__': |
args = argparse.ArgumentParser() |
args.add_argument('air') |
args.add_argument('out') |
args = args.parse_args() |
with open(args.out, 'w', encoding='utf8', newline='\n') as fout: |
writer = csv.writer(fout) |
writer.writerow(['file']) |
files = [] |
for name in os.listdir(args.air): |
if name.endswith('.mat'): |
files.append(name) |
files.sort() |
for name in files: |
writer.writerow([name]) |
@ -0,0 +1,63 @@ |
import torch |
import torchaudio |
class MelSpec(torch.nn.Module): |
def __init__(self, |
sample_rate=8000, |
stft_n=1024, |
stft_hop=256, |
f_min=300, |
f_max=4000, |
n_mels=256, |
naf_mode=False, |
mel_log='log', |
spec_norm='l2'): |
super(MelSpec, self).__init__() |
self.naf_mode = naf_mode |
self.mel_log = mel_log |
self.spec_norm = spec_norm |
self.mel = torchaudio.transforms.MelSpectrogram( |
sample_rate=sample_rate, |
n_fft=stft_n, |
hop_length=stft_hop, |
f_min=f_min, |
f_max=f_max, |
n_mels=n_mels, |
window_fn=torch.hann_window, |
power = 1 if naf_mode else 2, |
pad_mode = 'constant' if naf_mode else 'reflect', |
norm = 'slaney' if naf_mode else None, |
mel_scale = 'slaney' if naf_mode else 'htk' |
) |
def forward(self, x): |
# normalize volume |
p = 1e999 if self.spec_norm == 'max' else 2 |
x = torch.nn.functional.normalize(x, p=p, dim=-1) |
if self.naf_mode: |
x = self.mel(x) + 0.06 |
else: |
x = self.mel(x) + 1e-8 |
if self.mel_log == 'log10': |
x = torch.log10(x) |
elif self.mel_log == 'log': |
x = torch.log(x) |
if self.spec_norm == 'max': |
x = x - torch.amax(x, dim=(-2,-1), keepdim=True) |
return x |
def build_mel_spec_layer(params): |
return MelSpec( |
sample_rate = params['sample_rate'], |
stft_n = params['stft_n'], |
stft_hop = params['stft_hop'], |
f_min = params['f_min'], |
f_max = params['f_max'], |
n_mels = params['n_mels'], |
naf_mode = params.get('naf_mode', False), |
mel_log = params.get('mel_log', 'log'), |
spec_norm = params.get('spec_norm', 'l2') |
) |
@ -0,0 +1,109 @@ |
import csv |
import os |
import warnings |
import tqdm |
import numpy as np |
import torch |
with warnings.catch_warnings(): |
warnings.simplefilter("ignore") |
import torchaudio |
from .simpleutils import get_hash |
from .audio import get_audio |
class NoiseData: |
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir): |
print('loading noise dataset') |
hashes = [] |
with open(list_csv, 'r') as fin: |
reader = csv.reader(fin) |
noises = [] |
firstrow = next(reader) |
for row in reader: |
noises.append(row[0]) |
hashes.append(get_hash(row[0])) |
hash = get_hash(''.join(hashes)) |
#self.data = self.load_from_cache(list_csv, cache_dir, hash) |
#if self.data is not None: |
# print(self.data.shape) |
# return |
data = [] |
silence_threshold = 0 |
self.names = [] |
for name in tqdm.tqdm(noises): |
smp, smprate = get_audio(os.path.join(noise_dir, name)) |
smp = torch.from_numpy(smp.astype(np.float32)) |
# convert to mono |
smp = smp.mean(dim=0) |
# strip silence start/end |
abs_smp = torch.abs(smp) |
if torch.max(abs_smp) <= silence_threshold: |
print('%s too silent' % name) |
continue |
has_sound = (abs_smp > silence_threshold).to(torch.int) |
start = int(torch.argmax(has_sound)) |
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0))) |
smp = smp[max(start, 0) : end] |
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) |
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999) |
data.append(resampled) |
self.names.append(name) |
self.data = torch.cat(data) |
self.boundary = [0] + [x.shape[0] for x in data] |
self.boundary = torch.LongTensor(self.boundary).cumsum(0) |
del data |
#self.save_to_cache(list_csv, cache_dir, hash, self.data) |
print(self.data.shape) |
def load_from_cache(self, list_csv, cache_dir, hash): |
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') |
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') |
if os.path.exists(loc) and os.path.exists(loc2): |
with open(loc2, 'r') as fin: |
read_hash = fin.read().strip() |
if read_hash != hash: |
return None |
print('cache hit!') |
return torch.from_numpy(np.fromfile(loc, dtype=np.float32)) |
return None |
def save_to_cache(self, list_csv, cache_dir, hash, audio): |
os.makedirs(cache_dir, exist_ok=True) |
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') |
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') |
with open(loc2, 'w') as fout: |
fout.write(hash) |
print('save to cache') |
audio.numpy().tofile(loc) |
def random_choose(self, num, duration, out_name=False): |
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long) |
out = torch.zeros([num, duration], dtype=torch.float32) |
for i in range(num): |
start = int(indices[i]) |
end = start + duration |
out[i] = self.data[start:end] |
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1 |
if out_name: |
return out, [self.names[x] for x in name_lookup] |
return out |
# x is a 2d array |
def add_noises(self, x, snr_min, snr_max, out_name=False): |
eps = 1e-12 |
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name) |
if out_name: |
noise, noise_name = noise |
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt() |
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt() |
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max) |
ratio = vol_x / vol_noise |
ratio *= 10 ** -(snr / 20) |
x_aug = x + ratio.unsqueeze(1) * noise |
if out_name: |
return x_aug, noise_name, snr |
return x_aug |
@ -0,0 +1,55 @@ |
import csv |
import os |
import numpy as np |
import torch |
from torch.utils.data import DataLoader, Dataset |
import torchaudio |
import tqdm |
from .audio import get_audio |
class Preprocessor(Dataset): |
def __init__(self, files, dir, sample_rate): |
self.files = files |
self.dir = dir |
self.resampler = {} |
self.sample_rate = sample_rate |
def __getitem__(self, n): |
dat = get_audio(os.path.join(self.dir, self.files[n])) |
wav, smprate = dat |
if smprate not in self.resampler: |
self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate) |
wav = torch.Tensor(wav) |
wav = wav.mean(dim=0) |
wav = self.resampler[smprate](torch.Tensor(wav)) |
# quantize to 16 bit again |
wav *= 32768 |
torch.clamp(wav, -32768, 32767, out=wav) |
wav = wav.to(torch.int16) |
return wav |
def __len__(self): |
return len(self.files) |
def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out): |
print('converting music to wav') |
with open(music_csv) as fin: |
reader = csv.reader(fin) |
next(reader) |
files = [row[0] for row in reader] |
preprocessor = Preprocessor(files, music_dir, sample_rate) |
loader = DataLoader(preprocessor, num_workers=4, batch_size=None) |
out_file = open(preprocess_out + '.bin', 'wb') |
song_lens = [] |
for wav in tqdm.tqdm(loader): |
# torch.set_num_threads(1) # default multithreading causes cpu contention |
wav = wav.numpy() |
out_file.write(wav.tobytes()) |
song_lens.append(wav.shape[0]) |
out_file.close() |
np.save(preprocess_out, np.array(song_lens, dtype=np.int64)) |
@ -0,0 +1,26 @@ |
# every utils that don't use torch |
import hashlib |
import json |
import logging |
import multiprocessing as mp |
import os |
def get_hash(s): |
m = hashlib.md5() |
m.update(s.encode('utf8')) |
return m.hexdigest() |
def read_config(path): |
with open(path, 'r') as fin: |
return json.load(fin) |
def init_logger(app_name): |
os.makedirs('logs', exist_ok=True) |
logger = mp.get_logger() |
logger.setLevel(logging.INFO) |
handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8") |
handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s')) |
logger.addHandler(handler) |
@ -0,0 +1,42 @@ |
import torch |
class SpecAugment: |
def __init__(self, params): |
self.freq_min = params.get('cutout_min', 0.1) # 5 |
self.freq_max = params.get('cutout_max', 0.5) # 20 |
self.time_min = params.get('cutout_min', 0.1) # 5 |
self.time_max = params.get('cutout_max', 0.5) # 16 |
self.cutout_min = params.get('cutout_min', 0.1) # 0.1 |
self.cutout_max = params.get('cutout_max', 0.5) # 0.4 |
def get_mask(self, F, T): |
mask = torch.zeros(F, T) |
# cutout |
cutout_max = self.cutout_max |
cutout_min = self.cutout_min |
f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) |
f = int(f) |
f0 = torch.randint(0, F - f + 1, (1,)) |
t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) |
t = int(t) |
t0 = torch.randint(0, T - t + 1, (1,)) |
mask[f0:f0+f, t0:t0+t] = 1 |
# frequency masking |
f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min)) |
f = int(f) |
f0 = torch.randint(0, F - f + 1, (1,)) |
mask[f0:f0+f, :] = 1 |
# time masking |
t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min)) |
t = int(t) |
t0 = torch.randint(0, T - t + 1, (1,)) |
mask[:, t0:t0+t] = 1 |
return mask |
def augment(self, x): |
mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device) |
x = x * (1 - mask) |
return x |
@ -0,0 +1,152 @@ |
## This training script is with reference to https://github.com/stdio2016/pfann |
from typing import Any |
import numpy as np |
import torch |
from torch import nn |
from towhee.trainer.callback import Callback |
from towhee.trainer.trainer import Trainer |
from towhee.trainer.training_config import TrainingConfig |
from .datautil.dataset_v2 import SegmentedDataLoader |
from .datautil.simpleutils import read_config |
from .datautil.specaug import SpecAugment |
from torch.cuda.amp import autocast, GradScaler |
# other requirements: torchvision, torchmetrics==0.7.0 |
def similarity_loss(y, tau): |
a = torch.matmul(y, y.T) |
a /= tau |
Ls = [] |
for i in range(y.shape[0]): |
nn_self = torch.cat([a[i, :i], a[i, i + 1:]]) |
softmax = torch.nn.functional.log_softmax(nn_self, dim=0) |
Ls.append(softmax[i if i % 2 == 0 else i - 1]) |
Ls = torch.stack(Ls) |
loss = torch.sum(Ls) / -y.shape[0] |
return loss |
class SetEpochCallback(Callback): |
def __init__(self, dataloader): |
super().__init__() |
self.dataloader = dataloader |
def on_epoch_begin(self, epochs, logs): |
self.dataloader.set_epoch(epochs) |
class NNFPTrainer(Trainer): |
def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, |
model_card, params): |
super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, |
model_card) |
self.specaug = SpecAugment(params) |
self.scaler = GradScaler() |
self.tau = params.get('tau', 0.05) |
self.losses = [] |
self.set_epoch_callback = SetEpochCallback(self.train_dataloader) |
self.callbacks.add_callback(self.set_epoch_callback) |
print('evaluate before fine-tune...') |
self.evaluate(self.model, dict()) |
def train_step(self, model: nn.Module, inputs: Any) -> dict: |
x = inputs |
self.optimizer.zero_grad() |
x = torch.flatten(x, 0, 1) |
x = self.specaug.augment(x) |
with autocast(): |
y = model(x.to(self.configs.device)) |
loss = similarity_loss(y, self.tau) |
self.scaler.scale(loss).backward() |
self.scaler.step(self.optimizer) |
self.scaler.update() |
self.lr_scheduler.step() |
lossnum = float(loss.item()) |
self.losses.append(lossnum) |
step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0} |
return step_logs |
@torch.no_grad() |
def evaluate(self, model: nn.Module, logs: dict) -> dict: |
validate_N = 0 |
y_embed = [] |
minibatch = 40 |
if torch.cuda.get_device_properties(0).total_memory > 11e9: |
minibatch = 640 |
for x in self.eval_dataloader: |
x = torch.flatten(x, 0, 1) |
for xx in torch.split(x, minibatch): |
y = model(xx.to(self.configs.device)).cpu() |
y_embed.append(y) |
y_embed = torch.cat(y_embed) |
y_embed_org = y_embed[0::2] |
y_embed_aug = y_embed[1::2].to(self.configs.device) |
# compute validation score on GPU |
self_score = [] |
for embeds in torch.split(y_embed_org, 320): |
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) |
self_score.append(A.diagonal(-validate_N).cpu()) |
validate_N += embeds.shape[0] |
self_score = torch.cat(self_score).to(self.configs.device) |
ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device) |
for embeds in torch.split(y_embed_org, 320): |
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) |
ranks += (A.T >= self_score).sum(dim=0) |
acc = int((ranks == 1).sum()) |
print('\nvalidate score: %f' % (acc / validate_N,)) |
# logs['epoch_metric'] = acc / validate_N |
return logs |
def train_nnfp(model, config_json_path): |
model.train() |
params = read_config(config_json_path) |
train_dataloader = SegmentedDataLoader('train', params, num_workers=4) |
print('training data contains %d samples' % len(train_dataloader.dataset)) |
train_dataloader.shuffle = True |
train_dataloader.eval_time_shift = False |
train_dataloader.augmented = True |
train_dataloader.set_epoch(0) |
val_dataloader = SegmentedDataLoader('validate', params, num_workers=4) |
val_dataloader.shuffle = False |
val_dataloader.eval_time_shift = True |
val_dataloader.set_epoch(-1) |
training_config = TrainingConfig( |
batch_size=params['batch_size'], |
epoch_num=params['epoch'], |
output_dir='fine_tune_output', |
lr_scheduler_type='cosine', |
eval_strategy='epoch', |
lr=params['lr'] |
) |
optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr) |
nnfp_trainer = NNFPTrainer( |
model=model, |
training_config=training_config, |
train_dataset=train_dataloader.dataset, |
eval_dataset=None, |
train_dataloader=train_dataloader, |
eval_dataloader=val_dataloader, |
model_card=None, |
params = params |
) |
nnfp_trainer.set_optimizer(optimizer) |
nnfp_trainer.train() |
Reference in new issue