logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

26 lines
1003 B

import warnings
import numpy
import torchaudio
from typing import NamedTuple
from towhee.operator import Operator
warnings.filterwarnings("ignore")
class TorchaudioAudioEmbedding(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, name: str, framework: str = 'pytorch') -> None:
super().__init__()
if framework == 'pytorch':
self._bundle = getattr(torchaudio.pipelines, name)
self._model = self._bundle.get_model()
def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate)
feature_vector, _ = self._model.extract_features(waveform)
feature_vector = feature_vector[0].detach().numpy()
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(feature_vector)