Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
2 additions and
2 deletions
-
nn_fingerprint.py
-
requirements.txt
|
|
@ -22,7 +22,6 @@ from typing import List, Union |
|
|
|
import torch |
|
|
|
import torchaudio |
|
|
|
import numpy |
|
|
|
import onnxruntime |
|
|
|
|
|
|
|
from towhee.operator.base import NNOperator |
|
|
|
from towhee import register |
|
|
@ -67,6 +66,7 @@ class NNFingerprint(NNOperator): |
|
|
|
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt') |
|
|
|
if model_path.endswith('.onnx'): |
|
|
|
log.warning('Using onnx.') |
|
|
|
import onnxruntime |
|
|
|
self.model = onnxruntime.InferenceSession( |
|
|
|
model_path, |
|
|
|
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] |
|
|
|
|
|
@ -1,4 +1,4 @@ |
|
|
|
towhee.models>=0.8.0 |
|
|
|
torch |
|
|
|
torchaudio |
|
|
|
onnxruntime |
|
|
|
# onnxruntime |
|
|
|