|
@ -19,7 +19,7 @@ import os |
|
|
import sys |
|
|
import sys |
|
|
import numpy |
|
|
import numpy |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
from typing import Union |
|
|
|
|
|
|
|
|
from typing import Union, List, NamedTuple |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
@ -34,7 +34,9 @@ warnings.filterwarnings('ignore') |
|
|
log = logging.getLogger() |
|
|
log = logging.getLogger() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
|
|
|
|
|
|
AudioOutput = NamedTuple('AudioOutput', [('vec', 'ndarray')]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Vggish(NNOperator): |
|
|
class Vggish(NNOperator): |
|
|
""" |
|
|
""" |
|
|
""" |
|
|
""" |
|
@ -51,25 +53,19 @@ class Vggish(NNOperator): |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
self.model.to(self.device) |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: |
|
|
|
|
|
audio_tensors = self.preprocess(audio, sr).to(self.device) |
|
|
|
|
|
|
|
|
def __call__(self, datas: List[NamedTuple('data', [('audio', 'ndarray'), ('sample_rate', 'int')])]) -> numpy.ndarray: |
|
|
|
|
|
audios = numpy.hstack([item.audio for item in datas]) |
|
|
|
|
|
sr = datas[0].sample_rate |
|
|
|
|
|
audio_array = numpy.stack(audios) |
|
|
|
|
|
audio_tensors = self.preprocess(audio_array, sr).to(self.device) |
|
|
features = self.model(audio_tensors) |
|
|
features = self.model(audio_tensors) |
|
|
outs = features.to("cpu") |
|
|
outs = features.to("cpu") |
|
|
return outs.detach().numpy() |
|
|
|
|
|
|
|
|
return [AudioOutput(outs.detach().numpy())] |
|
|
|
|
|
|
|
|
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): |
|
|
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): |
|
|
if isinstance(audio, str): |
|
|
|
|
|
audio_tensors = vggish_input.wavfile_to_examples(audio) |
|
|
|
|
|
elif isinstance(audio, numpy.ndarray): |
|
|
|
|
|
try: |
|
|
|
|
|
audio = audio.transpose() |
|
|
audio = audio.transpose() |
|
|
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) |
|
|
|
|
|
except Exception as e: |
|
|
|
|
|
log.error("Fail to load audio data.") |
|
|
|
|
|
raise e |
|
|
|
|
|
else: |
|
|
|
|
|
log.error(f"Invalid input audio: {type(audio)}") |
|
|
|
|
|
return audio_tensors |
|
|
|
|
|
|
|
|
return vggish_input.waveform_to_examples(audio, sr, return_tensor=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if __name__ == '__main__': |
|
|
# if __name__ == '__main__': |
|
|