logo
Browse Source

Refactor

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
6e01ca1882
  1. 29
      .gitattributes
  2. 69
      README.md
  3. 19
      __init__.py
  4. 223
      mel_features.py
  5. 4
      requirements.txt
  6. BIN
      vggish.pth
  7. 65
      vggish.py
  8. 99
      vggish_input.py
  9. 53
      vggish_params.py

29
.gitattributes

@ -1,28 +1 @@
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bin.* filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zstandard filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
*vggish.pth filter=lfs diff=lfs merge=lfs -text

69
README.md

@ -1,2 +1,69 @@
# vggish
# Audio Embedding with Vggish
*Author: Jael Gu*
## Desription
The audio embedding operator converts an input audio into a dense vector which can be used to represent the audio clip's semantics.
This operator is built on top of the VGGish model with Pytorch.
It is originally implemented in [Tensorflow](https://github.com/tensorflow/models/tree/master/research/audioset/vggish).
The model is pre-trained with a large scale of audio dataset [AudioSet](https://research.google.com/audioset).
As suggested, it is suitable to extract features at high level or warm up a larger model.
```python
from towhee import ops
audio_encoder = ops.audio_embedding.vggish()
audio_embedding = audio_encoder("/path/to/audio")
```
## Factory Constructor
Create the operator via the following factory method
***ops.audio_embedding.vggish()***
## Interface
An audio embedding operator generates vectors in numpy.ndarray given an audio file path.
**Parameters:**
​ None.
**Returns**: *numpy.ndarray*
​ Audio embeddings.
## Code Example
Generate embeddings for the audio "test.wav".
*Write the pipeline in simplified style*:
```python
from towhee import dc
dc.glob('test.wav')
.audio_embedding.vggish()
.show()
```
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
from towhee import dc
dc.glob['path']('test.wav')
.audio_embedding.vggish['path', 'vecs']()
.select('vecs')
.show()
```

19
__init__.py

@ -0,0 +1,19 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .vggish import Vggish
def vggish():
return Vggish()

223
mel_features.py

@ -0,0 +1,223 @@
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines routines to compute mel spectrogram features from audio waveform."""
import numpy as np
def frame(data, window_length, hop_length):
"""Convert array into a sequence of successive possibly overlapping frames.
An n-dimensional array of shape (num_samples, ...) is converted into an
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
starts hop_length points after the preceding one.
This is accomplished using stride_tricks, so the original data is not
copied. However, there is no zero-padding, so any incomplete frames at the
end are not included.
Args:
data: np.array of dimension N >= 1.
window_length: Number of samples in each frame.
hop_length: Advance (in samples) between each window.
Returns:
(N+1)-D np.array with as many rows as there are complete frames that can be
extracted.
"""
num_samples = data.shape[0]
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
shape = (num_frames, window_length) + data.shape[1:]
strides = (data.strides[0] * hop_length,) + data.strides
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
def periodic_hann(window_length):
"""Calculate a "periodic" Hann window.
The classic Hann window is defined as a raised cosine that starts and
ends on zero, and where every value appears twice, except the middle
point for an odd-length window. Matlab calls this a "symmetric" window
and np.hanning() returns it. However, for Fourier analysis, this
actually represents just over one cycle of a period N-1 cosine, and
thus is not compactly expressed on a length-N Fourier basis. Instead,
it's better to use a raised cosine that ends just before the final
zero value - i.e. a complete cycle of a period-N cosine. Matlab
calls this a "periodic" window. This routine calculates it.
Args:
window_length: The number of points in the returned window.
Returns:
A 1D np.array containing the periodic hann window.
"""
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
np.arange(window_length)))
def stft_magnitude(signal, fft_length,
hop_length=None,
window_length=None):
"""Calculate the short-time Fourier transform magnitude.
Args:
signal: 1D np.array of the input time-domain signal.
fft_length: Size of the FFT to apply.
hop_length: Advance (in samples) between each frame passed to FFT.
window_length: Length of each block of samples to pass to FFT.
Returns:
2D np.array where each row contains the magnitudes of the fft_length/2+1
unique values of the FFT for the corresponding frame of input samples.
"""
frames = frame(signal, window_length, hop_length)
# Apply frame window to each frame. We use a periodic Hann (cosine of period
# window_length) instead of the symmetric Hann of np.hanning (period
# window_length-1).
window = periodic_hann(window_length)
windowed_frames = frames * window
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
# Mel spectrum constants and functions.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
def hertz_to_mel(frequencies_hertz):
"""Convert frequencies to mel scale using HTK formula.
Args:
frequencies_hertz: Scalar or np.array of frequencies in hertz.
Returns:
Object of same size as frequencies_hertz containing corresponding values
on the mel scale.
"""
return _MEL_HIGH_FREQUENCY_Q * np.log(
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
def spectrogram_to_mel_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
upper_edge_hertz=3800.0):
"""Return a matrix that can post-multiply spectrogram rows to make mel.
Returns a np.array matrix A that can be used to post-multiply a matrix S of
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
"mel spectrogram" M of frames x num_mel_bins. M = S A.
The classic HTK algorithm exploits the complementarity of adjacent mel bands
to multiply each FFT bin by only one mel weight, then add it, with positive
and negative signs, to the two adjacent mel bands to which that bin
contributes. Here, by expressing this operation as a matrix multiply, we go
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
num_fft^2 multiplies and adds. However, because these are all presumably
accomplished in a single call to np.dot(), it's not clear which approach is
faster in Python. The matrix multiplication has the attraction of being more
general and flexible, and much easier to read.
Args:
num_mel_bins: How many bands in the resulting mel spectrum. This is
the number of columns in the output matrix.
num_spectrogram_bins: How many bins there are in the source spectrogram
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
only contains the nonredundant FFT bins.
audio_sample_rate: Samples per second of the audio at the input to the
spectrogram. We need this to figure out the actual frequencies for
each spectrogram bin, which dictates how they are mapped into mel.
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
spectrum. This corresponds to the lower edge of the lowest triangular
band.
upper_edge_hertz: The desired top edge of the highest frequency band.
Returns:
An np.array with shape (num_spectrogram_bins, num_mel_bins).
Raises:
ValueError: if frequency edges are incorrectly ordered or out of range.
"""
nyquist_hertz = audio_sample_rate / 2.
if lower_edge_hertz < 0.0:
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
if lower_edge_hertz >= upper_edge_hertz:
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
(lower_edge_hertz, upper_edge_hertz))
if upper_edge_hertz > nyquist_hertz:
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
(upper_edge_hertz, nyquist_hertz))
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
# The i'th mel band (starting from i=1) has center frequency
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
# the band_edges_mel arrays.
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
# of spectrogram values.
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
for i in range(num_mel_bins):
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the *mel* domain, not hertz.
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
(center_mel - lower_edge_mel))
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
(upper_edge_mel - center_mel))
# .. then intersect them with each other and zero.
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
upper_slope))
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
# coefficient.
mel_weights_matrix[0, :] = 0.0
return mel_weights_matrix
def log_mel_spectrogram(data,
audio_sample_rate=8000,
log_offset=0.0,
window_length_secs=0.025,
hop_length_secs=0.010,
**kwargs):
"""Convert waveform to a log magnitude mel-frequency spectrogram.
Args:
data: 1D np.array of waveform data.
audio_sample_rate: The sampling rate of data.
log_offset: Add this to values when taking log to avoid -Infs.
window_length_secs: Duration of each window to analyze.
hop_length_secs: Advance between successive analysis windows.
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
Returns:
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
magnitudes for successive frames.
"""
window_length_samples = int(round(audio_sample_rate * window_length_secs))
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
spectrogram = stft_magnitude(
data,
fft_length=fft_length,
hop_length=hop_length_samples,
window_length=window_length_samples)
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
num_spectrogram_bins=spectrogram.shape[1],
audio_sample_rate=audio_sample_rate, **kwargs))
return np.log(mel_spectrogram + log_offset)

4
requirements.txt

@ -0,0 +1,4 @@
torch==1.9.0
numpy==1.19.5
resampy
torchaudio

BIN
vggish.pth (Stored with Git LFS)

Binary file not shown.

65
vggish.py

@ -0,0 +1,65 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import numpy
from pathlib import Path
import torch
from towhee.operator.base import NNOperator
from towhee.models.vggish.torch_vggish import VGG
from towhee import register
import vggish_input
import warnings
warnings.filterwarnings('ignore')
log = logging.getLogger()
@register(output_schema=['vec'])
class Vggish(NNOperator):
"""
"""
def __init__(self, weights_path: str = None, framework: str = 'pytorch') -> None:
super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = VGG()
if not weights_path:
path = str(Path(__file__).parent)
weights_path = os.path.join(path, 'vggish.pth')
state_dict = torch.load(weights_path, map_location=torch.device('cpu'))
self.model.load_state_dict(state_dict)
self.model.eval()
self.model.to(self.device)
def __call__(self, audio: str) -> numpy.ndarray:
audio_tensors = self.preprocess(audio).to(self.device)
features = self.model(audio_tensors)
outs = features.to("cpu")
return outs.detach().numpy()
def preprocess(self, audio_path: str):
audio_tensors = vggish_input.wavfile_to_examples(audio_path)
return audio_tensors
# if __name__ == '__main__':
# encoder = Vggish()
# audio_path = '/path/to/audio/wav'
# vec = encoder(audio_path)
# print(vec.shape)

99
vggish_input.py

@ -0,0 +1,99 @@
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute input examples for VGGish from audio waveform."""
# Modification: Return torch tensors rather than numpy arrays
import torch
import numpy as np
import resampy
import mel_features
import vggish_params
import torchaudio
def waveform_to_examples(data, sample_rate, return_tensor=True):
"""Converts audio waveform into an array of examples for VGGish.
Args:
data: np.array of either one dimension (mono) or two dimensions
(multi-channel, with the outer dimension representing channels).
Each sample is generally expected to lie in the range [-1.0, +1.0],
although this is not required.
sample_rate: Sample rate of data.
return_tensor: Return data as a Pytorch tensor ready for VGGish
Returns:
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
a sequence of examples, each of which contains a patch of log mel
spectrogram, covering num_frames frames of audio and num_bands mel frequency
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
"""
# Convert to mono.
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
# Compute log mel spectrogram features.
log_mel = mel_features.log_mel_spectrogram(
data,
audio_sample_rate=vggish_params.SAMPLE_RATE,
log_offset=vggish_params.LOG_OFFSET,
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
num_mel_bins=vggish_params.NUM_MEL_BINS,
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
# Frame features into examples.
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
example_window_length = int(round(
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
example_hop_length = int(round(
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
log_mel_examples = mel_features.frame(
log_mel,
window_length=example_window_length,
hop_length=example_hop_length)
if return_tensor:
log_mel_examples = torch.tensor(
log_mel_examples, requires_grad=True)[:, None, :, :].float()
return log_mel_examples
def wavfile_to_examples(wav_file, return_tensor=True):
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
Args:
wav_file: String path to a file, or a file-like object. The file
is assumed to contain WAV audio data with signed 16-bit PCM samples.
torch: Return data as a Pytorch tensor ready for VGGish
Returns:
See waveform_to_examples.
"""
data, sr = torchaudio.load(wav_file)
wav_data = data.short().detach().numpy().transpose()
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr, return_tensor)

53
vggish_params.py

@ -0,0 +1,53 @@
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global parameters for the VGGish model.
See vggish_slim.py for more information.
"""
# Architectural constants.
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
EMBEDDING_SIZE = 128 # Size of embedding layer.
# Hyperparameters used in feature and example generation.
SAMPLE_RATE = 16000
STFT_WINDOW_LENGTH_SECONDS = 0.025
STFT_HOP_LENGTH_SECONDS = 0.010
NUM_MEL_BINS = NUM_BANDS
MEL_MIN_HZ = 125
MEL_MAX_HZ = 7500
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
# Parameters used for embedding postprocessing.
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
PCA_MEANS_NAME = 'pca_means'
QUANTIZE_MIN_VAL = -2.0
QUANTIZE_MAX_VAL = +2.0
# Hyperparameters used in training.
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
# Names of ops, tensors, and features.
INPUT_OP_NAME = 'vggish/input_features'
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
OUTPUT_OP_NAME = 'vggish/embedding'
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
Loading…
Cancel
Save