logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
8a2f373b91
  1. 21
      vggish.py
  2. 6
      vggish_input.py

21
vggish.py

@ -19,6 +19,7 @@ import os
import sys
import numpy
from pathlib import Path
from typing import Union
import torch
@ -37,6 +38,7 @@ log = logging.getLogger()
class Vggish(NNOperator):
"""
"""
def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None:
super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@ -49,19 +51,26 @@ class Vggish(NNOperator):
self.model.eval()
self.model.to(self.device)
def __call__(self, audio: str) -> numpy.ndarray:
audio_tensors = self.preprocess(audio).to(self.device)
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray:
audio_tensors = self.preprocess(audio, sr).to(self.device)
features = self.model(audio_tensors)
outs = features.to("cpu")
return outs.detach().numpy()
def preprocess(self, audio_path: str):
audio_tensors = vggish_input.wavfile_to_examples(audio_path)
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
if isinstance(audio, str):
audio_tensors = vggish_input.wavfile_to_examples(audio)
elif isinstance(audio, numpy.ndarray):
try:
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
except Exception as e:
log.error("Fail to load audio data.")
raise e
return audio_tensors
# if __name__ == '__main__':
# encoder = Vggish()
# audio_path = '/path/to/audio/wav'
# audio_path = '/path/to/audio'
# vec = encoder(audio_path)
# print(vec.shape)
# print(vec)

6
vggish_input.py

@ -93,7 +93,5 @@ def wavfile_to_examples(wav_file, return_tensor=True):
See waveform_to_examples.
"""
data, sr = torchaudio.load(wav_file)
wav_data = data.short().detach().numpy().transpose()
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr, return_tensor)
wav_data = data.detach().numpy().transpose()
return waveform_to_examples(wav_data, sr, return_tensor)

Loading…
Cancel
Save