logo
Browse Source

Allow string input

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
324e08ebf0
  1. 11
      nn_fingerprint.py

11
nn_fingerprint.py

@ -17,7 +17,7 @@ import warnings
import os import os
from pathlib import Path from pathlib import Path
from typing import List
from typing import List, Union
import torch import torch
import torchaudio import torchaudio
@ -92,7 +92,7 @@ class NNFingerprint(NNOperator):
self.model.eval() self.model.eval()
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray:
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data) audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device: if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device) audio_tensors = audio_tensors.to(self.device)
@ -107,7 +107,10 @@ class NNFingerprint(NNOperator):
outs = features.detach().cpu().numpy() outs = features.detach().cpu().numpy()
return outs return outs
def preprocess(self, frames: List[AudioFrame]):
def preprocess(self, frames: Union[str, List[AudioFrame]]):
if isinstance(frames, str):
audio, sr = torchaudio.load(frames)
else:
sr = frames[0].sample_rate sr = frames[0].sample_rate
layout = frames[0].layout layout = frames[0].layout
if layout == 'stereo': if layout == 'stereo':
@ -117,9 +120,9 @@ class NNFingerprint(NNOperator):
audio = numpy.hstack(frames) audio = numpy.hstack(frames)
if len(audio.shape) == 1: if len(audio.shape) == 1:
audio = audio[None, :] audio = audio[None, :]
assert len(audio.shape) == 2
audio = self.int2float(audio) audio = self.int2float(audio)
audio = torch.from_numpy(audio) audio = torch.from_numpy(audio)
assert len(audio.shape) == 2
if sr != self.params['sample_rate']: if sr != self.params['sample_rate']:
resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype) resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype)

Loading…
Cancel
Save