Browse Source
Replace resampy with torchaudio
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
3 changed files with
14 additions and
8 deletions
-
requirements.txt
-
vggish.py
-
vggish_input.py
|
|
@ -1,6 +1,5 @@ |
|
|
|
torch>=1.9.0 |
|
|
|
numpy>=1.19.5 |
|
|
|
resampy |
|
|
|
torchaudio>=0.9.0 |
|
|
|
|
|
|
|
towhee>=0.7.0 |
|
|
|
towhee.models |
|
|
|
|
|
@ -87,10 +87,15 @@ class Vggish(NNOperator): |
|
|
|
assert dtype.kind == 'f' |
|
|
|
|
|
|
|
if wav.dtype.kind in 'iu': |
|
|
|
ii = numpy.iinfo(wav.dtype) |
|
|
|
abs_max = 2 ** (ii.bits - 1) |
|
|
|
offset = ii.min + abs_max |
|
|
|
return (wav.astype(dtype) - offset) / abs_max |
|
|
|
# ii = numpy.iinfo(wav.dtype) |
|
|
|
# abs_max = 2 ** (ii.bits - 1) |
|
|
|
# offset = ii.min + abs_max |
|
|
|
# return (wav.astype(dtype) - offset) / abs_max |
|
|
|
if wav.dtype != 'int16': |
|
|
|
wav = (wav >> 16).astype(numpy.int16) |
|
|
|
assert wav.dtype == 'int16' |
|
|
|
wav = (wav / 32768.0).astype(dtype) |
|
|
|
return wav |
|
|
|
else: |
|
|
|
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) |
|
|
|
return wav.astype(dtype) |
|
|
|
|
|
@ -17,8 +17,8 @@ |
|
|
|
|
|
|
|
# Modification: Return torch tensors rather than numpy arrays |
|
|
|
import torch |
|
|
|
import torchaudio |
|
|
|
import numpy as np |
|
|
|
import resampy |
|
|
|
|
|
|
|
import mel_features |
|
|
|
import vggish_params |
|
|
@ -46,7 +46,9 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): |
|
|
|
data = np.mean(data, axis=1) |
|
|
|
# Resample to the rate assumed by VGGish. |
|
|
|
if sample_rate != vggish_params.SAMPLE_RATE: |
|
|
|
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) |
|
|
|
data = torch.from_numpy(data) |
|
|
|
resampler = torchaudio.transforms.Resample(sample_rate, vggish_params.SAMPLE_RATE, dtype=data.dtype) |
|
|
|
data = resampler(data).cpu().detach().numpy() |
|
|
|
|
|
|
|
# Compute log mel spectrogram features. |
|
|
|
log_mel = mel_features.log_mel_spectrogram( |
|
|
|