logo
Browse Source

Revert "update"

This reverts commit 56e8bc273a.
main
Jael Gu 2 years ago
parent
commit
6584777e94
  1. 4
      __init__.py
  2. 30
      vggish.py

4
__init__.py

@ -15,5 +15,5 @@
from .vggish import Vggish from .vggish import Vggish
def vggish(weights_path: str = None, framework: str = 'pytorch'):
return Vggish(weights_path, framework)
def vggish():
return Vggish()

30
vggish.py

@ -19,7 +19,7 @@ import os
import sys import sys
import numpy import numpy
from pathlib import Path from pathlib import Path
from typing import Union, List, NamedTuple
from typing import Union
import torch import torch
@ -34,9 +34,7 @@ warnings.filterwarnings('ignore')
log = logging.getLogger() log = logging.getLogger()
AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')])
@register(output_schema=['vec'])
class Vggish(NNOperator): class Vggish(NNOperator):
""" """
""" """
@ -53,19 +51,25 @@ class Vggish(NNOperator):
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray:
audios = numpy.hstack([item.audio for item in datas])
sr = datas[0].sample_rate
audio_array = numpy.stack(audios)
audio_tensors = self.preprocess(audio_array, sr).to(self.device)
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray:
audio_tensors = self.preprocess(audio, sr).to(self.device)
features = self.model(audio_tensors) features = self.model(audio_tensors)
outs = features.to("cpu") outs = features.to("cpu")
return [AudioOutput(outs.detach().numpy())]
return outs.detach().numpy()
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
audio = audio.transpose()
return vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
if isinstance(audio, str):
audio_tensors = vggish_input.wavfile_to_examples(audio)
elif isinstance(audio, numpy.ndarray):
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
except Exception as e:
log.error("Fail to load audio data.")
raise e
else:
log.error(f"Invalid input audio: {type(audio)}")
return audio_tensors
# if __name__ == '__main__': # if __name__ == '__main__':

Loading…
Cancel
Save