diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1b15f15 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +torchaudio>=0.10.2 \ No newline at end of file diff --git a/torchaudio_audio_embedding.py b/torchaudio_audio_embedding.py new file mode 100644 index 0000000..f835b59 --- /dev/null +++ b/torchaudio_audio_embedding.py @@ -0,0 +1,26 @@ +import warnings +import numpy +import torchaudio +from typing import NamedTuple + +from towhee.operator import Operator + +warnings.filterwarnings("ignore") + +class TorchaudioAudioEmbedding(Operator): + """ + PyTorch model for image embedding. + """ + def __init__(self, name: str, framework: str = 'pytorch') -> None: + super().__init__() + if framework == 'pytorch': + self._bundle = getattr(torchaudio.pipelines, name) + self._model = self._bundle.get_model() + + def __call__(self, image_file: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]): + waveform, sample_rate = torchaudio.load(image_file) + waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate) + embedding, _ = self._model.extract_features(waveform) + embedding = embedding[0].detach().numpy() + Outputs = NamedTuple('Outputs', [('embedding', numpy.ndarray)]) + return Outputs(embedding)