diff --git a/torchaudio_audio_embedding.py b/torchaudio_audio_embedding.py index f835b59..b9e3af6 100644 --- a/torchaudio_audio_embedding.py +++ b/torchaudio_audio_embedding.py @@ -17,8 +17,8 @@ class TorchaudioAudioEmbedding(Operator): self._bundle = getattr(torchaudio.pipelines, name) self._model = self._bundle.get_model() - def __call__(self, image_file: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]): - waveform, sample_rate = torchaudio.load(image_file) + def __call__(self, audio_path: 'str') -> NamedTuple('Outputs', [('embedding', numpy.ndarray)]): + waveform, sample_rate = torchaudio.load(audio_path) waveform = torchaudio.functional.resample(waveform, sample_rate, self._bundle.sample_rate) embedding, _ = self._model.extract_features(waveform) embedding = embedding[0].detach().numpy()