logo
Browse Source

update

main
junjiejiangjjj 2 years ago
parent
commit
56e8bc273a
  1. 4
      __init__.py
  2. 30
      vggish.py

4
__init__.py

@ -15,5 +15,5 @@
from .vggish import Vggish from .vggish import Vggish
def vggish():
return Vggish()
def vggish(weights_path: str = None, framework: str = 'pytorch'):
return Vggish(weights_path, framework)

30
vggish.py

@ -19,7 +19,7 @@ import os
import sys import sys
import numpy import numpy
from pathlib import Path from pathlib import Path
from typing import Union
from typing import Union, List, NamedTuple
import torch import torch
@ -34,7 +34,9 @@ warnings.filterwarnings('ignore')
log = logging.getLogger() log = logging.getLogger()
@register(output_schema=['vec'])
AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')])
class Vggish(NNOperator): class Vggish(NNOperator):
""" """
""" """
@ -51,25 +53,19 @@ class Vggish(NNOperator):
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray:
audio_tensors = self.preprocess(audio, sr).to(self.device)
def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray:
audios = numpy.hstack([item.audio for item in datas])
sr = datas[0].sample_rate
audio_array = numpy.stack(audios)
audio_tensors = self.preprocess(audio_array, sr).to(self.device)
features = self.model(audio_tensors) features = self.model(audio_tensors)
outs = features.to("cpu") outs = features.to("cpu")
return outs.detach().numpy()
return [AudioOutput(outs.detach().numpy())]
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
if isinstance(audio, str):
audio_tensors = vggish_input.wavfile_to_examples(audio)
elif isinstance(audio, numpy.ndarray):
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
except Exception as e:
log.error("Fail to load audio data.")
raise e
else:
log.error(f"Invalid input audio: {type(audio)}")
return audio_tensors
audio = audio.transpose()
return vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
# if __name__ == '__main__': # if __name__ == '__main__':

Loading…
Cancel
Save