towhee
/
torchaudio-audio-embedding
copied
2 changed files with 27 additions and 0 deletions
@ -0,0 +1 @@ |
|||
torchaudio>=0.10.2 |
@ -0,0 +1,26 @@ |
|||
import warnings |
|||
import numpy |
|||
import torchaudio |
|||
from typing import NamedTuple |
|||
|
|||
from towhee.operator import Operator |
|||
|
|||
warnings.filterwarnings("ignore") |
|||
|
|||
class TorchaudioAudioEmbedding(Operator): |
|||
""" |
|||
PyTorch model for image embedding. |
|||
""" |
|||
def __init__(self, name: str, framework: str = 'pytorch') -> None: |
|||
super().__init__() |
|||
if framework == 'pytorch': |
|||
self._bundle = getattr(torchaudio.pipelines, name) |
|||
self._model = self._bundle.get_model() |
|||
|
|||
def __call__(self, image_file: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]): |
|||
waveform, sample_rate = torchaudio.load(image_file) |
|||
waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate) |
|||
embedding, _ = self._model.extract_features(waveform) |
|||
embedding = embedding[0].detach().numpy() |
|||
Outputs = NamedTuple('Outputs', [('embedding', numpy.ndarray)]) |
|||
return Outputs(embedding) |
Loading…
Reference in new issue