diff --git a/requirements.txt b/requirements.txt index e55d6e3..992311d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ torch>=1.9.0 -numpy>=1.19.5 -resampy +torchaudio>=0.9.0 towhee>=0.7.0 towhee.models diff --git a/vggish.py b/vggish.py index c27066d..0139d70 100644 --- a/vggish.py +++ b/vggish.py @@ -87,10 +87,15 @@ class Vggish(NNOperator): assert dtype.kind == 'f' if wav.dtype.kind in 'iu': - ii = numpy.iinfo(wav.dtype) - abs_max = 2 ** (ii.bits - 1) - offset = ii.min + abs_max - return (wav.astype(dtype) - offset) / abs_max + # ii = numpy.iinfo(wav.dtype) + # abs_max = 2 ** (ii.bits - 1) + # offset = ii.min + abs_max + # return (wav.astype(dtype) - offset) / abs_max + if wav.dtype != 'int16': + wav = (wav >> 16).astype(numpy.int16) + assert wav.dtype == 'int16' + wav = (wav / 32768.0).astype(dtype) + return wav else: log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) return wav.astype(dtype) diff --git a/vggish_input.py b/vggish_input.py index 09e1bf2..333d741 100644 --- a/vggish_input.py +++ b/vggish_input.py @@ -17,8 +17,8 @@ # Modification: Return torch tensors rather than numpy arrays import torch +import torchaudio import numpy as np -import resampy import mel_features import vggish_params @@ -46,7 +46,9 @@ def waveform_to_examples(data, sample_rate, return_tensor=True): data = np.mean(data, axis=1) # Resample to the rate assumed by VGGish. if sample_rate != vggish_params.SAMPLE_RATE: - data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE) + data = torch.from_numpy(data) + resampler = torchaudio.transforms.Resample(sample_rate, vggish_params.SAMPLE_RATE, dtype=data.dtype) + data = resampler(data).cpu().detach().numpy() # Compute log mel spectrogram features. log_mel = mel_features.log_mel_spectrogram(