logo
Browse Source

Revert "update"

This reverts commit 56e8bc273a.
main
Jael Gu 2 years ago
parent
commit
6584777e94
  1. 4
      __init__.py
  2. 30
      vggish.py

4
__init__.py

@ -15,5 +15,5 @@
from .vggish import Vggish
def vggish(weights_path: str = None, framework: str = 'pytorch'):
return Vggish(weights_path, framework)
def vggish():
return Vggish()

30
vggish.py

@ -19,7 +19,7 @@ import os
import sys
import numpy
from pathlib import Path
from typing import Union, List, NamedTuple
from typing import Union
import torch
@ -34,9 +34,7 @@ warnings.filterwarnings('ignore')
log = logging.getLogger()
AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')])
@register(output_schema=['vec'])
class Vggish(NNOperator):
"""
"""
@ -53,19 +51,25 @@ class Vggish(NNOperator):
self.model.eval()
self.model.to(self.device)
def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray:
audios = numpy.hstack([item.audio for item in datas])
sr = datas[0].sample_rate
audio_array = numpy.stack(audios)
audio_tensors = self.preprocess(audio_array, sr).to(self.device)
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray:
audio_tensors = self.preprocess(audio, sr).to(self.device)
features = self.model(audio_tensors)
outs = features.to("cpu")
return [AudioOutput(outs.detach().numpy())]
return outs.detach().numpy()
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
audio = audio.transpose()
return vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
if isinstance(audio, str):
audio_tensors = vggish_input.wavfile_to_examples(audio)
elif isinstance(audio, numpy.ndarray):
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
except Exception as e:
log.error("Fail to load audio data.")
raise e
else:
log.error(f"Invalid input audio: {type(audio)}")
return audio_tensors
# if __name__ == '__main__':

Loading…
Cancel
Save