logo
Browse Source

Allow string input

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
324e08ebf0
  1. 29
      nn_fingerprint.py

29
nn_fingerprint.py

@ -17,7 +17,7 @@ import warnings
import os import os
from pathlib import Path from pathlib import Path
from typing import List
from typing import List, Union
import torch import torch
import torchaudio import torchaudio
@ -92,7 +92,7 @@ class NNFingerprint(NNOperator):
self.model.eval() self.model.eval()
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray:
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data) audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device: if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device) audio_tensors = audio_tensors.to(self.device)
@ -107,19 +107,22 @@ class NNFingerprint(NNOperator):
outs = features.detach().cpu().numpy() outs = features.detach().cpu().numpy()
return outs return outs
def preprocess(self, frames: List[AudioFrame]):
sr = frames[0].sample_rate
layout = frames[0].layout
if layout == 'stereo':
frames = [frame.reshape(-1, 2) for frame in frames]
audio = numpy.vstack(frames).transpose()
def preprocess(self, frames: Union[str, List[AudioFrame]]):
if isinstance(frames, str):
audio, sr = torchaudio.load(frames)
else: else:
audio = numpy.hstack(frames)
if len(audio.shape) == 1:
audio = audio[None, :]
sr = frames[0].sample_rate
layout = frames[0].layout
if layout == 'stereo':
frames = [frame.reshape(-1, 2) for frame in frames]
audio = numpy.vstack(frames).transpose()
else:
audio = numpy.hstack(frames)
if len(audio.shape) == 1:
audio = audio[None, :]
audio = self.int2float(audio)
audio = torch.from_numpy(audio)
assert len(audio.shape) == 2 assert len(audio.shape) == 2
audio = self.int2float(audio)
audio = torch.from_numpy(audio)
if sr != self.params['sample_rate']: if sr != self.params['sample_rate']:
resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype) resampler = torchaudio.transforms.Resample(sr, self.params['sample_rate'], dtype=audio.dtype)

Loading…
Cancel
Save