logo
Browse Source

Adapt audio-decode/ffmpeg

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
49c8aab5ac
  1. 30
      README.md
  2. 4
      requirements.txt
  3. 34
      vggish.py
  4. 21
      vggish_input.py

30
README.md

@ -1,6 +1,6 @@
# Audio Embedding with Vggish # Audio Embedding with Vggish
*Author: Jael Gu*
*Author: [Jael Gu](https://github.com/jaelgu)*
<br /> <br />
@ -23,11 +23,12 @@ Generate embeddings for the audio "test.wav".
```python ```python
import towhee import towhee
towhee.glob('test.wav') \
.audio_decode() \
.time_window(range=10) \
.audio_embedding.vggish() \
.show()
(
towhee.glob('test.wav')
.audio_decode.ffmpeg()
.audio_embedding.vggish()
.show()
)
``` ```
| [-0.4931737, -0.40068552, -0.032327592, ...] shape=(10, 128) | | [-0.4931737, -0.40068552, -0.032327592, ...] shape=(10, 128) |
@ -36,12 +37,12 @@ towhee.glob('test.wav') \
```python ```python
import towhee import towhee
towhee.glob['path']('test.wav') \
.audio_decode['path', 'audio']() \
.time_window['audio', 'frames'](range=10) \
.audio_embedding.vggish['frames', 'vecs']() \
.select('vecs') \
.to_vec()
(
towhee.glob['path']('test.wav')
.audio_decode.ffmpeg['path', 'frames']()
.audio_embedding.vggish['frames', 'vecs']()
.show()
)
``` ```
[array([[-0.4931737 , -0.40068552, -0.03232759, ..., -0.33428153, [array([[-0.4931737 , -0.40068552, -0.03232759, ..., -0.33428153,
0.1333081 , -0.25221825], 0.1333081 , -0.25221825],
@ -84,10 +85,9 @@ An audio embedding operator generates vectors in numpy.ndarray given an audio fi
**Parameters:** **Parameters:**
*Union[str, towhee.types.Audio (a sub-class of numpy.ndarray)]*
*data: List[towhee.types.audio_frame.AudioFrame]*
The audio path or link in string.
Or audio input data in towhee audio frames.
Input audio data is a list of towhee audio frames.
The input data should represent for an audio longer than 0.9s. The input data should represent for an audio longer than 0.9s.

4
requirements.txt

@ -1,4 +1,4 @@
torch==1.9.0
numpy==1.19.5
torch>=1.9.0
numpy>=1.19.5
resampy resampy
torchaudio torchaudio

34
vggish.py

@ -19,13 +19,14 @@ import os
import sys import sys
import numpy import numpy
from pathlib import Path from pathlib import Path
from typing import Union
from typing import List
import torch import torch
from towhee.operator.base import NNOperator from towhee.operator.base import NNOperator
from towhee.models.vggish.torch_vggish import VGG from towhee.models.vggish.torch_vggish import VGG
from towhee import register from towhee import register
from towhee.types.audio_frame import AudioFrame
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
import vggish_input import vggish_input
@ -51,25 +52,26 @@ class Vggish(NNOperator):
self.model.eval() self.model.eval()
self.model.to(self.device) self.model.to(self.device)
def __call__(self, audio: Union[str, numpy.ndarray], sr: int = None) -> numpy.ndarray:
audio_tensors = self.preprocess(audio, sr).to(self.device)
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray:
audio_tensors = self.preprocess(data).to(self.device)
features = self.model(audio_tensors) features = self.model(audio_tensors)
outs = features.to("cpu") outs = features.to("cpu")
return outs.detach().numpy() return outs.detach().numpy()
def preprocess(self, audio: Union[str, numpy.ndarray], sr: int = None):
if isinstance(audio, str):
audio_tensors = vggish_input.wavfile_to_examples(audio)
elif isinstance(audio, numpy.ndarray):
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
except Exception as e:
log.error("Fail to load audio data.")
raise e
else:
log.error(f"Invalid input audio: {type(audio)}")
return audio_tensors
def preprocess(self, frames: List[AudioFrame]):
sr = frames[0].sample_rate
audio = numpy.hstack(frames)
if audio.dtype == numpy.int32:
audio = audio / 2147483648.0
elif audio.dtype == numpy.int16:
audio = audio / 32768.0
try:
audio = audio.transpose()
audio_tensors = vggish_input.waveform_to_examples(audio, sr, return_tensor=True)
return audio_tensors
except Exception as e:
log.error("Fail to load audio data.")
raise e
# if __name__ == '__main__': # if __name__ == '__main__':

21
vggish_input.py

@ -44,9 +44,9 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS. bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
""" """
# Convert to mono.
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Todo: convert stereo to mono.
# if len(data.shape) > 1:
# data = np.mean(data, axis=1)
# Resample to the rate assumed by VGGish. # Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE: if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
@ -81,12 +81,15 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
def wavfile_to_examples(wav_file, return_tensor=True): def wavfile_to_examples(wav_file, return_tensor=True):
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
Args:
wav_file: String path to a file, or a file-like object. The file
is assumed to contain WAV audio data with signed 16-bit PCM samples.
torch: Return data as a Pytorch tensor ready for VGGish
"""
Convenience wrapper around waveform_to_examples() for a common WAV format.
Args:
wav_file:
String path to a file, or a file-like object.
The file is assumed to contain WAV audio data with signed 16-bit PCM samples.
return_tensor:
Return data as a Pytorch tensor ready for VGGish
Returns: Returns:
See waveform_to_examples. See waveform_to_examples.

Loading…
Cancel
Save