diff --git a/torchaudio_audio_embedding.py b/torchaudio_audio_embedding.py index b9e3af6..3fdf33b 100644 --- a/torchaudio_audio_embedding.py +++ b/torchaudio_audio_embedding.py @@ -17,10 +17,10 @@ class TorchaudioAudioEmbedding(Operator): self._bundle = getattr(torchaudio.pipelines, name) self._model = self._bundle.get_model() - def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]): + def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]): waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate) - embedding, _ = self._model.extract_features(waveform) - embedding = embedding[0].detach().numpy() - Outputs = NamedTuple('Outputs', [('embedding', numpy.ndarray)]) - return Outputs(embedding) + feature_vector, _ = self._model.extract_features(waveform) + feature_vector = feature_vector[0].detach().numpy() + Outputs = NamedTuple('Outputs', [('feature_vector', numpy.ndarray)]) + return Outputs(feature_vector)