diff --git a/vggish.py b/vggish.py index 59fe732..86d327b 100644 --- a/vggish.py +++ b/vggish.py @@ -19,6 +19,7 @@ import os import sys import numpy from pathlib import Path +from typing import Union import torch @@ -37,6 +38,7 @@ log = logging.getLogger() class Vggish(NNOperator): """ """ + def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None: super().__init__(framework=framework) self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -49,19 +51,26 @@ class Vggish(NNOperator): self.model.eval() self.model.to(self.device) - def __call__(self, audio: str) -> numpy.ndarray: - audio_tensors = self.preprocess(audio).to(self.device) + def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: + audio_tensors = self.preprocess(audio, sr).to(self.device) features = self.model(audio_tensors) outs = features.to("cpu") return outs.detach().numpy() - def preprocess(self, audio_path: str): - audio_tensors = vggish_input.wavfile_to_examples(audio_path) + def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): + if isinstance(audio, str): + audio_tensors = vggish_input.wavfile_to_examples(audio) + elif isinstance(audio, numpy.ndarray): + try: + audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) + except Exception as e: + log.error("Fail to load audio data.") + raise e return audio_tensors # if __name__ == '__main__': # encoder = Vggish() -# audio_path = '/path/to/audio/wav' +# audio_path = '/path/to/audio' # vec = encoder(audio_path) -# print(vec.shape) +# print(vec) diff --git a/vggish_input.py b/vggish_input.py index 49368d4..56e6176 100644 --- a/vggish_input.py +++ b/vggish_input.py @@ -93,7 +93,5 @@ def wavfile_to_examples(wav_file, return_tensor=True): See waveform_to_examples. """ data, sr = torchaudio.load(wav_file) - wav_data = data.short().detach().numpy().transpose() - assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype - samples = wav_data / 32768.0 # Convert to [-1.0, +1.0] - return waveform_to_examples(samples, sr, return_tensor) + wav_data = data.detach().numpy().transpose() + return waveform_to_examples(wav_data, sr, return_tensor)