From 49c8aab5acab4a78a8e64a692ee91a7023bd1b75 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 31 May 2022 15:01:52 +0800 Subject: [PATCH] Adapt audio-decode/ffmpeg Signed-off-by: Jael Gu --- README.md | 30 +++++++++++++++--------------- requirements.txt | 4 ++-- vggish.py | 34 ++++++++++++++++++---------------- vggish_input.py | 21 ++++++++++++--------- 4 files changed, 47 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index e639c0d..0285841 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # Audio Embedding with Vggish -*Author: Jael Gu* +*Author: [Jael Gu](https://github.com/jaelgu)*
@@ -23,11 +23,12 @@ Generate embeddings for the audio "test.wav". ```python import towhee -towhee.glob('test.wav') \ - .audio_decode() \ - .time_window(range=10) \ - .audio_embedding.vggish() \ - .show() +( + towhee.glob('test.wav') + .audio_decode.ffmpeg() + .audio_embedding.vggish() + .show() +) ``` | [-0.4931737, -0.40068552, -0.032327592, ...] shape=(10, 128) | @@ -36,12 +37,12 @@ towhee.glob('test.wav') \ ```python import towhee -towhee.glob['path']('test.wav') \ - .audio_decode['path', 'audio']() \ - .time_window['audio', 'frames'](range=10) \ - .audio_embedding.vggish['frames', 'vecs']() \ - .select('vecs') \ - .to_vec() +( + towhee.glob['path']('test.wav') + .audio_decode.ffmpeg['path', 'frames']() + .audio_embedding.vggish['frames', 'vecs']() + .show() +) ``` [array([[-0.4931737 , -0.40068552, -0.03232759, ..., -0.33428153, 0.1333081 , -0.25221825], @@ -84,10 +85,9 @@ An audio embedding operator generates vectors in numpy.ndarray given an audio fi **Parameters:** -*Union[str, towhee.types.Audio (a sub-class of numpy.ndarray)]* +*data: List[towhee.types.audio_frame.AudioFrame]* -The audio path or link in string. -Or audio input data in towhee audio frames. +Input audio data is a list of towhee audio frames. The input data should represent for an audio longer than 0.9s. diff --git a/requirements.txt b/requirements.txt index 2ae07a0..d7c19de 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -torch==1.9.0 -numpy==1.19.5 +torch>=1.9.0 +numpy>=1.19.5 resampy torchaudio diff --git a/vggish.py b/vggish.py index 69c3e3f..ab8d014 100644 --- a/vggish.py +++ b/vggish.py @@ -19,13 +19,14 @@ import os import sys import numpy from pathlib import Path -from typing import Union +from typing import List import torch from towhee.operator.base import NNOperator from towhee.models.vggish.torch_vggish import VGG from towhee import register +from towhee.types.audio_frame import AudioFrame sys.path.append(str(Path(__file__).parent)) import vggish_input @@ -51,25 +52,26 @@ class Vggish(NNOperator): self.model.eval() self.model.to(self.device) - def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray: - audio_tensors = self.preprocess(audio, sr).to(self.device) + def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: + audio_tensors = self.preprocess(data).to(self.device) features = self.model(audio_tensors) outs = features.to("cpu") return outs.detach().numpy() - def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None): - if isinstance(audio, str): - audio_tensors = vggish_input.wavfile_to_examples(audio) - elif isinstance(audio, numpy.ndarray): - try: - audio = audio.transpose() - audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) - except Exception as e: - log.error("Fail to load audio data.") - raise e - else: - log.error(f"Invalid input audio: {type(audio)}") - return audio_tensors + def preprocess(self, frames: List[AudioFrame]): + sr = frames[0].sample_rate + audio = numpy.hstack(frames) + if audio.dtype == numpy.int32: + audio = audio / 2147483648.0 + elif audio.dtype == numpy.int16: + audio = audio / 32768.0 + try: + audio = audio.transpose() + audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True) + return audio_tensors + except Exception as e: + log.error("Fail to load audio data.") + raise e # if __name__ == '__main__': diff --git a/vggish_input.py b/vggish_input.py index 856a406..0bf432c 100644 --- a/vggish_input.py +++ b/vggish_input.py @@ -44,9 +44,9 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. """ - # Convert to mono. - if len(data.shape) > 1: - data = np.mean(data, axis=1) + # Todo: convert stereo to mono. + # if len(data.shape) > 1: + # data = np.mean(data, axis=1) # Resample to the rate assumed by VGGish. if sample_rate != vggish_params.SAMPLE_RATE: data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) @@ -81,12 +81,15 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): def wavfile_to_examples(wav_file, return_tensor=True): - """Convenience wrapper around waveform_to_examples() for a common WAV format. - - Args: - wav_file: String path to a file, or a file-like object. The file - is assumed to contain WAV audio data with signed 16-bit PCM samples. - torch: Return data as a Pytorch tensor ready for VGGish + """ + Convenience wrapper around waveform_to_examples() for a common WAV format. + + Args: + wav_file: + String path to a file, or a file-like object. + The file is assumed to contain WAV audio data with signed 16-bit PCM samples. + return_tensor: + Return data as a Pytorch tensor ready for VGGish Returns: See waveform_to_examples.