logo
Browse Source

more changes

Signed-off-by: Filip Haltmayer <filip.haltmayer@zilliz.com>
main
Filip 4 years ago
parent
commit
84fd4f6753
  1. 10
      torchaudio_audio_embedding.py

10
torchaudio_audio_embedding.py

@ -17,10 +17,10 @@ class TorchaudioAudioEmbedding(Operator):
self._bundle = getattr(torchaudio.pipelines, name)
self._model = self._bundle.get_model()
def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]):
def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]):
waveform, sample_rate = torchaudio.load(audio_path)
waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate)
embedding, _ = self._model.extract_features(waveform)
embedding = embedding[0].detach().numpy()
Outputs = NamedTuple('Outputs', [('embedding', numpy.ndarray)])
return Outputs(embedding)
feature_vector, _ = self._model.extract_features(waveform)
feature_vector = feature_vector[0].detach().numpy()
Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)])
return Outputs(feature_vector)

Loading…
Cancel
Save