logo
Browse Source

operator

main
Filip 4 years ago
parent
commit
c5e4fe9aec
  1. 1
      requirements.txt
  2. 26
      torchaudio_audio_embedding.py

1
requirements.txt

@ -0,0 +1 @@
torchaudio>=0.10.2

26
torchaudio_audio_embedding.py

@ -0,0 +1,26 @@
import warnings
import numpy
import torchaudio
from typing import NamedTuple
from towhee.operator import Operator
warnings.filterwarnings("ignore")
class TorchaudioAudioEmbedding(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, name: str, framework: str = 'pytorch') -> None:
super().__init__()
if framework == 'pytorch':
self._bundle = getattr(torchaudio.pipelines, name)
self._model = self._bundle.get_model()
def __call__(self, image_file: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]):
waveform, sample_rate = torchaudio.load(image_file)
waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate)
embedding, _ = self._model.extract_features(waveform)
embedding = embedding[0].detach().numpy()
Outputs = NamedTuple('Outputs', [('embedding', numpy.ndarray)])
return Outputs(embedding)
Loading…
Cancel
Save