|
|
@ -17,7 +17,7 @@ import warnings |
|
|
|
|
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
from typing import List |
|
|
|
from typing import List, Union |
|
|
|
|
|
|
|
import torch |
|
|
|
import torchaudio |
|
|
@ -92,7 +92,7 @@ class NNFingerprint(NNOperator): |
|
|
|
self.model.eval() |
|
|
|
log.info('Model is loaded.') |
|
|
|
|
|
|
|
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: |
|
|
|
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: |
|
|
|
audio_tensors = self.preprocess(data) |
|
|
|
if audio_tensors.device != self.device: |
|
|
|
audio_tensors = audio_tensors.to(self.device) |
|
|
@ -107,7 +107,10 @@ class NNFingerprint(NNOperator): |
|
|
|
outs = features.detach().cpu().numpy() |
|
|
|
return outs |
|
|
|
|
|
|
|
def preprocess(self, frames: List[AudioFrame]): |
|
|
|
def preprocess(self, frames: Union[str, List[AudioFrame]]): |
|
|
|
if isinstance(frames, str): |
|
|
|
audio, sr = torchaudio.load(frames) |
|
|
|
else: |
|
|
|
sr = frames[0].sample_rate |
|
|
|
layout = frames[0].layout |
|
|
|
if layout == 'stereo': |
|
|
@ -117,9 +120,9 @@ class NNFingerprint(NNOperator): |
|
|
|
audio = numpy.hstack(frames) |
|
|
|
if len(audio.shape) == 1: |
|
|
|
audio = audio[None, :] |
|
|
|
assert len(audio.shape) == 2 |
|
|
|
audio = self.int2float(audio) |
|
|
|
audio = torch.from_numpy(audio) |
|
|
|
assert len(audio.shape) == 2 |
|
|
|
|
|
|
|
if sr != self.params['sample_rate']: |
|
|
|
resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype) |
|
|
|