logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

89 lines
3.0 KiB

import argparse
import csv
import os
import warnings
import scipy.io
import numpy as np
import torch
import torch.fft
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import torchaudio
from .audio import get_audio
class AIR:
def __init__(self, air_dir, list_csv, length, fftconv_n, sample_rate=8000):
print('loading Aachen IR dataset')
with open(list_csv, 'r') as fin:
reader = csv.reader(fin)
airs = []
firstrow = next(reader)
for row in reader:
airs.append(row[0])
data = []
to_len = int(length * sample_rate)
self.names = []
for name in airs:
mat = scipy.io.loadmat(os.path.join(air_dir, name))
h_air = torch.tensor(mat['h_air'].astype(np.float32))
assert h_air.shape[0] == 1
h_air = h_air[0]
air_info = mat['air_info']
fs = int(air_info['fs'][0][0][0][0])
self.names.append(str(air_info['room'][0][0][0]))
resampled = torchaudio.transforms.Resample(fs, sample_rate)(h_air)
truncated = resampled[0:to_len]
freqd = torch.fft.rfft(truncated, fftconv_n)
data.append(freqd)
self.data = torch.stack(data)
def random_choose(self, num):
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long)
return self.data[indices]
def random_choose_name(self):
index = torch.randint(0, self.data.shape[0], size=(1,), dtype=torch.long).item()
return self.data[index], self.names[index]
class MicIRP:
def __init__(self, mic_dir, list_csv, length, fftconv_n, sample_rate=8000):
print('loading microphone IR dataset')
with open(list_csv, 'r') as fin:
reader = csv.reader(fin)
mics = []
firstrow = next(reader)
for row in reader:
mics.append(row[0])
data = []
to_len = int(length * sample_rate)
for name in mics:
smp, smprate = get_audio(os.path.join(mic_dir, name))
smp = torch.FloatTensor(smp).mean(dim=0)
resampled = torchaudio.transforms.Resample(smprate, sample_rate)(smp)
truncated = resampled[0:to_len]
freqd = torch.fft.rfft(truncated, fftconv_n)
data.append(freqd)
self.data = torch.stack(data)
def random_choose(self, num):
indices = torch.randint(0, self.data.shape[0], size=(num,), dtype=torch.long)
return self.data[indices]
if __name__ == '__main__':
args = argparse.ArgumentParser()
args.add_argument('air')
args.add_argument('out')
args = args.parse_args()
with open(args.out, 'w', encoding='utf8', newline='\n') as fout:
writer = csv.writer(fout)
writer.writerow(['file'])
files = []
for name in os.listdir(args.air):
if name.endswith('.mat'):
files.append(name)
files.sort()
for name in files:
writer.writerow([name])