towhee
/
torchaudio-audio-embedding
copied
2 changed files with 27 additions and 0 deletions
@ -0,0 +1 @@ |
|||||
|
torchaudio>=0.10.2 |
@ -0,0 +1,26 @@ |
|||||
|
import warnings |
||||
|
import numpy |
||||
|
import torchaudio |
||||
|
from typing import NamedTuple |
||||
|
|
||||
|
from towhee.operator import Operator |
||||
|
|
||||
|
warnings.filterwarnings("ignore") |
||||
|
|
||||
|
class TorchaudioAudioEmbedding(Operator): |
||||
|
""" |
||||
|
PyTorch model for image embedding. |
||||
|
""" |
||||
|
def __init__(self, name: str, framework: str = 'pytorch') -> None: |
||||
|
super().__init__() |
||||
|
if framework == 'pytorch': |
||||
|
self._bundle = getattr(torchaudio.pipelines, name) |
||||
|
self._model = self._bundle.get_model() |
||||
|
|
||||
|
def __call__(self, image_file: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]): |
||||
|
waveform, sample_rate = torchaudio.load(image_file) |
||||
|
waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate) |
||||
|
embedding, _ = self._model.extract_features(waveform) |
||||
|
embedding = embedding[0].detach().numpy() |
||||
|
Outputs = NamedTuple('Outputs', [('embedding', numpy.ndarray)]) |
||||
|
return Outputs(embedding) |
Loading…
Reference in new issue