nnfp
copied
ChengZi
2 years ago
11 changed files with 997 additions and 0 deletions
@ -0,0 +1,152 @@ |
|||
import json |
|||
import os |
|||
import subprocess |
|||
|
|||
import numpy as np |
|||
import wave |
|||
import io |
|||
|
|||
|
|||
# because builtin wave won't read wav files with more than 2 channels |
|||
class HackExtensibleWave: |
|||
def __init__(self, stream): |
|||
self.stream = stream |
|||
self.pos = 0 |
|||
def read(self, n): |
|||
r = self.stream.read(n) |
|||
new_pos = self.pos + len(r) |
|||
if self.pos < 20 and self.pos + n >= 20: |
|||
r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] |
|||
elif 20 <= self.pos < 22: |
|||
r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] |
|||
self.pos = new_pos |
|||
return r |
|||
|
|||
def ffmpeg_get_audio(filename): |
|||
error_log = open(os.devnull, 'w') |
|||
proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], |
|||
stderr=error_log, |
|||
stdin=open(os.devnull), |
|||
stdout=subprocess.PIPE, |
|||
bufsize=1000000) |
|||
try: |
|||
dat = proc.stdout.read() |
|||
wav = wave.open(HackExtensibleWave(io.BytesIO(dat))) |
|||
ch = wav.getnchannels() |
|||
rate = wav.getframerate() |
|||
n = wav.getnframes() |
|||
dat = wav.readframes(n) |
|||
del wav |
|||
samples = np.frombuffer(dat, dtype=np.int16) / 32768 |
|||
samples = samples.reshape([-1, ch]).T |
|||
return samples, rate |
|||
except (wave.Error, EOFError): |
|||
print('failed to decode %s. maybe the file is broken!' % filename) |
|||
return np.zeros([1, 0]), 44100 |
|||
|
|||
def wave_get_audio(filename): |
|||
with open(filename, 'rb') as fin: |
|||
wav = wave.open(HackExtensibleWave(fin)) |
|||
smpwidth = wav.getsampwidth() |
|||
if smpwidth not in {1, 2, 3}: |
|||
return None |
|||
n = wav.getnframes() |
|||
if smpwidth == 1: |
|||
samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1 |
|||
elif smpwidth == 2: |
|||
samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768 |
|||
elif smpwidth == 3: |
|||
a = np.frombuffer(wav.readframes(n), dtype=np.uint8) |
|||
samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1) |
|||
del a |
|||
samples = samples / 8388608 |
|||
samples = samples.reshape([-1, wav.getnchannels()]).T |
|||
return samples, wav.getframerate() |
|||
|
|||
def get_audio(filename): |
|||
if str(filename).endswith('.wav'): |
|||
try: |
|||
a = wave_get_audio(filename) |
|||
if a: return a |
|||
except Exception: |
|||
pass |
|||
return ffmpeg_get_audio(filename) |
|||
|
|||
class FfmpegStream: |
|||
def __init__(self, proc, sample_rate, nchannels, tmpfile): |
|||
self.proc = proc |
|||
self.sample_rate = sample_rate |
|||
self.nchannels = nchannels |
|||
self.tmpfile = None |
|||
self.stream = self.gen_stream() |
|||
self.tmpfile = tmpfile |
|||
def __del__(self): |
|||
self.proc.terminate() |
|||
self.proc.communicate() |
|||
del self.proc |
|||
if self.tmpfile: |
|||
os.unlink(self.tmpfile) |
|||
def gen_stream(self): |
|||
num = yield np.array([], dtype=np.int16) |
|||
if not num: num = 1024 |
|||
while True: |
|||
to_read = num * self.nchannels * 2 |
|||
dat = self.proc.stdout.read(to_read) |
|||
num = yield np.frombuffer(dat, dtype=np.int16) |
|||
if not num: num = 1024 |
|||
if len(dat) < to_read: |
|||
break |
|||
|
|||
def ffmpeg_stream_audio(filename, is_tmp=False): |
|||
while 1: |
|||
try: |
|||
stderr=open(os.devnull, 'w') |
|||
stdin=open(os.devnull) |
|||
break |
|||
except PermissionError: |
|||
print('PermissionError occured, try again') |
|||
proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams', |
|||
'-select_streams', 'a', '-print_format', 'json'], |
|||
stderr=stderr, |
|||
stdin=stdin, |
|||
stdout=subprocess.PIPE) |
|||
prop = json.loads(proc.stdout.read()) |
|||
if 'streams' not in prop: |
|||
raise RuntimeError('FFmpeg cannot decode audio') |
|||
sample_rate = int(prop['streams'][0]['sample_rate']) |
|||
nchannels = prop['streams'][0]['channels'] |
|||
proc = subprocess.Popen(['ffmpeg', '-i', filename, |
|||
'-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'], |
|||
stderr=stderr, |
|||
stdin=stdin, |
|||
stdout=subprocess.PIPE) |
|||
tmpfile = None |
|||
if is_tmp: |
|||
tmpfile = filename |
|||
return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile) |
|||
|
|||
class WaveStream: |
|||
def __init__(self, filename, is_tmp=False): |
|||
self.is_tmp = None |
|||
self.file = open(filename, 'rb') |
|||
self.wave = wave.open(HackExtensibleWave(self.file)) |
|||
self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth() |
|||
self.sample_rate = self.wave.getframerate() |
|||
self.nchannels = self.wave.getnchannels() |
|||
if self.wave.getsampwidth() != 2: |
|||
raise NotImplementedError('wave stream currently only supports 16bit wav') |
|||
self.stream = self.gen_stream() |
|||
self.is_tmp = filename if is_tmp else None |
|||
def gen_stream(self): |
|||
num = yield np.array([], dtype=np.int16) |
|||
if not num: num = 1024 |
|||
while True: |
|||
dat = self.wave.readframes(num) |
|||
num = yield np.frombuffer(dat, dtype=np.int16) |
|||
if not num: num = 1024 |
|||
if len(dat) < num * self.smpsize: |
|||
break |
|||
def __del__(self): |
|||
if self.is_tmp: |
|||
os.unlink(self.is_tmp) |
|||
|
@ -0,0 +1,301 @@ |
|||
import os |
|||
|
|||
import numpy as np |
|||
import torch |
|||
import torch.fft |
|||
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler |
|||
|
|||
from .melspec import build_mel_spec_layer |
|||
from .noise import NoiseData |
|||
from .ir import AIR, MicIRP |
|||
from .preprocess import preprocess_music |
|||
|
|||
class NumpyMemmapDataset(Dataset): |
|||
def __init__(self, path, dtype): |
|||
self.path = path |
|||
self.f = np.memmap(self.path, dtype=dtype) |
|||
|
|||
def __len__(self): |
|||
return len(self.f) |
|||
|
|||
def __getitem__(self, idx): |
|||
return self.f[idx] |
|||
|
|||
# need these, otherwise the pickler will read the whole data in memory |
|||
def __getstate__(self): |
|||
return {'path': self.path, 'dtype': self.f.dtype} |
|||
|
|||
def __setstate__(self, d): |
|||
self.__init__(d['path'], d['dtype']) |
|||
|
|||
# data augmentation on music dataset |
|||
class MusicSegmentDataset(Dataset): |
|||
def __init__(self, params, train_val): |
|||
# load some configs |
|||
assert train_val in {'train', 'validate'} |
|||
sample_rate = params['sample_rate'] |
|||
self.augmented = True |
|||
self.eval_time_shift = True |
|||
self.segment_size = int(params['segment_size'] * sample_rate) |
|||
self.hop_size = int(params['hop_size'] * sample_rate) |
|||
self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios |
|||
self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb |
|||
self.params = params |
|||
|
|||
# get fft size needed for reverb |
|||
fftconv_n = 1024 |
|||
air_len = int(params['air']['length'] * sample_rate) |
|||
ir_len = int(params['micirp']['length'] * sample_rate) |
|||
fft_needed = self.segment_size + self.pad_start + air_len + ir_len |
|||
while fftconv_n < fft_needed: |
|||
fftconv_n *= 2 |
|||
self.fftconv_n = fftconv_n |
|||
|
|||
# datasets data augmentation |
|||
cache_dir = params['cache_dir'] |
|||
os.makedirs(cache_dir, exist_ok=True) |
|||
if params['noise'][train_val]: |
|||
self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir) |
|||
else: self.noise = None |
|||
if params['air'][train_val]: |
|||
self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) |
|||
else: self.air = None |
|||
if params['micirp'][train_val]: |
|||
self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) |
|||
else: self.micirp = None |
|||
|
|||
# Load music dataset as memory mapped file |
|||
file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0] |
|||
file_name = os.path.join(cache_dir, '1' + file_name) |
|||
if os.path.exists(file_name + '.npy'): |
|||
print('load cached music from %s.bin' % file_name) |
|||
else: |
|||
preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name) |
|||
self.f = NumpyMemmapDataset(file_name + '.bin', np.int16) |
|||
|
|||
# some segmentation settings |
|||
song_len = np.load(file_name + '.npy') |
|||
self.cues = [] # start location of segment i |
|||
self.offset_left = [] # allowed left shift of segment i |
|||
self.offset_right = [] # allowed right shift of segment i |
|||
self.song_range = [] # range of song i is (start time, end time, start idx, end idx) |
|||
t = 0 |
|||
for duration in song_len: |
|||
num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size |
|||
start_cue = len(self.cues) |
|||
for idx in range(num_segs): |
|||
my_time = idx * self.hop_size |
|||
self.cues.append(t + my_time) |
|||
self.offset_left.append(my_time) |
|||
self.offset_right.append(duration - my_time) |
|||
end_cue = len(self.cues) |
|||
self.song_range.append((t, t + duration, start_cue, end_cue)) |
|||
t += duration |
|||
|
|||
# convert to torch Tensor to allow sharing memory |
|||
self.cues = torch.LongTensor(self.cues) |
|||
self.offset_left = torch.LongTensor(self.offset_left) |
|||
self.offset_right = torch.LongTensor(self.offset_right) |
|||
|
|||
# setup mel spectrogram |
|||
self.mel = build_mel_spec_layer(params) |
|||
|
|||
def get_single_segment(self, idx, offset, length): |
|||
cue = int(self.cues[idx]) + offset |
|||
left = int(self.offset_left[idx]) + offset |
|||
right = int(self.offset_right[idx]) - offset |
|||
|
|||
# choose a segment |
|||
# segment looks like: |
|||
# cue |
|||
# v |
|||
# [ pad_start | segment_size | ... ] |
|||
# \--- time_offset ---/ |
|||
segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)] |
|||
segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)]) |
|||
|
|||
# convert 16 bit to float |
|||
return segment * np.float32(1/32768) |
|||
|
|||
def __getitem__(self, indices): |
|||
if self.eval_time_shift: |
|||
# collect all segments. It is faster to do batch processing |
|||
shift_range = self.segment_size//2 |
|||
x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices] |
|||
|
|||
# random time offset only on augmented segment, as a query audio |
|||
segment_size = self.pad_start + self.segment_size |
|||
offset1 = [self.segment_size//4] * len(x) |
|||
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
|||
else: |
|||
# collect all segments. It is faster to do batch processing |
|||
x = [self.get_single_segment(i, 0, self.time_offset) for i in indices] |
|||
|
|||
# random time offset |
|||
shift_range = self.time_offset - self.segment_size |
|||
segment_size = self.pad_start + self.segment_size |
|||
offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
|||
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
|||
|
|||
x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)] |
|||
x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32)) |
|||
x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)] |
|||
x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32)) |
|||
|
|||
if self.augmented: |
|||
# background noise |
|||
if self.noise is not None: |
|||
x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max']) |
|||
|
|||
# impulse response |
|||
spec = torch.fft.rfft(x_aug, self.fftconv_n) |
|||
if self.air is not None: |
|||
spec *= self.air.random_choose(spec.shape[0]) |
|||
if self.micirp is not None: |
|||
spec *= self.micirp.random_choose(spec.shape[0]) |
|||
x_aug = torch.fft.irfft(spec, self.fftconv_n) |
|||
x_aug = x_aug[..., self.pad_start:segment_size] |
|||
|
|||
# output [x1_orig, x1_aug, x2_orig, x2_aug, ...] |
|||
x = [x_orig, x_aug] if self.augmented else [x_orig] |
|||
x = torch.stack(x, dim=1) |
|||
if self.mel is not None: |
|||
return self.mel(x) |
|||
return x |
|||
|
|||
def fan_si_le(self): |
|||
raise NotImplementedError('煩死了') |
|||
|
|||
def zuo_bu_chu_lai(self): |
|||
raise NotImplementedError('做不起來') |
|||
|
|||
def __len__(self): |
|||
return len(self.cues) |
|||
|
|||
def get_num_songs(self): |
|||
return len(self.song_range) |
|||
|
|||
def get_song_segments(self, song_id): |
|||
return self.song_range[song_id][2:4] |
|||
|
|||
def preload_song(self, song_id): |
|||
start, end, _, _ = self.song_range[song_id] |
|||
return self.f[start : end].copy() |
|||
|
|||
class TwoStageShuffler(Sampler): |
|||
def __init__(self, music_data: MusicSegmentDataset, shuffle_size): |
|||
self.music_data = music_data |
|||
self.shuffle_size = shuffle_size |
|||
self.shuffle = True |
|||
self.loaded = set() |
|||
self.generator = torch.Generator() |
|||
self.generator2 = torch.Generator() |
|||
|
|||
def set_epoch(self, epoch): |
|||
self.generator.manual_seed(42 + epoch) |
|||
self.generator2.manual_seed(42 + epoch) |
|||
|
|||
def __len__(self): |
|||
return len(self.music_data) |
|||
|
|||
def preload(self, song_id): |
|||
if song_id not in self.loaded: |
|||
self.music_data.preload_song(song_id) |
|||
self.loaded.add(song_id) |
|||
|
|||
def baseline_shuffle(self): |
|||
# the same as DataLoader with shuffle=True, but with a music preloader |
|||
# 2021/10/18 sorry I have to remove preloader |
|||
#for song in range(self.music_data.get_num_songs()): |
|||
# self.preload(song) |
|||
|
|||
yield from torch.randperm(len(self), generator=self.generator).tolist() |
|||
|
|||
def shuffling_iter(self): |
|||
# shuffle song list first |
|||
shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator) |
|||
|
|||
# split song list into chunks |
|||
chunks = torch.split(shuffle_song, self.shuffle_size) |
|||
for nc, songs in enumerate(chunks): |
|||
# sort songs to make preloader read more sequential |
|||
songs = torch.sort(songs)[0].tolist() |
|||
|
|||
buf = [] |
|||
for song in songs: |
|||
# collect segment ids |
|||
seg_start, seg_end = self.music_data.get_song_segments(song) |
|||
buf += list(range(seg_start, seg_end)) |
|||
|
|||
if nc == 0: |
|||
# load first chunk |
|||
for song in songs: |
|||
self.preload(song) |
|||
|
|||
# shuffle segments from song chunk |
|||
shuffle_segs = torch.randperm(len(buf), generator=self.generator2) |
|||
shuffle_segs = [buf[x] for x in shuffle_segs] |
|||
preload_cnt = 0 |
|||
for i, idx in enumerate(shuffle_segs): |
|||
# output shuffled segment idx |
|||
yield idx |
|||
|
|||
# preload next chunk |
|||
while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size: |
|||
song = shuffle_song[len(self.loaded)].item() |
|||
self.preload(song) |
|||
preload_cnt += 1 |
|||
|
|||
def non_shuffling_iter(self): |
|||
# just return 0 ... len(dataset)-1 |
|||
yield from range(len(self)) |
|||
|
|||
def __iter__(self): |
|||
if self.shuffle: |
|||
if self.shuffle_size is None: |
|||
return self.baseline_shuffle() |
|||
else: |
|||
return self.shuffling_iter() |
|||
return self.non_shuffling_iter() |
|||
|
|||
# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder |
|||
class SegmentedDataLoader: |
|||
def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): |
|||
assert train_val in {'train', 'validate'} |
|||
self.dataset = MusicSegmentDataset(configs, train_val) |
|||
assert configs['batch_size'] % 2 == 0 |
|||
self.batch_size = configs['batch_size'] |
|||
self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) |
|||
self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) |
|||
self.num_workers = num_workers |
|||
self.configs = configs |
|||
self.pin_memory = pin_memory |
|||
self.prefetch_factor = prefetch_factor |
|||
|
|||
# you can change shuffle to True/False |
|||
self.shuffle = True |
|||
# you can change augmented to True/False |
|||
self.augmented = True |
|||
# you can change eval time shift to True/False |
|||
self.eval_time_shift = False |
|||
|
|||
self.loader = DataLoader( |
|||
self.dataset, |
|||
sampler=self.sampler, |
|||
batch_size=None, |
|||
num_workers=self.num_workers, |
|||
pin_memory=self.pin_memory, |
|||
prefetch_factor=self.prefetch_factor |
|||
) |
|||
|
|||
def set_epoch(self, epoch): |
|||
self.shuffler.set_epoch(epoch) |
|||
|
|||
def __iter__(self): |
|||
self.dataset.augmented = self.augmented |
|||
self.dataset.eval_time_shift = self.eval_time_shift |
|||
self.shuffler.shuffle = self.shuffle |
|||
return iter(self.loader) |
|||
|
|||
def __len__(self): |
|||
return len(self.loader) |
@ -0,0 +1,89 @@ |
|||
import argparse |
|||
import csv |
|||
import os |
|||
import warnings |
|||
|
|||
import scipy.io |
|||
import numpy as np |
|||
import torch |
|||
import torch.fft |
|||
with warnings.catch_warnings(): |
|||
warnings.simplefilter("ignore") |
|||
import torchaudio |
|||
|
|||
from .audio import get_audio |
|||
|
|||
class AIR: |
|||
def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000): |
|||
print('loading Aachen IR dataset') |
|||
with open(list_csv, 'r') as fin: |
|||
reader = csv.reader(fin) |
|||
airs = [] |
|||
firstrow = next(reader) |
|||
for row in reader: |
|||
airs.append(row[0]) |
|||
data = [] |
|||
to_len = int(length * sample_rate) |
|||
self.names = [] |
|||
for name in airs: |
|||
mat = scipy.io.loadmat(os.path.join(air_dir, name)) |
|||
h_air = torch.tensor(mat['h_air'].astype(np.float32)) |
|||
assert h_air.shape[0] == 1 |
|||
h_air = h_air[0] |
|||
air_info = mat['air_info'] |
|||
fs = int(air_info['fs'][0][0][0][0]) |
|||
self.names.append(str(air_info['room'][0][0][0])) |
|||
resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air) |
|||
truncated = resampled[0:to_len] |
|||
freqd = torch.fft.rfft(truncated, fftconv_n) |
|||
data.append(freqd) |
|||
self.data = torch.stack(data) |
|||
|
|||
def random_choose(self, num): |
|||
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) |
|||
return self.data[indices] |
|||
|
|||
def random_choose_name(self): |
|||
index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item() |
|||
return self.data[index], self.names[index] |
|||
|
|||
class MicIRP: |
|||
def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000): |
|||
print('loading microphone IR dataset') |
|||
with open(list_csv, 'r') as fin: |
|||
reader = csv.reader(fin) |
|||
mics = [] |
|||
firstrow = next(reader) |
|||
for row in reader: |
|||
mics.append(row[0]) |
|||
data = [] |
|||
to_len = int(length * sample_rate) |
|||
for name in mics: |
|||
smp, smprate = get_audio(os.path.join(mic_dir, name)) |
|||
smp = torch.FloatTensor(smp).mean(dim=0) |
|||
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) |
|||
truncated = resampled[0:to_len] |
|||
freqd = torch.fft.rfft(truncated, fftconv_n) |
|||
data.append(freqd) |
|||
self.data = torch.stack(data) |
|||
|
|||
def random_choose(self, num): |
|||
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) |
|||
return self.data[indices] |
|||
|
|||
if __name__ == '__main__': |
|||
args = argparse.ArgumentParser() |
|||
args.add_argument('air') |
|||
args.add_argument('out') |
|||
args = args.parse_args() |
|||
|
|||
with open(args.out, 'w', encoding='utf8', newline='\n') as fout: |
|||
writer = csv.writer(fout) |
|||
writer.writerow(['file']) |
|||
files = [] |
|||
for name in os.listdir(args.air): |
|||
if name.endswith('.mat'): |
|||
files.append(name) |
|||
files.sort() |
|||
for name in files: |
|||
writer.writerow([name]) |
@ -0,0 +1,63 @@ |
|||
import torch |
|||
import torchaudio |
|||
|
|||
class MelSpec(torch.nn.Module): |
|||
def __init__(self, |
|||
sample_rate=8000, |
|||
stft_n=1024, |
|||
stft_hop=256, |
|||
f_min=300, |
|||
f_max=4000, |
|||
n_mels=256, |
|||
naf_mode=False, |
|||
mel_log='log', |
|||
spec_norm='l2'): |
|||
super(MelSpec, self).__init__() |
|||
self.naf_mode = naf_mode |
|||
self.mel_log = mel_log |
|||
self.spec_norm = spec_norm |
|||
self.mel = torchaudio.transforms.MelSpectrogram( |
|||
sample_rate=sample_rate, |
|||
n_fft=stft_n, |
|||
hop_length=stft_hop, |
|||
f_min=f_min, |
|||
f_max=f_max, |
|||
n_mels=n_mels, |
|||
window_fn=torch.hann_window, |
|||
power = 1 if naf_mode else 2, |
|||
pad_mode = 'constant' if naf_mode else 'reflect', |
|||
norm = 'slaney' if naf_mode else None, |
|||
mel_scale = 'slaney' if naf_mode else 'htk' |
|||
) |
|||
|
|||
def forward(self, x): |
|||
# normalize volume |
|||
p = 1e999 if self.spec_norm == 'max' else 2 |
|||
x = torch.nn.functional.normalize(x, p=p, dim=-1) |
|||
|
|||
if self.naf_mode: |
|||
x = self.mel(x) + 0.06 |
|||
else: |
|||
x = self.mel(x) + 1e-8 |
|||
|
|||
if self.mel_log == 'log10': |
|||
x = torch.log10(x) |
|||
elif self.mel_log == 'log': |
|||
x = torch.log(x) |
|||
|
|||
if self.spec_norm == 'max': |
|||
x = x - torch.amax(x, dim=(-2,-1), keepdim=True) |
|||
return x |
|||
|
|||
def build_mel_spec_layer(params): |
|||
return MelSpec( |
|||
sample_rate = params['sample_rate'], |
|||
stft_n = params['stft_n'], |
|||
stft_hop = params['stft_hop'], |
|||
f_min = params['f_min'], |
|||
f_max = params['f_max'], |
|||
n_mels = params['n_mels'], |
|||
naf_mode = params.get('naf_mode', False), |
|||
mel_log = params.get('mel_log', 'log'), |
|||
spec_norm = params.get('spec_norm', 'l2') |
|||
) |
@ -0,0 +1,109 @@ |
|||
import csv |
|||
import os |
|||
import warnings |
|||
|
|||
import tqdm |
|||
import numpy as np |
|||
import torch |
|||
with warnings.catch_warnings(): |
|||
warnings.simplefilter("ignore") |
|||
import torchaudio |
|||
|
|||
from .simpleutils import get_hash |
|||
from .audio import get_audio |
|||
|
|||
class NoiseData: |
|||
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir): |
|||
print('loading noise dataset') |
|||
hashes = [] |
|||
with open(list_csv, 'r') as fin: |
|||
reader = csv.reader(fin) |
|||
noises = [] |
|||
firstrow = next(reader) |
|||
for row in reader: |
|||
noises.append(row[0]) |
|||
hashes.append(get_hash(row[0])) |
|||
hash = get_hash(''.join(hashes)) |
|||
#self.data = self.load_from_cache(list_csv, cache_dir, hash) |
|||
#if self.data is not None: |
|||
# print(self.data.shape) |
|||
# return |
|||
data = [] |
|||
silence_threshold = 0 |
|||
self.names = [] |
|||
for name in tqdm.tqdm(noises): |
|||
smp, smprate = get_audio(os.path.join(noise_dir, name)) |
|||
smp = torch.from_numpy(smp.astype(np.float32)) |
|||
|
|||
# convert to mono |
|||
smp = smp.mean(dim=0) |
|||
|
|||
# strip silence start/end |
|||
abs_smp = torch.abs(smp) |
|||
if torch.max(abs_smp) <= silence_threshold: |
|||
print('%s too silent' % name) |
|||
continue |
|||
has_sound = (abs_smp > silence_threshold).to(torch.int) |
|||
start = int(torch.argmax(has_sound)) |
|||
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0))) |
|||
smp = smp[max(start, 0) : end] |
|||
|
|||
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) |
|||
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999) |
|||
data.append(resampled) |
|||
self.names.append(name) |
|||
self.data = torch.cat(data) |
|||
self.boundary = [0] + [x.shape[0] for x in data] |
|||
self.boundary = torch.LongTensor(self.boundary).cumsum(0) |
|||
del data |
|||
#self.save_to_cache(list_csv, cache_dir, hash, self.data) |
|||
print(self.data.shape) |
|||
|
|||
def load_from_cache(self, list_csv, cache_dir, hash): |
|||
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') |
|||
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') |
|||
if os.path.exists(loc) and os.path.exists(loc2): |
|||
with open(loc2, 'r') as fin: |
|||
read_hash = fin.read().strip() |
|||
if read_hash != hash: |
|||
return None |
|||
print('cache hit!') |
|||
return torch.from_numpy(np.fromfile(loc, dtype=np.float32)) |
|||
return None |
|||
|
|||
def save_to_cache(self, list_csv, cache_dir, hash, audio): |
|||
os.makedirs(cache_dir, exist_ok=True) |
|||
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') |
|||
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') |
|||
with open(loc2, 'w') as fout: |
|||
fout.write(hash) |
|||
print('save to cache') |
|||
audio.numpy().tofile(loc) |
|||
|
|||
def random_choose(self, num, duration, out_name=False): |
|||
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long) |
|||
out = torch.zeros([num, duration], dtype=torch.float32) |
|||
for i in range(num): |
|||
start = int(indices[i]) |
|||
end = start + duration |
|||
out[i] = self.data[start:end] |
|||
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1 |
|||
if out_name: |
|||
return out, [self.names[x] for x in name_lookup] |
|||
return out |
|||
|
|||
# x is a 2d array |
|||
def add_noises(self, x, snr_min, snr_max, out_name=False): |
|||
eps = 1e-12 |
|||
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name) |
|||
if out_name: |
|||
noise, noise_name = noise |
|||
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt() |
|||
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt() |
|||
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max) |
|||
ratio = vol_x / vol_noise |
|||
ratio *= 10 ** -(snr / 20) |
|||
x_aug = x + ratio.unsqueeze(1) * noise |
|||
if out_name: |
|||
return x_aug, noise_name, snr |
|||
return x_aug |
@ -0,0 +1,55 @@ |
|||
import csv |
|||
import os |
|||
|
|||
import numpy as np |
|||
import torch |
|||
from torch.utils.data import DataLoader, Dataset |
|||
import torchaudio |
|||
import tqdm |
|||
|
|||
from .audio import get_audio |
|||
|
|||
class Preprocessor(Dataset): |
|||
def __init__(self, files, dir, sample_rate): |
|||
self.files = files |
|||
self.dir = dir |
|||
self.resampler = {} |
|||
self.sample_rate = sample_rate |
|||
|
|||
def __getitem__(self, n): |
|||
dat = get_audio(os.path.join(self.dir, self.files[n])) |
|||
wav, smprate = dat |
|||
if smprate not in self.resampler: |
|||
self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate) |
|||
wav = torch.Tensor(wav) |
|||
wav = wav.mean(dim=0) |
|||
wav = self.resampler[smprate](torch.Tensor(wav)) |
|||
|
|||
# quantize to 16 bit again |
|||
wav *= 32768 |
|||
torch.clamp(wav, -32768, 32767, out=wav) |
|||
wav = wav.to(torch.int16) |
|||
return wav |
|||
|
|||
def __len__(self): |
|||
return len(self.files) |
|||
|
|||
def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out): |
|||
print('converting music to wav') |
|||
with open(music_csv) as fin: |
|||
reader = csv.reader(fin) |
|||
next(reader) |
|||
files = [row[0] for row in reader] |
|||
|
|||
preprocessor = Preprocessor(files, music_dir, sample_rate) |
|||
loader = DataLoader(preprocessor, num_workers=4, batch_size=None) |
|||
out_file = open(preprocess_out + '.bin', 'wb') |
|||
song_lens = [] |
|||
for wav in tqdm.tqdm(loader): |
|||
# torch.set_num_threads(1) # default multithreading causes cpu contention |
|||
|
|||
wav = wav.numpy() |
|||
out_file.write(wav.tobytes()) |
|||
song_lens.append(wav.shape[0]) |
|||
out_file.close() |
|||
np.save(preprocess_out, np.array(song_lens, dtype=np.int64)) |
@ -0,0 +1,26 @@ |
|||
# every utils that don't use torch |
|||
import hashlib |
|||
import json |
|||
import logging |
|||
import multiprocessing as mp |
|||
import os |
|||
|
|||
|
|||
def get_hash(s): |
|||
m = hashlib.md5() |
|||
m.update(s.encode('utf8')) |
|||
return m.hexdigest() |
|||
|
|||
|
|||
def read_config(path): |
|||
with open(path, 'r') as fin: |
|||
return json.load(fin) |
|||
|
|||
|
|||
def init_logger(app_name): |
|||
os.makedirs('logs', exist_ok=True) |
|||
logger = mp.get_logger() |
|||
logger.setLevel(logging.INFO) |
|||
handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8") |
|||
handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s')) |
|||
logger.addHandler(handler) |
@ -0,0 +1,42 @@ |
|||
import torch |
|||
|
|||
class SpecAugment: |
|||
def __init__(self, params): |
|||
self.freq_min = params.get('cutout_min', 0.1) # 5 |
|||
self.freq_max = params.get('cutout_max', 0.5) # 20 |
|||
self.time_min = params.get('cutout_min', 0.1) # 5 |
|||
self.time_max = params.get('cutout_max', 0.5) # 16 |
|||
|
|||
self.cutout_min = params.get('cutout_min', 0.1) # 0.1 |
|||
self.cutout_max = params.get('cutout_max', 0.5) # 0.4 |
|||
|
|||
def get_mask(self, F, T): |
|||
mask = torch.zeros(F, T) |
|||
# cutout |
|||
cutout_max = self.cutout_max |
|||
cutout_min = self.cutout_min |
|||
f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) |
|||
f = int(f) |
|||
f0 = torch.randint(0, F - f + 1, (1,)) |
|||
t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) |
|||
t = int(t) |
|||
t0 = torch.randint(0, T - t + 1, (1,)) |
|||
mask[f0:f0+f, t0:t0+t] = 1 |
|||
|
|||
# frequency masking |
|||
f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min)) |
|||
f = int(f) |
|||
f0 = torch.randint(0, F - f + 1, (1,)) |
|||
mask[f0:f0+f, :] = 1 |
|||
|
|||
# time masking |
|||
t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min)) |
|||
t = int(t) |
|||
t0 = torch.randint(0, T - t + 1, (1,)) |
|||
mask[:, t0:t0+t] = 1 |
|||
return mask |
|||
|
|||
def augment(self, x): |
|||
mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device) |
|||
x = x * (1 - mask) |
|||
return x |
@ -0,0 +1,152 @@ |
|||
## This training script is with reference to https://github.com/stdio2016/pfann |
|||
from typing import Any |
|||
|
|||
import numpy as np |
|||
import torch |
|||
from torch import nn |
|||
from towhee.trainer.callback import Callback |
|||
from towhee.trainer.trainer import Trainer |
|||
from towhee.trainer.training_config import TrainingConfig |
|||
|
|||
from .datautil.dataset_v2 import SegmentedDataLoader |
|||
from .datautil.simpleutils import read_config |
|||
from .datautil.specaug import SpecAugment |
|||
from torch.cuda.amp import autocast, GradScaler |
|||
|
|||
|
|||
# other requirements: torchvision, torchmetrics==0.7.0 |
|||
|
|||
|
|||
def similarity_loss(y, tau): |
|||
a = torch.matmul(y, y.T) |
|||
a /= tau |
|||
Ls = [] |
|||
for i in range(y.shape[0]): |
|||
nn_self = torch.cat([a[i, :i], a[i, i + 1:]]) |
|||
softmax = torch.nn.functional.log_softmax(nn_self, dim=0) |
|||
Ls.append(softmax[i if i % 2 == 0 else i - 1]) |
|||
Ls = torch.stack(Ls) |
|||
|
|||
loss = torch.sum(Ls) / -y.shape[0] |
|||
return loss |
|||
|
|||
|
|||
class SetEpochCallback(Callback): |
|||
def __init__(self, dataloader): |
|||
super().__init__() |
|||
self.dataloader = dataloader |
|||
|
|||
def on_epoch_begin(self, epochs, logs): |
|||
self.dataloader.set_epoch(epochs) |
|||
|
|||
|
|||
class NNFPTrainer(Trainer): |
|||
def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, |
|||
model_card, params): |
|||
super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, |
|||
model_card) |
|||
self.specaug = SpecAugment(params) |
|||
self.scaler = GradScaler() |
|||
self.tau = params.get('tau', 0.05) |
|||
self.losses = [] |
|||
self.set_epoch_callback = SetEpochCallback(self.train_dataloader) |
|||
self.callbacks.add_callback(self.set_epoch_callback) |
|||
print('evaluate before fine-tune...') |
|||
self.evaluate(self.model, dict()) |
|||
|
|||
def train_step(self, model: nn.Module, inputs: Any) -> dict: |
|||
x = inputs |
|||
self.optimizer.zero_grad() |
|||
|
|||
x = torch.flatten(x, 0, 1) |
|||
x = self.specaug.augment(x) |
|||
|
|||
with autocast(): |
|||
y = model(x.to(self.configs.device)) |
|||
loss = similarity_loss(y, self.tau) |
|||
self.scaler.scale(loss).backward() |
|||
|
|||
self.scaler.step(self.optimizer) |
|||
self.scaler.update() |
|||
|
|||
self.lr_scheduler.step() |
|||
|
|||
lossnum = float(loss.item()) |
|||
self.losses.append(lossnum) |
|||
step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0} |
|||
return step_logs |
|||
|
|||
@torch.no_grad() |
|||
def evaluate(self, model: nn.Module, logs: dict) -> dict: |
|||
validate_N = 0 |
|||
y_embed = [] |
|||
minibatch = 40 |
|||
if torch.cuda.get_device_properties(0).total_memory > 11e9: |
|||
minibatch = 640 |
|||
for x in self.eval_dataloader: |
|||
x = torch.flatten(x, 0, 1) |
|||
for xx in torch.split(x, minibatch): |
|||
y = model(xx.to(self.configs.device)).cpu() |
|||
y_embed.append(y) |
|||
y_embed = torch.cat(y_embed) |
|||
y_embed_org = y_embed[0::2] |
|||
y_embed_aug = y_embed[1::2].to(self.configs.device) |
|||
|
|||
# compute validation score on GPU |
|||
self_score = [] |
|||
for embeds in torch.split(y_embed_org, 320): |
|||
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) |
|||
self_score.append(A.diagonal(-validate_N).cpu()) |
|||
validate_N += embeds.shape[0] |
|||
self_score = torch.cat(self_score).to(self.configs.device) |
|||
|
|||
ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device) |
|||
|
|||
for embeds in torch.split(y_embed_org, 320): |
|||
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) |
|||
ranks += (A.T >= self_score).sum(dim=0) |
|||
acc = int((ranks == 1).sum()) |
|||
print('\nvalidate score: %f' % (acc / validate_N,)) |
|||
# logs['epoch_metric'] = acc / validate_N |
|||
return logs |
|||
|
|||
|
|||
def train_nnfp(model, config_json_path): |
|||
model.train() |
|||
params = read_config(config_json_path) |
|||
train_dataloader = SegmentedDataLoader('train', params, num_workers=4) |
|||
print('training data contains %d samples' % len(train_dataloader.dataset)) |
|||
train_dataloader.shuffle = True |
|||
train_dataloader.eval_time_shift = False |
|||
train_dataloader.augmented = True |
|||
train_dataloader.set_epoch(0) |
|||
|
|||
val_dataloader = SegmentedDataLoader('validate', params, num_workers=4) |
|||
val_dataloader.shuffle = False |
|||
val_dataloader.eval_time_shift = True |
|||
val_dataloader.set_epoch(-1) |
|||
|
|||
training_config = TrainingConfig( |
|||
batch_size=params['batch_size'], |
|||
epoch_num=params['epoch'], |
|||
output_dir='fine_tune_output', |
|||
lr_scheduler_type='cosine', |
|||
eval_strategy='epoch', |
|||
lr=params['lr'] |
|||
) |
|||
|
|||
optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr) |
|||
|
|||
nnfp_trainer = NNFPTrainer( |
|||
model=model, |
|||
training_config=training_config, |
|||
train_dataset=train_dataloader.dataset, |
|||
eval_dataset=None, |
|||
train_dataloader=train_dataloader, |
|||
eval_dataloader=val_dataloader, |
|||
model_card=None, |
|||
params = params |
|||
) |
|||
nnfp_trainer.set_optimizer(optimizer) |
|||
|
|||
nnfp_trainer.train() |
Loading…
Reference in new issue