From 6584777e947abf3128984697afcf3a171a0f2f38 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 21 Apr 2022 17:14:52 +0800 Subject: [PATCH] Revert "update" This reverts commit 56e8bc273ae87e5ae13c9987e7751592651bce03. --- __init__.py | 4 ++-- vggish.py | 30 +++++++++++++++++------------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/__init__.py b/__init__.py index 0a4a066..7b40d66 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .vggish import Vggish -def vggish(weights_path: str = None, framework: str = 'pytorch'): - return Vggish(weights_path, framework) +def vggish(): + return Vggish() diff --git a/vggish.py b/vggish.py index 593ff5f..69c3e3f 100644 --- a/vggish.py +++ b/vggish.py @@ -19,7 +19,7 @@ import os import sys import numpy from pathlib import Path -from typing import Union, List, NamedTuple +from typing import Union import torch @@ -34,9 +34,7 @@ warnings.filterwarnings('ignore') log = logging.getLogger() -AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) - - +@register(output_schema=['vec']) class Vggish(NNOperator): """ """ @@ -53,19 +51,25 @@ class Vggish(NNOperator): self.model.eval() self.model.to(self.device) - def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: - audios = numpy.hstack([item.audio for item in datas]) - sr = datas[0].sample_rate - audio_array = numpy.stack(audios) - audio_tensors = self.preprocess(audio_array, sr).to(self.device) + def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: + audio_tensors = self.preprocess(audio, sr).to(self.device) features = self.model(audio_tensors) outs = features.to("cpu") - return [AudioOutput(outs.detach().numpy())] + return outs.detach().numpy() def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): - audio = audio.transpose() - return vggish_input.waveform_to_examples(audio, sr, return_tensor=True) - + if isinstance(audio, str): + audio_tensors = vggish_input.wavfile_to_examples(audio) + elif isinstance(audio, numpy.ndarray): + try: + audio = audio.transpose() + audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) + except Exception as e: + log.error("Fail to load audio data.") + raise e + else: + log.error(f"Invalid input audio: {type(audio)}") + return audio_tensors # if __name__ == '__main__':