nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
90 lines
3.0 KiB
90 lines
3.0 KiB
2 years ago
|
import argparse
|
||
|
import csv
|
||
|
import os
|
||
|
import warnings
|
||
|
|
||
|
import scipy.io
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.fft
|
||
|
with warnings.catch_warnings():
|
||
|
warnings.simplefilter("ignore")
|
||
|
import torchaudio
|
||
|
|
||
|
from .audio import get_audio
|
||
|
|
||
|
class AIR:
|
||
|
def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000):
|
||
|
print('loading Aachen IR dataset')
|
||
|
with open(list_csv, 'r') as fin:
|
||
|
reader = csv.reader(fin)
|
||
|
airs = []
|
||
|
firstrow = next(reader)
|
||
|
for row in reader:
|
||
|
airs.append(row[0])
|
||
|
data = []
|
||
|
to_len = int(length * sample_rate)
|
||
|
self.names = []
|
||
|
for name in airs:
|
||
|
mat = scipy.io.loadmat(os.path.join(air_dir, name))
|
||
|
h_air = torch.tensor(mat['h_air'].astype(np.float32))
|
||
|
assert h_air.shape[0] == 1
|
||
|
h_air = h_air[0]
|
||
|
air_info = mat['air_info']
|
||
|
fs = int(air_info['fs'][0][0][0][0])
|
||
|
self.names.append(str(air_info['room'][0][0][0]))
|
||
|
resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air)
|
||
|
truncated = resampled[0:to_len]
|
||
|
freqd = torch.fft.rfft(truncated, fftconv_n)
|
||
|
data.append(freqd)
|
||
|
self.data = torch.stack(data)
|
||
|
|
||
|
def random_choose(self, num):
|
||
|
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long)
|
||
|
return self.data[indices]
|
||
|
|
||
|
def random_choose_name(self):
|
||
|
index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item()
|
||
|
return self.data[index], self.names[index]
|
||
|
|
||
|
class MicIRP:
|
||
|
def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000):
|
||
|
print('loading microphone IR dataset')
|
||
|
with open(list_csv, 'r') as fin:
|
||
|
reader = csv.reader(fin)
|
||
|
mics = []
|
||
|
firstrow = next(reader)
|
||
|
for row in reader:
|
||
|
mics.append(row[0])
|
||
|
data = []
|
||
|
to_len = int(length * sample_rate)
|
||
|
for name in mics:
|
||
|
smp, smprate = get_audio(os.path.join(mic_dir, name))
|
||
|
smp = torch.FloatTensor(smp).mean(dim=0)
|
||
|
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
|
||
|
truncated = resampled[0:to_len]
|
||
|
freqd = torch.fft.rfft(truncated, fftconv_n)
|
||
|
data.append(freqd)
|
||
|
self.data = torch.stack(data)
|
||
|
|
||
|
def random_choose(self, num):
|
||
|
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long)
|
||
|
return self.data[indices]
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = argparse.ArgumentParser()
|
||
|
args.add_argument('air')
|
||
|
args.add_argument('out')
|
||
|
args = args.parse_args()
|
||
|
|
||
|
with open(args.out, 'w', encoding='utf8', newline='\n') as fout:
|
||
|
writer = csv.writer(fout)
|
||
|
writer.writerow(['file'])
|
||
|
files = []
|
||
|
for name in os.listdir(args.air):
|
||
|
if name.endswith('.mat'):
|
||
|
files.append(name)
|
||
|
files.sort()
|
||
|
for name in files:
|
||
|
writer.writerow([name])
|