logo
Browse Source

Update dtype change

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
108a0022e8
  1. 19
      nn_fingerprint.py

19
nn_fingerprint.py

@ -64,8 +64,8 @@ class NNFingerprint(NNOperator):
if model_path.endswith('.onnx'):
log.warning('Using onnx.')
self.model = onnxruntime.InferenceSession(
model_path,
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']
model_path,
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']
)
else:
try:
@ -120,7 +120,7 @@ class NNFingerprint(NNOperator):
audio = numpy.hstack(frames)
if len(audio.shape) == 1:
audio = audio[None, :]
audio = self.int2float(audio)
audio = self.int2float(audio, dtype='float32')
audio = torch.from_numpy(audio)
assert len(audio.shape) == 2
@ -158,10 +158,15 @@ class NNFingerprint(NNOperator):
assert dtype.kind == 'f'
if wav.dtype.kind in 'iu':
ii = numpy.iinfo(wav.dtype)
abs_max = 2 ** (ii.bits - 1)
offset = ii.min + abs_max
return (wav.astype(dtype) - offset) / abs_max
# ii = numpy.iinfo(wav.dtype)
# abs_max = 2 ** (ii.bits - 1)
# offset = ii.min + abs_max
# return (wav.astype(dtype) - offset) / abs_max
if wav.dtype != 'int16':
wav = (wav >> 16).astype(numpy.int16)
assert wav.dtype == 'int16'
wav = (wav / 32768.0).astype(dtype)
return wav
else:
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype)
return wav.astype(dtype)

Loading…
Cancel
Save