diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 23f9338..ac07b0b 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -17,7 +17,7 @@ import warnings import os from pathlib import Path -from typing import List +from typing import List, Union import torch import torchaudio @@ -92,7 +92,7 @@ class NNFingerprint(NNOperator): self.model.eval() log.info('Model is loaded.') - def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: + def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: audio_tensors = self.preprocess(data) if audio_tensors.device != self.device: audio_tensors = audio_tensors.to(self.device) @@ -107,19 +107,22 @@ class NNFingerprint(NNOperator): outs = features.detach().cpu().numpy() return outs - def preprocess(self, frames: List[AudioFrame]): - sr = frames[0].sample_rate - layout = frames[0].layout - if layout == 'stereo': - frames = [frame.reshape(-1, 2) for frame in frames] - audio = numpy.vstack(frames).transpose() + def preprocess(self, frames: Union[str, List[AudioFrame]]): + if isinstance(frames, str): + audio, sr = torchaudio.load(frames) else: - audio = numpy.hstack(frames) - if len(audio.shape) == 1: - audio = audio[None, :] + sr = frames[0].sample_rate + layout = frames[0].layout + if layout == 'stereo': + frames = [frame.reshape(-1, 2) for frame in frames] + audio = numpy.vstack(frames).transpose() + else: + audio = numpy.hstack(frames) + if len(audio.shape) == 1: + audio = audio[None, :] + audio = self.int2float(audio) + audio = torch.from_numpy(audio) assert len(audio.shape) == 2 - audio = self.int2float(audio) - audio = torch.from_numpy(audio) if sr != self.params['sample_rate']: resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype)