logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
d4b82c7f9b
  1. 2
      nn_fingerprint.py
  2. 2
      requirements.txt

2
nn_fingerprint.py

@ -22,7 +22,6 @@ from typing import List, Union
import torch import torch
import torchaudio import torchaudio
import numpy import numpy
import onnxruntime
from towhee.operator.base import NNOperator from towhee.operator.base import NNOperator
from towhee import register from towhee import register
@ -67,6 +66,7 @@ class NNFingerprint(NNOperator):
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt') model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
if model_path.endswith('.onnx'): if model_path.endswith('.onnx'):
log.warning('Using onnx.') log.warning('Using onnx.')
import onnxruntime
self.model = onnxruntime.InferenceSession( self.model = onnxruntime.InferenceSession(
model_path, model_path,
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']

2
requirements.txt

@ -1,4 +1,4 @@
towhee.models>=0.8.0 towhee.models>=0.8.0
torch torch
torchaudio torchaudio
onnxruntime
# onnxruntime

Loading…
Cancel
Save