From 56e8bc273ae87e5ae13c9987e7751592651bce03 Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Wed, 20 Apr 2022 16:07:01 +0800 Subject: [PATCH] update --- __init__.py | 4 ++-- vggish.py | 30 +++++++++++++----------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/__init__.py b/__init__.py index 7b40d66..0a4a066 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .vggish import Vggish -def vggish(): - return Vggish() +def vggish(weights_path: str = None, framework: str = 'pytorch'): + return Vggish(weights_path, framework) diff --git a/vggish.py b/vggish.py index 69c3e3f..593ff5f 100644 --- a/vggish.py +++ b/vggish.py @@ -19,7 +19,7 @@ import os import sys import numpy from pathlib import Path -from typing import Union +from typing import Union, List, NamedTuple import torch @@ -34,7 +34,9 @@ warnings.filterwarnings('ignore') log = logging.getLogger() -@register(output_schema=['vec']) +AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) + + class Vggish(NNOperator): """ """ @@ -51,25 +53,19 @@ class Vggish(NNOperator): self.model.eval() self.model.to(self.device) - def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: - audio_tensors = self.preprocess(audio, sr).to(self.device) + def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: + audios = numpy.hstack([item.audio for item in datas]) + sr = datas[0].sample_rate + audio_array = numpy.stack(audios) + audio_tensors = self.preprocess(audio_array, sr).to(self.device) features = self.model(audio_tensors) outs = features.to("cpu") - return outs.detach().numpy() + return [AudioOutput(outs.detach().numpy())] def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): - if isinstance(audio, str): - audio_tensors = vggish_input.wavfile_to_examples(audio) - elif isinstance(audio, numpy.ndarray): - try: - audio = audio.transpose() - audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) - except Exception as e: - log.error("Fail to load audio data.") - raise e - else: - log.error(f"Invalid input audio: {type(audio)}") - return audio_tensors + audio = audio.transpose() + return vggish_input.waveform_to_examples(audio, sr, return_tensor=True) + # if __name__ == '__main__':