|
@ -64,8 +64,8 @@ class NNFingerprint(NNOperator): |
|
|
if model_path.endswith('.onnx'): |
|
|
if model_path.endswith('.onnx'): |
|
|
log.warning('Using onnx.') |
|
|
log.warning('Using onnx.') |
|
|
self.model = onnxruntime.InferenceSession( |
|
|
self.model = onnxruntime.InferenceSession( |
|
|
model_path, |
|
|
|
|
|
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] |
|
|
|
|
|
|
|
|
model_path, |
|
|
|
|
|
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
try: |
|
|
try: |
|
@ -120,7 +120,7 @@ class NNFingerprint(NNOperator): |
|
|
audio = numpy.hstack(frames) |
|
|
audio = numpy.hstack(frames) |
|
|
if len(audio.shape) == 1: |
|
|
if len(audio.shape) == 1: |
|
|
audio = audio[None, :] |
|
|
audio = audio[None, :] |
|
|
audio = self.int2float(audio) |
|
|
|
|
|
|
|
|
audio = self.int2float(audio, dtype='float32') |
|
|
audio = torch.from_numpy(audio) |
|
|
audio = torch.from_numpy(audio) |
|
|
assert len(audio.shape) == 2 |
|
|
assert len(audio.shape) == 2 |
|
|
|
|
|
|
|
@ -158,10 +158,15 @@ class NNFingerprint(NNOperator): |
|
|
assert dtype.kind == 'f' |
|
|
assert dtype.kind == 'f' |
|
|
|
|
|
|
|
|
if wav.dtype.kind in 'iu': |
|
|
if wav.dtype.kind in 'iu': |
|
|
ii = numpy.iinfo(wav.dtype) |
|
|
|
|
|
abs_max = 2 ** (ii.bits - 1) |
|
|
|
|
|
offset = ii.min + abs_max |
|
|
|
|
|
return (wav.astype(dtype) - offset) / abs_max |
|
|
|
|
|
|
|
|
# ii = numpy.iinfo(wav.dtype) |
|
|
|
|
|
# abs_max = 2 ** (ii.bits - 1) |
|
|
|
|
|
# offset = ii.min + abs_max |
|
|
|
|
|
# return (wav.astype(dtype) - offset) / abs_max |
|
|
|
|
|
if wav.dtype != 'int16': |
|
|
|
|
|
wav = (wav >> 16).astype(numpy.int16) |
|
|
|
|
|
assert wav.dtype == 'int16' |
|
|
|
|
|
wav = (wav / 32768.0).astype(dtype) |
|
|
|
|
|
return wav |
|
|
else: |
|
|
else: |
|
|
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) |
|
|
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) |
|
|
return wav.astype(dtype) |
|
|
return wav.astype(dtype) |
|
|