From 108a0022e8233d4bca5bd94f9bb1133eafca2e32 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 29 Sep 2022 15:51:12 +0800 Subject: [PATCH] Update dtype change Signed-off-by: Jael Gu --- nn_fingerprint.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/nn_fingerprint.py b/nn_fingerprint.py index ac07b0b..27925f3 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -64,8 +64,8 @@ class NNFingerprint(NNOperator): if model_path.endswith('.onnx'): log.warning('Using onnx.') self.model = onnxruntime.InferenceSession( - model_path, - providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] + model_path, + providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] ) else: try: @@ -120,7 +120,7 @@ class NNFingerprint(NNOperator): audio = numpy.hstack(frames) if len(audio.shape) == 1: audio = audio[None, :] - audio = self.int2float(audio) + audio = self.int2float(audio, dtype='float32') audio = torch.from_numpy(audio) assert len(audio.shape) == 2 @@ -158,10 +158,15 @@ class NNFingerprint(NNOperator): assert dtype.kind == 'f' if wav.dtype.kind in 'iu': - ii = numpy.iinfo(wav.dtype) - abs_max = 2 ** (ii.bits - 1) - offset = ii.min + abs_max - return (wav.astype(dtype) - offset) / abs_max + # ii = numpy.iinfo(wav.dtype) + # abs_max = 2 ** (ii.bits - 1) + # offset = ii.min + abs_max + # return (wav.astype(dtype) - offset) / abs_max + if wav.dtype != 'int16': + wav = (wav >> 16).astype(numpy.int16) + assert wav.dtype == 'int16' + wav = (wav / 32768.0).astype(dtype) + return wav else: log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) return wav.astype(dtype)