logo
Browse Source

Allow string input

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
324e08ebf0
  1. 11
      nn_fingerprint.py

11
nn_fingerprint.py

@ -17,7 +17,7 @@ import warnings
import os
from pathlib import Path
from typing import List
from typing import List, Union
import torch
import torchaudio
@ -92,7 +92,7 @@ class NNFingerprint(NNOperator):
self.model.eval()
log.info('Model is loaded.')
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray:
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device)
@ -107,7 +107,10 @@ class NNFingerprint(NNOperator):
outs = features.detach().cpu().numpy()
return outs
def preprocess(self, frames: List[AudioFrame]):
def preprocess(self, frames: Union[str, List[AudioFrame]]):
if isinstance(frames, str):
audio, sr = torchaudio.load(frames)
else:
sr = frames[0].sample_rate
layout = frames[0].layout
if layout == 'stereo':
@ -117,9 +120,9 @@ class NNFingerprint(NNOperator):
audio = numpy.hstack(frames)
if len(audio.shape) == 1:
audio = audio[None, :]
assert len(audio.shape) == 2
audio = self.int2float(audio)
audio = torch.from_numpy(audio)
assert len(audio.shape) == 2
if sr != self.params['sample_rate']:
resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype)

Loading…
Cancel
Save