nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
56 lines
1.7 KiB
56 lines
1.7 KiB
2 years ago
|
import csv
|
||
|
import os
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader, Dataset
|
||
|
import torchaudio
|
||
|
import tqdm
|
||
|
|
||
|
from .audio import get_audio
|
||
|
|
||
|
class Preprocessor(Dataset):
|
||
|
def __init__(self, files, dir, sample_rate):
|
||
|
self.files = files
|
||
|
self.dir = dir
|
||
|
self.resampler = {}
|
||
|
self.sample_rate = sample_rate
|
||
|
|
||
|
def __getitem__(self, n):
|
||
|
dat = get_audio(os.path.join(self.dir, self.files[n]))
|
||
|
wav, smprate = dat
|
||
|
if smprate not in self.resampler:
|
||
|
self.resampler[smprate] = torchaudio.transforms.Resample(smprate, self.sample_rate)
|
||
|
wav = torch.Tensor(wav)
|
||
|
wav = wav.mean(dim=0)
|
||
|
wav = self.resampler[smprate](torch.Tensor(wav))
|
||
|
|
||
|
# quantize to 16 bit again
|
||
|
wav *= 32768
|
||
|
torch.clamp(wav, -32768, 32767, out=wav)
|
||
|
wav = wav.to(torch.int16)
|
||
|
return wav
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.files)
|
||
|
|
||
|
def preprocess_music(music_dir, music_csv, sample_rate, preprocess_out):
|
||
|
print('converting music to wav')
|
||
|
with open(music_csv) as fin:
|
||
|
reader = csv.reader(fin)
|
||
|
next(reader)
|
||
|
files = [row[0] for row in reader]
|
||
|
|
||
|
preprocessor = Preprocessor(files, music_dir, sample_rate)
|
||
|
loader = DataLoader(preprocessor, num_workers=4, batch_size=None)
|
||
|
out_file = open(preprocess_out + '.bin', 'wb')
|
||
|
song_lens = []
|
||
|
for wav in tqdm.tqdm(loader):
|
||
|
# torch.set_num_threads(1) # default multithreading causes cpu contention
|
||
|
|
||
|
wav = wav.numpy()
|
||
|
out_file.write(wav.tobytes())
|
||
|
song_lens.append(wav.shape[0])
|
||
|
out_file.close()
|
||
|
np.save(preprocess_out, np.array(song_lens, dtype=np.int64))
|