import warnings import numpy import torchaudio from typing import NamedTuple from towhee.operator import Operator warnings.filterwarnings("ignore") class TorchaudioAudioEmbedding(Operator): """ PyTorch model for image embedding. """ def __init__(self, name: str, framework: str = 'pytorch') -> None: super().__init__() if framework == 'pytorch': self._bundle = getattr(torchaudio.pipelines, name) self._model = self._bundle.get_model() def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate) feature_vector, _ = self._model.extract_features(waveform) feature_vector = feature_vector[0].detach().numpy() Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) return Outputs(feature_vector)