nnfp
copied
ChengZi
2 years ago
11 changed files with 997 additions and 0 deletions
@ -0,0 +1,152 @@ |
|||||
|
import json |
||||
|
import os |
||||
|
import subprocess |
||||
|
|
||||
|
import numpy as np |
||||
|
import wave |
||||
|
import io |
||||
|
|
||||
|
|
||||
|
# because builtin wave won't read wav files with more than 2 channels |
||||
|
class HackExtensibleWave: |
||||
|
def __init__(self, stream): |
||||
|
self.stream = stream |
||||
|
self.pos = 0 |
||||
|
def read(self, n): |
||||
|
r = self.stream.read(n) |
||||
|
new_pos = self.pos + len(r) |
||||
|
if self.pos < 20 and self.pos + n >= 20: |
||||
|
r = r[:20-self.pos] + b'\x01\x00'[:new_pos-20] + r[22-self.pos:] |
||||
|
elif 20 <= self.pos < 22: |
||||
|
r = b'\x01\x00'[self.pos-20:new_pos-20] + r[22-self.pos:] |
||||
|
self.pos = new_pos |
||||
|
return r |
||||
|
|
||||
|
def ffmpeg_get_audio(filename): |
||||
|
error_log = open(os.devnull, 'w') |
||||
|
proc = subprocess.Popen(['ffmpeg', '-i', filename, '-f', 'wav', 'pipe:1'], |
||||
|
stderr=error_log, |
||||
|
stdin=open(os.devnull), |
||||
|
stdout=subprocess.PIPE, |
||||
|
bufsize=1000000) |
||||
|
try: |
||||
|
dat = proc.stdout.read() |
||||
|
wav = wave.open(HackExtensibleWave(io.BytesIO(dat))) |
||||
|
ch = wav.getnchannels() |
||||
|
rate = wav.getframerate() |
||||
|
n = wav.getnframes() |
||||
|
dat = wav.readframes(n) |
||||
|
del wav |
||||
|
samples = np.frombuffer(dat, dtype=np.int16) / 32768 |
||||
|
samples = samples.reshape([-1, ch]).T |
||||
|
return samples, rate |
||||
|
except (wave.Error, EOFError): |
||||
|
print('failed to decode %s. maybe the file is broken!' % filename) |
||||
|
return np.zeros([1, 0]), 44100 |
||||
|
|
||||
|
def wave_get_audio(filename): |
||||
|
with open(filename, 'rb') as fin: |
||||
|
wav = wave.open(HackExtensibleWave(fin)) |
||||
|
smpwidth = wav.getsampwidth() |
||||
|
if smpwidth not in {1, 2, 3}: |
||||
|
return None |
||||
|
n = wav.getnframes() |
||||
|
if smpwidth == 1: |
||||
|
samples = np.frombuffer(wav.readframes(n), dtype=np.uint8) / 128 - 1 |
||||
|
elif smpwidth == 2: |
||||
|
samples = np.frombuffer(wav.readframes(n), dtype=np.int16) / 32768 |
||||
|
elif smpwidth == 3: |
||||
|
a = np.frombuffer(wav.readframes(n), dtype=np.uint8) |
||||
|
samples = np.stack([a[0::3], a[1::3], a[2::3], -(a[2::3]>>7)], axis=1).view(np.int32).squeeze(1) |
||||
|
del a |
||||
|
samples = samples / 8388608 |
||||
|
samples = samples.reshape([-1, wav.getnchannels()]).T |
||||
|
return samples, wav.getframerate() |
||||
|
|
||||
|
def get_audio(filename): |
||||
|
if str(filename).endswith('.wav'): |
||||
|
try: |
||||
|
a = wave_get_audio(filename) |
||||
|
if a: return a |
||||
|
except Exception: |
||||
|
pass |
||||
|
return ffmpeg_get_audio(filename) |
||||
|
|
||||
|
class FfmpegStream: |
||||
|
def __init__(self, proc, sample_rate, nchannels, tmpfile): |
||||
|
self.proc = proc |
||||
|
self.sample_rate = sample_rate |
||||
|
self.nchannels = nchannels |
||||
|
self.tmpfile = None |
||||
|
self.stream = self.gen_stream() |
||||
|
self.tmpfile = tmpfile |
||||
|
def __del__(self): |
||||
|
self.proc.terminate() |
||||
|
self.proc.communicate() |
||||
|
del self.proc |
||||
|
if self.tmpfile: |
||||
|
os.unlink(self.tmpfile) |
||||
|
def gen_stream(self): |
||||
|
num = yield np.array([], dtype=np.int16) |
||||
|
if not num: num = 1024 |
||||
|
while True: |
||||
|
to_read = num * self.nchannels * 2 |
||||
|
dat = self.proc.stdout.read(to_read) |
||||
|
num = yield np.frombuffer(dat, dtype=np.int16) |
||||
|
if not num: num = 1024 |
||||
|
if len(dat) < to_read: |
||||
|
break |
||||
|
|
||||
|
def ffmpeg_stream_audio(filename, is_tmp=False): |
||||
|
while 1: |
||||
|
try: |
||||
|
stderr=open(os.devnull, 'w') |
||||
|
stdin=open(os.devnull) |
||||
|
break |
||||
|
except PermissionError: |
||||
|
print('PermissionError occured, try again') |
||||
|
proc = subprocess.Popen(['ffprobe', '-i', filename, '-show_streams', |
||||
|
'-select_streams', 'a', '-print_format', 'json'], |
||||
|
stderr=stderr, |
||||
|
stdin=stdin, |
||||
|
stdout=subprocess.PIPE) |
||||
|
prop = json.loads(proc.stdout.read()) |
||||
|
if 'streams' not in prop: |
||||
|
raise RuntimeError('FFmpeg cannot decode audio') |
||||
|
sample_rate = int(prop['streams'][0]['sample_rate']) |
||||
|
nchannels = prop['streams'][0]['channels'] |
||||
|
proc = subprocess.Popen(['ffmpeg', '-i', filename, |
||||
|
'-f', 's16le', '-acodec', 'pcm_s16le', 'pipe:1'], |
||||
|
stderr=stderr, |
||||
|
stdin=stdin, |
||||
|
stdout=subprocess.PIPE) |
||||
|
tmpfile = None |
||||
|
if is_tmp: |
||||
|
tmpfile = filename |
||||
|
return FfmpegStream(proc, sample_rate, nchannels, tmpfile=tmpfile) |
||||
|
|
||||
|
class WaveStream: |
||||
|
def __init__(self, filename, is_tmp=False): |
||||
|
self.is_tmp = None |
||||
|
self.file = open(filename, 'rb') |
||||
|
self.wave = wave.open(HackExtensibleWave(self.file)) |
||||
|
self.smpsize = self.wave.getnchannels() * self.wave.getsampwidth() |
||||
|
self.sample_rate = self.wave.getframerate() |
||||
|
self.nchannels = self.wave.getnchannels() |
||||
|
if self.wave.getsampwidth() != 2: |
||||
|
raise NotImplementedError('wave stream currently only supports 16bit wav') |
||||
|
self.stream = self.gen_stream() |
||||
|
self.is_tmp = filename if is_tmp else None |
||||
|
def gen_stream(self): |
||||
|
num = yield np.array([], dtype=np.int16) |
||||
|
if not num: num = 1024 |
||||
|
while True: |
||||
|
dat = self.wave.readframes(num) |
||||
|
num = yield np.frombuffer(dat, dtype=np.int16) |
||||
|
if not num: num = 1024 |
||||
|
if len(dat) < num * self.smpsize: |
||||
|
break |
||||
|
def __del__(self): |
||||
|
if self.is_tmp: |
||||
|
os.unlink(self.is_tmp) |
||||
|
|
@ -0,0 +1,301 @@ |
|||||
|
import os |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
import torch.fft |
||||
|
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler |
||||
|
|
||||
|
from .melspec import build_mel_spec_layer |
||||
|
from .noise import NoiseData |
||||
|
from .ir import AIR, MicIRP |
||||
|
from .preprocess import preprocess_music |
||||
|
|
||||
|
class NumpyMemmapDataset(Dataset): |
||||
|
def __init__(self, path, dtype): |
||||
|
self.path = path |
||||
|
self.f = np.memmap(self.path, dtype=dtype) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.f) |
||||
|
|
||||
|
def __getitem__(self, idx): |
||||
|
return self.f[idx] |
||||
|
|
||||
|
# need these, otherwise the pickler will read the whole data in memory |
||||
|
def __getstate__(self): |
||||
|
return {'path': self.path, 'dtype': self.f.dtype} |
||||
|
|
||||
|
def __setstate__(self, d): |
||||
|
self.__init__(d['path'], d['dtype']) |
||||
|
|
||||
|
# data augmentation on music dataset |
||||
|
class MusicSegmentDataset(Dataset): |
||||
|
def __init__(self, params, train_val): |
||||
|
# load some configs |
||||
|
assert train_val in {'train', 'validate'} |
||||
|
sample_rate = params['sample_rate'] |
||||
|
self.augmented = True |
||||
|
self.eval_time_shift = True |
||||
|
self.segment_size = int(params['segment_size'] * sample_rate) |
||||
|
self.hop_size = int(params['hop_size'] * sample_rate) |
||||
|
self.time_offset = int(params['time_offset'] * sample_rate) # select 1.2s of audio, then choose two random 1s of audios |
||||
|
self.pad_start = int(params['pad_start'] * sample_rate) # include more audio at the left of a segment to simulate reverb |
||||
|
self.params = params |
||||
|
|
||||
|
# get fft size needed for reverb |
||||
|
fftconv_n = 1024 |
||||
|
air_len = int(params['air']['length'] * sample_rate) |
||||
|
ir_len = int(params['micirp']['length'] * sample_rate) |
||||
|
fft_needed = self.segment_size + self.pad_start + air_len + ir_len |
||||
|
while fftconv_n < fft_needed: |
||||
|
fftconv_n *= 2 |
||||
|
self.fftconv_n = fftconv_n |
||||
|
|
||||
|
# datasets data augmentation |
||||
|
cache_dir = params['cache_dir'] |
||||
|
os.makedirs(cache_dir, exist_ok=True) |
||||
|
if params['noise'][train_val]: |
||||
|
self.noise = NoiseData(noise_dir=params['noise']['dir'], list_csv=params['noise'][train_val], sample_rate=sample_rate, cache_dir=cache_dir) |
||||
|
else: self.noise = None |
||||
|
if params['air'][train_val]: |
||||
|
self.air = AIR(air_dir=params['air']['dir'], list_csv=params['air'][train_val], length=params['air']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) |
||||
|
else: self.air = None |
||||
|
if params['micirp'][train_val]: |
||||
|
self.micirp = MicIRP(mic_dir=params['micirp']['dir'], list_csv=params['micirp'][train_val], length=params['micirp']['length'], fftconv_n=fftconv_n, sample_rate=sample_rate) |
||||
|
else: self.micirp = None |
||||
|
|
||||
|
# Load music dataset as memory mapped file |
||||
|
file_name = os.path.splitext(os.path.split(params[train_val + '_csv'])[1])[0] |
||||
|
file_name = os.path.join(cache_dir, '1' + file_name) |
||||
|
if os.path.exists(file_name + '.npy'): |
||||
|
print('load cached music from %s.bin' % file_name) |
||||
|
else: |
||||
|
preprocess_music(params['music_dir'], params[train_val + '_csv'], sample_rate, file_name) |
||||
|
self.f = NumpyMemmapDataset(file_name + '.bin', np.int16) |
||||
|
|
||||
|
# some segmentation settings |
||||
|
song_len = np.load(file_name + '.npy') |
||||
|
self.cues = [] # start location of segment i |
||||
|
self.offset_left = [] # allowed left shift of segment i |
||||
|
self.offset_right = [] # allowed right shift of segment i |
||||
|
self.song_range = [] # range of song i is (start time, end time, start idx, end idx) |
||||
|
t = 0 |
||||
|
for duration in song_len: |
||||
|
num_segs = (duration - self.segment_size + self.hop_size) // self.hop_size |
||||
|
start_cue = len(self.cues) |
||||
|
for idx in range(num_segs): |
||||
|
my_time = idx * self.hop_size |
||||
|
self.cues.append(t + my_time) |
||||
|
self.offset_left.append(my_time) |
||||
|
self.offset_right.append(duration - my_time) |
||||
|
end_cue = len(self.cues) |
||||
|
self.song_range.append((t, t + duration, start_cue, end_cue)) |
||||
|
t += duration |
||||
|
|
||||
|
# convert to torch Tensor to allow sharing memory |
||||
|
self.cues = torch.LongTensor(self.cues) |
||||
|
self.offset_left = torch.LongTensor(self.offset_left) |
||||
|
self.offset_right = torch.LongTensor(self.offset_right) |
||||
|
|
||||
|
# setup mel spectrogram |
||||
|
self.mel = build_mel_spec_layer(params) |
||||
|
|
||||
|
def get_single_segment(self, idx, offset, length): |
||||
|
cue = int(self.cues[idx]) + offset |
||||
|
left = int(self.offset_left[idx]) + offset |
||||
|
right = int(self.offset_right[idx]) - offset |
||||
|
|
||||
|
# choose a segment |
||||
|
# segment looks like: |
||||
|
# cue |
||||
|
# v |
||||
|
# [ pad_start | segment_size | ... ] |
||||
|
# \--- time_offset ---/ |
||||
|
segment = self.f[cue - min(left, self.pad_start): cue + min(right, length)] |
||||
|
segment = np.pad(segment, [max(0, self.pad_start-left), max(0, length-right)]) |
||||
|
|
||||
|
# convert 16 bit to float |
||||
|
return segment * np.float32(1/32768) |
||||
|
|
||||
|
def __getitem__(self, indices): |
||||
|
if self.eval_time_shift: |
||||
|
# collect all segments. It is faster to do batch processing |
||||
|
shift_range = self.segment_size//2 |
||||
|
x = [self.get_single_segment(i, -self.segment_size//4, self.segment_size+shift_range) for i in indices] |
||||
|
|
||||
|
# random time offset only on augmented segment, as a query audio |
||||
|
segment_size = self.pad_start + self.segment_size |
||||
|
offset1 = [self.segment_size//4] * len(x) |
||||
|
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
||||
|
else: |
||||
|
# collect all segments. It is faster to do batch processing |
||||
|
x = [self.get_single_segment(i, 0, self.time_offset) for i in indices] |
||||
|
|
||||
|
# random time offset |
||||
|
shift_range = self.time_offset - self.segment_size |
||||
|
segment_size = self.pad_start + self.segment_size |
||||
|
offset1 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
||||
|
offset2 = torch.randint(high=shift_range+1, size=(len(x),)).tolist() |
||||
|
|
||||
|
x_orig = [xi[off + self.pad_start : off + segment_size] for xi, off in zip(x, offset1)] |
||||
|
x_orig = torch.Tensor(np.stack(x_orig).astype(np.float32)) |
||||
|
x_aug = [xi[off : off + segment_size] for xi, off in zip(x, offset2)] |
||||
|
x_aug = torch.Tensor(np.stack(x_aug).astype(np.float32)) |
||||
|
|
||||
|
if self.augmented: |
||||
|
# background noise |
||||
|
if self.noise is not None: |
||||
|
x_aug = self.noise.add_noises(x_aug, self.params['noise']['snr_min'], self.params['noise']['snr_max']) |
||||
|
|
||||
|
# impulse response |
||||
|
spec = torch.fft.rfft(x_aug, self.fftconv_n) |
||||
|
if self.air is not None: |
||||
|
spec *= self.air.random_choose(spec.shape[0]) |
||||
|
if self.micirp is not None: |
||||
|
spec *= self.micirp.random_choose(spec.shape[0]) |
||||
|
x_aug = torch.fft.irfft(spec, self.fftconv_n) |
||||
|
x_aug = x_aug[..., self.pad_start:segment_size] |
||||
|
|
||||
|
# output [x1_orig, x1_aug, x2_orig, x2_aug, ...] |
||||
|
x = [x_orig, x_aug] if self.augmented else [x_orig] |
||||
|
x = torch.stack(x, dim=1) |
||||
|
if self.mel is not None: |
||||
|
return self.mel(x) |
||||
|
return x |
||||
|
|
||||
|
def fan_si_le(self): |
||||
|
raise NotImplementedError('煩死了') |
||||
|
|
||||
|
def zuo_bu_chu_lai(self): |
||||
|
raise NotImplementedError('做不起來') |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.cues) |
||||
|
|
||||
|
def get_num_songs(self): |
||||
|
return len(self.song_range) |
||||
|
|
||||
|
def get_song_segments(self, song_id): |
||||
|
return self.song_range[song_id][2:4] |
||||
|
|
||||
|
def preload_song(self, song_id): |
||||
|
start, end, _, _ = self.song_range[song_id] |
||||
|
return self.f[start : end].copy() |
||||
|
|
||||
|
class TwoStageShuffler(Sampler): |
||||
|
def __init__(self, music_data: MusicSegmentDataset, shuffle_size): |
||||
|
self.music_data = music_data |
||||
|
self.shuffle_size = shuffle_size |
||||
|
self.shuffle = True |
||||
|
self.loaded = set() |
||||
|
self.generator = torch.Generator() |
||||
|
self.generator2 = torch.Generator() |
||||
|
|
||||
|
def set_epoch(self, epoch): |
||||
|
self.generator.manual_seed(42 + epoch) |
||||
|
self.generator2.manual_seed(42 + epoch) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.music_data) |
||||
|
|
||||
|
def preload(self, song_id): |
||||
|
if song_id not in self.loaded: |
||||
|
self.music_data.preload_song(song_id) |
||||
|
self.loaded.add(song_id) |
||||
|
|
||||
|
def baseline_shuffle(self): |
||||
|
# the same as DataLoader with shuffle=True, but with a music preloader |
||||
|
# 2021/10/18 sorry I have to remove preloader |
||||
|
#for song in range(self.music_data.get_num_songs()): |
||||
|
# self.preload(song) |
||||
|
|
||||
|
yield from torch.randperm(len(self), generator=self.generator).tolist() |
||||
|
|
||||
|
def shuffling_iter(self): |
||||
|
# shuffle song list first |
||||
|
shuffle_song = torch.randperm(self.music_data.get_num_songs(), generator=self.generator) |
||||
|
|
||||
|
# split song list into chunks |
||||
|
chunks = torch.split(shuffle_song, self.shuffle_size) |
||||
|
for nc, songs in enumerate(chunks): |
||||
|
# sort songs to make preloader read more sequential |
||||
|
songs = torch.sort(songs)[0].tolist() |
||||
|
|
||||
|
buf = [] |
||||
|
for song in songs: |
||||
|
# collect segment ids |
||||
|
seg_start, seg_end = self.music_data.get_song_segments(song) |
||||
|
buf += list(range(seg_start, seg_end)) |
||||
|
|
||||
|
if nc == 0: |
||||
|
# load first chunk |
||||
|
for song in songs: |
||||
|
self.preload(song) |
||||
|
|
||||
|
# shuffle segments from song chunk |
||||
|
shuffle_segs = torch.randperm(len(buf), generator=self.generator2) |
||||
|
shuffle_segs = [buf[x] for x in shuffle_segs] |
||||
|
preload_cnt = 0 |
||||
|
for i, idx in enumerate(shuffle_segs): |
||||
|
# output shuffled segment idx |
||||
|
yield idx |
||||
|
|
||||
|
# preload next chunk |
||||
|
while len(self.loaded) < len(shuffle_song) and preload_cnt * len(shuffle_segs) < (i+1) * self.shuffle_size: |
||||
|
song = shuffle_song[len(self.loaded)].item() |
||||
|
self.preload(song) |
||||
|
preload_cnt += 1 |
||||
|
|
||||
|
def non_shuffling_iter(self): |
||||
|
# just return 0 ... len(dataset)-1 |
||||
|
yield from range(len(self)) |
||||
|
|
||||
|
def __iter__(self): |
||||
|
if self.shuffle: |
||||
|
if self.shuffle_size is None: |
||||
|
return self.baseline_shuffle() |
||||
|
else: |
||||
|
return self.shuffling_iter() |
||||
|
return self.non_shuffling_iter() |
||||
|
|
||||
|
# since instantiating a DataLoader of MusicSegmentDataset is hard, I provide a data loader builder |
||||
|
class SegmentedDataLoader: |
||||
|
def __init__(self, train_val, configs, num_workers=4, pin_memory=False, prefetch_factor=2): |
||||
|
assert train_val in {'train', 'validate'} |
||||
|
self.dataset = MusicSegmentDataset(configs, train_val) |
||||
|
assert configs['batch_size'] % 2 == 0 |
||||
|
self.batch_size = configs['batch_size'] |
||||
|
self.shuffler = TwoStageShuffler(self.dataset, configs['shuffle_size']) |
||||
|
self.sampler = BatchSampler(self.shuffler, self.batch_size//2, False) |
||||
|
self.num_workers = num_workers |
||||
|
self.configs = configs |
||||
|
self.pin_memory = pin_memory |
||||
|
self.prefetch_factor = prefetch_factor |
||||
|
|
||||
|
# you can change shuffle to True/False |
||||
|
self.shuffle = True |
||||
|
# you can change augmented to True/False |
||||
|
self.augmented = True |
||||
|
# you can change eval time shift to True/False |
||||
|
self.eval_time_shift = False |
||||
|
|
||||
|
self.loader = DataLoader( |
||||
|
self.dataset, |
||||
|
sampler=self.sampler, |
||||
|
batch_size=None, |
||||
|
num_workers=self.num_workers, |
||||
|
pin_memory=self.pin_memory, |
||||
|
prefetch_factor=self.prefetch_factor |
||||
|
) |
||||
|
|
||||
|
def set_epoch(self, epoch): |
||||
|
self.shuffler.set_epoch(epoch) |
||||
|
|
||||
|
def __iter__(self): |
||||
|
self.dataset.augmented = self.augmented |
||||
|
self.dataset.eval_time_shift = self.eval_time_shift |
||||
|
self.shuffler.shuffle = self.shuffle |
||||
|
return iter(self.loader) |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.loader) |
@ -0,0 +1,89 @@ |
|||||
|
import argparse |
||||
|
import csv |
||||
|
import os |
||||
|
import warnings |
||||
|
|
||||
|
import scipy.io |
||||
|
import numpy as np |
||||
|
import torch |
||||
|
import torch.fft |
||||
|
with warnings.catch_warnings(): |
||||
|
warnings.simplefilter("ignore") |
||||
|
import torchaudio |
||||
|
|
||||
|
from .audio import get_audio |
||||
|
|
||||
|
class AIR: |
||||
|
def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000): |
||||
|
print('loading Aachen IR dataset') |
||||
|
with open(list_csv, 'r') as fin: |
||||
|
reader = csv.reader(fin) |
||||
|
airs = [] |
||||
|
firstrow = next(reader) |
||||
|
for row in reader: |
||||
|
airs.append(row[0]) |
||||
|
data = [] |
||||
|
to_len = int(length * sample_rate) |
||||
|
self.names = [] |
||||
|
for name in airs: |
||||
|
mat = scipy.io.loadmat(os.path.join(air_dir, name)) |
||||
|
h_air = torch.tensor(mat['h_air'].astype(np.float32)) |
||||
|
assert h_air.shape[0] == 1 |
||||
|
h_air = h_air[0] |
||||
|
air_info = mat['air_info'] |
||||
|
fs = int(air_info['fs'][0][0][0][0]) |
||||
|
self.names.append(str(air_info['room'][0][0][0])) |
||||
|
resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air) |
||||
|
truncated = resampled[0:to_len] |
||||
|
freqd = torch.fft.rfft(truncated, fftconv_n) |
||||
|
data.append(freqd) |
||||
|
self.data = torch.stack(data) |
||||
|
|
||||
|
def random_choose(self, num): |
||||
|
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) |
||||
|
return self.data[indices] |
||||
|
|
||||
|
def random_choose_name(self): |
||||
|
index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item() |
||||
|
return self.data[index], self.names[index] |
||||
|
|
||||
|
class MicIRP: |
||||
|
def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000): |
||||
|
print('loading microphone IR dataset') |
||||
|
with open(list_csv, 'r') as fin: |
||||
|
reader = csv.reader(fin) |
||||
|
mics = [] |
||||
|
firstrow = next(reader) |
||||
|
for row in reader: |
||||
|
mics.append(row[0]) |
||||
|
data = [] |
||||
|
to_len = int(length * sample_rate) |
||||
|
for name in mics: |
||||
|
smp, smprate = get_audio(os.path.join(mic_dir, name)) |
||||
|
smp = torch.FloatTensor(smp).mean(dim=0) |
||||
|
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) |
||||
|
truncated = resampled[0:to_len] |
||||
|
freqd = torch.fft.rfft(truncated, fftconv_n) |
||||
|
data.append(freqd) |
||||
|
self.data = torch.stack(data) |
||||
|
|
||||
|
def random_choose(self, num): |
||||
|
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long) |
||||
|
return self.data[indices] |
||||
|
|
||||
|
if __name__ == '__main__': |
||||
|
args = argparse.ArgumentParser() |
||||
|
args.add_argument('air') |
||||
|
args.add_argument('out') |
||||
|
args = args.parse_args() |
||||
|
|
||||
|
with open(args.out, 'w', encoding='utf8', newline='\n') as fout: |
||||
|
writer = csv.writer(fout) |
||||
|
writer.writerow(['file']) |
||||
|
files = [] |
||||
|
for name in os.listdir(args.air): |
||||
|
if name.endswith('.mat'): |
||||
|
files.append(name) |
||||
|
files.sort() |
||||
|
for name in files: |
||||
|
writer.writerow([name]) |
@ -0,0 +1,63 @@ |
|||||
|
import torch |
||||
|
import torchaudio |
||||
|
|
||||
|
class MelSpec(torch.nn.Module): |
||||
|
def __init__(self, |
||||
|
sample_rate=8000, |
||||
|
stft_n=1024, |
||||
|
stft_hop=256, |
||||
|
f_min=300, |
||||
|
f_max=4000, |
||||
|
n_mels=256, |
||||
|
naf_mode=False, |
||||
|
mel_log='log', |
||||
|
spec_norm='l2'): |
||||
|
super(MelSpec, self).__init__() |
||||
|
self.naf_mode = naf_mode |
||||
|
self.mel_log = mel_log |
||||
|
self.spec_norm = spec_norm |
||||
|
self.mel = torchaudio.transforms.MelSpectrogram( |
||||
|
sample_rate=sample_rate, |
||||
|
n_fft=stft_n, |
||||
|
hop_length=stft_hop, |
||||
|
f_min=f_min, |
||||
|
f_max=f_max, |
||||
|
n_mels=n_mels, |
||||
|
window_fn=torch.hann_window, |
||||
|
power = 1 if naf_mode else 2, |
||||
|
pad_mode = 'constant' if naf_mode else 'reflect', |
||||
|
norm = 'slaney' if naf_mode else None, |
||||
|
mel_scale = 'slaney' if naf_mode else 'htk' |
||||
|
) |
||||
|
|
||||
|
def forward(self, x): |
||||
|
# normalize volume |
||||
|
p = 1e999 if self.spec_norm == 'max' else 2 |
||||
|
x = torch.nn.functional.normalize(x, p=p, dim=-1) |
||||
|
|
||||
|
if self.naf_mode: |
||||
|
x = self.mel(x) + 0.06 |
||||
|
else: |
||||
|
x = self.mel(x) + 1e-8 |
||||
|
|
||||
|
if self.mel_log == 'log10': |
||||
|
x = torch.log10(x) |
||||
|
elif self.mel_log == 'log': |
||||
|
x = torch.log(x) |
||||
|
|
||||
|
if self.spec_norm == 'max': |
||||
|
x = x - torch.amax(x, dim=(-2,-1), keepdim=True) |
||||
|
return x |
||||
|
|
||||
|
def build_mel_spec_layer(params): |
||||
|
return MelSpec( |
||||
|
sample_rate = params['sample_rate'], |
||||
|
stft_n = params['stft_n'], |
||||
|
stft_hop = params['stft_hop'], |
||||
|
f_min = params['f_min'], |
||||
|
f_max = params['f_max'], |
||||
|
n_mels = params['n_mels'], |
||||
|
naf_mode = params.get('naf_mode', False), |
||||
|
mel_log = params.get('mel_log', 'log'), |
||||
|
spec_norm = params.get('spec_norm', 'l2') |
||||
|
) |
@ -0,0 +1,109 @@ |
|||||
|
import csv |
||||
|
import os |
||||
|
import warnings |
||||
|
|
||||
|
import tqdm |
||||
|
import numpy as np |
||||
|
import torch |
||||
|
with warnings.catch_warnings(): |
||||
|
warnings.simplefilter("ignore") |
||||
|
import torchaudio |
||||
|
|
||||
|
from .simpleutils import get_hash |
||||
|
from .audio import get_audio |
||||
|
|
||||
|
class NoiseData: |
||||
|
def __init__(self, noise_dir, list_csv, sample_rate, cache_dir): |
||||
|
print('loading noise dataset') |
||||
|
hashes = [] |
||||
|
with open(list_csv, 'r') as fin: |
||||
|
reader = csv.reader(fin) |
||||
|
noises = [] |
||||
|
firstrow = next(reader) |
||||
|
for row in reader: |
||||
|
noises.append(row[0]) |
||||
|
hashes.append(get_hash(row[0])) |
||||
|
hash = get_hash(''.join(hashes)) |
||||
|
#self.data = self.load_from_cache(list_csv, cache_dir, hash) |
||||
|
#if self.data is not None: |
||||
|
# print(self.data.shape) |
||||
|
# return |
||||
|
data = [] |
||||
|
silence_threshold = 0 |
||||
|
self.names = [] |
||||
|
for name in tqdm.tqdm(noises): |
||||
|
smp, smprate = get_audio(os.path.join(noise_dir, name)) |
||||
|
smp = torch.from_numpy(smp.astype(np.float32)) |
||||
|
|
||||
|
# convert to mono |
||||
|
smp = smp.mean(dim=0) |
||||
|
|
||||
|
# strip silence start/end |
||||
|
abs_smp = torch.abs(smp) |
||||
|
if torch.max(abs_smp) <= silence_threshold: |
||||
|
print('%s too silent' % name) |
||||
|
continue |
||||
|
has_sound = (abs_smp > silence_threshold).to(torch.int) |
||||
|
start = int(torch.argmax(has_sound)) |
||||
|
end = has_sound.shape[0] - int(torch.argmax(has_sound.flip(0))) |
||||
|
smp = smp[max(start, 0) : end] |
||||
|
|
||||
|
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp) |
||||
|
resampled = torch.nn.functional.normalize(resampled, dim=0, p=1e999) |
||||
|
data.append(resampled) |
||||
|
self.names.append(name) |
||||
|
self.data = torch.cat(data) |
||||
|
self.boundary = [0] + [x.shape[0] for x in data] |
||||
|
self.boundary = torch.LongTensor(self.boundary).cumsum(0) |
||||
|
del data |
||||
|
#self.save_to_cache(list_csv, cache_dir, hash, self.data) |
||||
|
print(self.data.shape) |
||||
|
|
||||
|
def load_from_cache(self, list_csv, cache_dir, hash): |
||||
|
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') |
||||
|
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') |
||||
|
if os.path.exists(loc) and os.path.exists(loc2): |
||||
|
with open(loc2, 'r') as fin: |
||||
|
read_hash = fin.read().strip() |
||||
|
if read_hash != hash: |
||||
|
return None |
||||
|
print('cache hit!') |
||||
|
return torch.from_numpy(np.fromfile(loc, dtype=np.float32)) |
||||
|
return None |
||||
|
|
||||
|
def save_to_cache(self, list_csv, cache_dir, hash, audio): |
||||
|
os.makedirs(cache_dir, exist_ok=True) |
||||
|
loc = os.path.join(cache_dir, os.path.basename(list_csv) + '.npy') |
||||
|
loc2 = os.path.join(cache_dir, os.path.basename(list_csv) + '.hash') |
||||
|
with open(loc2, 'w') as fout: |
||||
|
fout.write(hash) |
||||
|
print('save to cache') |
||||
|
audio.numpy().tofile(loc) |
||||
|
|
||||
|
def random_choose(self, num, duration, out_name=False): |
||||
|
indices = torch.randint(0, self.data.shape[0] - duration, size=(num,), dtype=torch.long) |
||||
|
out = torch.zeros([num, duration], dtype=torch.float32) |
||||
|
for i in range(num): |
||||
|
start = int(indices[i]) |
||||
|
end = start + duration |
||||
|
out[i] = self.data[start:end] |
||||
|
name_lookup = torch.searchsorted(self.boundary, indices, right=True) - 1 |
||||
|
if out_name: |
||||
|
return out, [self.names[x] for x in name_lookup] |
||||
|
return out |
||||
|
|
||||
|
# x is a 2d array |
||||
|
def add_noises(self, x, snr_min, snr_max, out_name=False): |
||||
|
eps = 1e-12 |
||||
|
noise = self.random_choose(x.shape[0], x.shape[1], out_name=out_name) |
||||
|
if out_name: |
||||
|
noise, noise_name = noise |
||||
|
vol_x = torch.clamp((x ** 2).mean(dim=1), min=eps).sqrt() |
||||
|
vol_noise = torch.clamp((noise ** 2).mean(dim=1), min=eps).sqrt() |
||||
|
snr = torch.FloatTensor(x.shape[0]).uniform_(snr_min, snr_max) |
||||
|
ratio = vol_x / vol_noise |
||||
|
ratio *= 10 ** -(snr / 20) |
||||
|
x_aug = x + ratio.unsqueeze(1) * noise |
||||
|
if out_name: |
||||
|
return x_aug, noise_name, snr |
||||
|
return x_aug |
@ -0,0 +1,55 @@ |
|||||
|
import csv |
||||
|
import os |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
from torch.utils.data import DataLoader, Dataset |
||||
|
import torchaudio |
||||
|
import tqdm |
||||
|
|
||||
|
from .audio import get_audio |
||||
|
|
||||
|
class Preprocessor(Dataset): |
||||
|
def __init__(self, files, dir, sample_rate): |
||||
|
self.files = files |
||||
|
self.dir = dir |
||||
|
self.resampler = {} |
||||
|
self.sample_rate = sample_rate |
||||
|
|
||||
|
def __getitem__(self, n): |
||||
|
dat = get_audio(os.path.join(self.dir, self.files[n])) |
||||
|
wav, smprate = dat |
||||
|
if smprate not in self.resampler: |
||||
|
self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate) |
||||
|
wav = torch.Tensor(wav) |
||||
|
wav = wav.mean(dim=0) |
||||
|
wav = self.resampler[smprate](torch.Tensor(wav)) |
||||
|
|
||||
|
# quantize to 16 bit again |
||||
|
wav *= 32768 |
||||
|
torch.clamp(wav, -32768, 32767, out=wav) |
||||
|
wav = wav.to(torch.int16) |
||||
|
return wav |
||||
|
|
||||
|
def __len__(self): |
||||
|
return len(self.files) |
||||
|
|
||||
|
def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out): |
||||
|
print('converting music to wav') |
||||
|
with open(music_csv) as fin: |
||||
|
reader = csv.reader(fin) |
||||
|
next(reader) |
||||
|
files = [row[0] for row in reader] |
||||
|
|
||||
|
preprocessor = Preprocessor(files, music_dir, sample_rate) |
||||
|
loader = DataLoader(preprocessor, num_workers=4, batch_size=None) |
||||
|
out_file = open(preprocess_out + '.bin', 'wb') |
||||
|
song_lens = [] |
||||
|
for wav in tqdm.tqdm(loader): |
||||
|
# torch.set_num_threads(1) # default multithreading causes cpu contention |
||||
|
|
||||
|
wav = wav.numpy() |
||||
|
out_file.write(wav.tobytes()) |
||||
|
song_lens.append(wav.shape[0]) |
||||
|
out_file.close() |
||||
|
np.save(preprocess_out, np.array(song_lens, dtype=np.int64)) |
@ -0,0 +1,26 @@ |
|||||
|
# every utils that don't use torch |
||||
|
import hashlib |
||||
|
import json |
||||
|
import logging |
||||
|
import multiprocessing as mp |
||||
|
import os |
||||
|
|
||||
|
|
||||
|
def get_hash(s): |
||||
|
m = hashlib.md5() |
||||
|
m.update(s.encode('utf8')) |
||||
|
return m.hexdigest() |
||||
|
|
||||
|
|
||||
|
def read_config(path): |
||||
|
with open(path, 'r') as fin: |
||||
|
return json.load(fin) |
||||
|
|
||||
|
|
||||
|
def init_logger(app_name): |
||||
|
os.makedirs('logs', exist_ok=True) |
||||
|
logger = mp.get_logger() |
||||
|
logger.setLevel(logging.INFO) |
||||
|
handler = logging.FileHandler('logs/%s.log' % app_name, encoding="utf8") |
||||
|
handler.setFormatter(logging.Formatter('[%(asctime)s] [%(processName)s/%(levelname)s] %(message)s')) |
||||
|
logger.addHandler(handler) |
@ -0,0 +1,42 @@ |
|||||
|
import torch |
||||
|
|
||||
|
class SpecAugment: |
||||
|
def __init__(self, params): |
||||
|
self.freq_min = params.get('cutout_min', 0.1) # 5 |
||||
|
self.freq_max = params.get('cutout_max', 0.5) # 20 |
||||
|
self.time_min = params.get('cutout_min', 0.1) # 5 |
||||
|
self.time_max = params.get('cutout_max', 0.5) # 16 |
||||
|
|
||||
|
self.cutout_min = params.get('cutout_min', 0.1) # 0.1 |
||||
|
self.cutout_max = params.get('cutout_max', 0.5) # 0.4 |
||||
|
|
||||
|
def get_mask(self, F, T): |
||||
|
mask = torch.zeros(F, T) |
||||
|
# cutout |
||||
|
cutout_max = self.cutout_max |
||||
|
cutout_min = self.cutout_min |
||||
|
f = F * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) |
||||
|
f = int(f) |
||||
|
f0 = torch.randint(0, F - f + 1, (1,)) |
||||
|
t = T * (cutout_min + torch.rand(1) * (cutout_max-cutout_min)) |
||||
|
t = int(t) |
||||
|
t0 = torch.randint(0, T - t + 1, (1,)) |
||||
|
mask[f0:f0+f, t0:t0+t] = 1 |
||||
|
|
||||
|
# frequency masking |
||||
|
f = F * (self.freq_min + torch.rand(1) * (self.freq_max - self.freq_min)) |
||||
|
f = int(f) |
||||
|
f0 = torch.randint(0, F - f + 1, (1,)) |
||||
|
mask[f0:f0+f, :] = 1 |
||||
|
|
||||
|
# time masking |
||||
|
t = T * (self.time_min + torch.rand(1) * (self.time_max - self.time_min)) |
||||
|
t = int(t) |
||||
|
t0 = torch.randint(0, T - t + 1, (1,)) |
||||
|
mask[:, t0:t0+t] = 1 |
||||
|
return mask |
||||
|
|
||||
|
def augment(self, x): |
||||
|
mask = self.get_mask(x.shape[-2], x.shape[-1]).to(x.device) |
||||
|
x = x * (1 - mask) |
||||
|
return x |
@ -0,0 +1,152 @@ |
|||||
|
## This training script is with reference to https://github.com/stdio2016/pfann |
||||
|
from typing import Any |
||||
|
|
||||
|
import numpy as np |
||||
|
import torch |
||||
|
from torch import nn |
||||
|
from towhee.trainer.callback import Callback |
||||
|
from towhee.trainer.trainer import Trainer |
||||
|
from towhee.trainer.training_config import TrainingConfig |
||||
|
|
||||
|
from .datautil.dataset_v2 import SegmentedDataLoader |
||||
|
from .datautil.simpleutils import read_config |
||||
|
from .datautil.specaug import SpecAugment |
||||
|
from torch.cuda.amp import autocast, GradScaler |
||||
|
|
||||
|
|
||||
|
# other requirements: torchvision, torchmetrics==0.7.0 |
||||
|
|
||||
|
|
||||
|
def similarity_loss(y, tau): |
||||
|
a = torch.matmul(y, y.T) |
||||
|
a /= tau |
||||
|
Ls = [] |
||||
|
for i in range(y.shape[0]): |
||||
|
nn_self = torch.cat([a[i, :i], a[i, i + 1:]]) |
||||
|
softmax = torch.nn.functional.log_softmax(nn_self, dim=0) |
||||
|
Ls.append(softmax[i if i % 2 == 0 else i - 1]) |
||||
|
Ls = torch.stack(Ls) |
||||
|
|
||||
|
loss = torch.sum(Ls) / -y.shape[0] |
||||
|
return loss |
||||
|
|
||||
|
|
||||
|
class SetEpochCallback(Callback): |
||||
|
def __init__(self, dataloader): |
||||
|
super().__init__() |
||||
|
self.dataloader = dataloader |
||||
|
|
||||
|
def on_epoch_begin(self, epochs, logs): |
||||
|
self.dataloader.set_epoch(epochs) |
||||
|
|
||||
|
|
||||
|
class NNFPTrainer(Trainer): |
||||
|
def __init__(self, model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, |
||||
|
model_card, params): |
||||
|
super().__init__(model, training_config, train_dataset, eval_dataset, train_dataloader, eval_dataloader, |
||||
|
model_card) |
||||
|
self.specaug = SpecAugment(params) |
||||
|
self.scaler = GradScaler() |
||||
|
self.tau = params.get('tau', 0.05) |
||||
|
self.losses = [] |
||||
|
self.set_epoch_callback = SetEpochCallback(self.train_dataloader) |
||||
|
self.callbacks.add_callback(self.set_epoch_callback) |
||||
|
print('evaluate before fine-tune...') |
||||
|
self.evaluate(self.model, dict()) |
||||
|
|
||||
|
def train_step(self, model: nn.Module, inputs: Any) -> dict: |
||||
|
x = inputs |
||||
|
self.optimizer.zero_grad() |
||||
|
|
||||
|
x = torch.flatten(x, 0, 1) |
||||
|
x = self.specaug.augment(x) |
||||
|
|
||||
|
with autocast(): |
||||
|
y = model(x.to(self.configs.device)) |
||||
|
loss = similarity_loss(y, self.tau) |
||||
|
self.scaler.scale(loss).backward() |
||||
|
|
||||
|
self.scaler.step(self.optimizer) |
||||
|
self.scaler.update() |
||||
|
|
||||
|
self.lr_scheduler.step() |
||||
|
|
||||
|
lossnum = float(loss.item()) |
||||
|
self.losses.append(lossnum) |
||||
|
step_logs = {"step_loss": lossnum, "epoch_loss": np.mean(self.losses), "epoch_metric": 0} |
||||
|
return step_logs |
||||
|
|
||||
|
@torch.no_grad() |
||||
|
def evaluate(self, model: nn.Module, logs: dict) -> dict: |
||||
|
validate_N = 0 |
||||
|
y_embed = [] |
||||
|
minibatch = 40 |
||||
|
if torch.cuda.get_device_properties(0).total_memory > 11e9: |
||||
|
minibatch = 640 |
||||
|
for x in self.eval_dataloader: |
||||
|
x = torch.flatten(x, 0, 1) |
||||
|
for xx in torch.split(x, minibatch): |
||||
|
y = model(xx.to(self.configs.device)).cpu() |
||||
|
y_embed.append(y) |
||||
|
y_embed = torch.cat(y_embed) |
||||
|
y_embed_org = y_embed[0::2] |
||||
|
y_embed_aug = y_embed[1::2].to(self.configs.device) |
||||
|
|
||||
|
# compute validation score on GPU |
||||
|
self_score = [] |
||||
|
for embeds in torch.split(y_embed_org, 320): |
||||
|
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) |
||||
|
self_score.append(A.diagonal(-validate_N).cpu()) |
||||
|
validate_N += embeds.shape[0] |
||||
|
self_score = torch.cat(self_score).to(self.configs.device) |
||||
|
|
||||
|
ranks = torch.zeros(validate_N, dtype=torch.long).to(self.configs.device) |
||||
|
|
||||
|
for embeds in torch.split(y_embed_org, 320): |
||||
|
A = torch.matmul(y_embed_aug, embeds.T.to(self.configs.device)) |
||||
|
ranks += (A.T >= self_score).sum(dim=0) |
||||
|
acc = int((ranks == 1).sum()) |
||||
|
print('\nvalidate score: %f' % (acc / validate_N,)) |
||||
|
# logs['epoch_metric'] = acc / validate_N |
||||
|
return logs |
||||
|
|
||||
|
|
||||
|
def train_nnfp(model, config_json_path): |
||||
|
model.train() |
||||
|
params = read_config(config_json_path) |
||||
|
train_dataloader = SegmentedDataLoader('train', params, num_workers=4) |
||||
|
print('training data contains %d samples' % len(train_dataloader.dataset)) |
||||
|
train_dataloader.shuffle = True |
||||
|
train_dataloader.eval_time_shift = False |
||||
|
train_dataloader.augmented = True |
||||
|
train_dataloader.set_epoch(0) |
||||
|
|
||||
|
val_dataloader = SegmentedDataLoader('validate', params, num_workers=4) |
||||
|
val_dataloader.shuffle = False |
||||
|
val_dataloader.eval_time_shift = True |
||||
|
val_dataloader.set_epoch(-1) |
||||
|
|
||||
|
training_config = TrainingConfig( |
||||
|
batch_size=params['batch_size'], |
||||
|
epoch_num=params['epoch'], |
||||
|
output_dir='fine_tune_output', |
||||
|
lr_scheduler_type='cosine', |
||||
|
eval_strategy='epoch', |
||||
|
lr=params['lr'] |
||||
|
) |
||||
|
|
||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=training_config.lr) |
||||
|
|
||||
|
nnfp_trainer = NNFPTrainer( |
||||
|
model=model, |
||||
|
training_config=training_config, |
||||
|
train_dataset=train_dataloader.dataset, |
||||
|
eval_dataset=None, |
||||
|
train_dataloader=train_dataloader, |
||||
|
eval_dataloader=val_dataloader, |
||||
|
model_card=None, |
||||
|
params = params |
||||
|
) |
||||
|
nnfp_trainer.set_optimizer(optimizer) |
||||
|
|
||||
|
nnfp_trainer.train() |
Loading…
Reference in new issue