logo
Browse Source

Replace resampy with torchaudio

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
737c9f9a55
  1. 3
      requirements.txt
  2. 13
      vggish.py
  3. 6
      vggish_input.py

3
requirements.txt

@ -1,6 +1,5 @@
torch>=1.9.0
numpy>=1.19.5
resampy
torchaudio>=0.9.0
towhee>=0.7.0
towhee.models

13
vggish.py

@ -87,10 +87,15 @@ class Vggish(NNOperator):
assert dtype.kind == 'f'
if wav.dtype.kind in 'iu':
ii = numpy.iinfo(wav.dtype)
abs_max = 2 ** (ii.bits - 1)
offset = ii.min + abs_max
return (wav.astype(dtype) - offset) / abs_max
# ii = numpy.iinfo(wav.dtype)
# abs_max = 2 ** (ii.bits - 1)
# offset = ii.min + abs_max
# return (wav.astype(dtype) - offset) / abs_max
if wav.dtype != 'int16':
wav = (wav >> 16).astype(numpy.int16)
assert wav.dtype == 'int16'
wav = (wav / 32768.0).astype(dtype)
return wav
else:
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype)
return wav.astype(dtype)

6
vggish_input.py

@ -17,8 +17,8 @@
# Modification: Return torch tensors rather than numpy arrays
import torch
import torchaudio
import numpy as np
import resampy
import mel_features
import vggish_params
@ -46,7 +46,9 @@ def waveform_to_examples(data, sample_rate, return_tensor=True):
data = np.mean(data, axis=1)
# Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
data = torch.from_numpy(data)
resampler = torchaudio.transforms.Resample(sample_rate, vggish_params.SAMPLE_RATE, dtype=data.dtype)
data = resampler(data).cpu().detach().numpy()
# Compute log mel spectrogram features.
log_mel = mel_features.log_mel_spectrogram(

Loading…
Cancel
Save