logo
Browse Source

Add providers to onnxruntime

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
ea0313a099
  1. 5
      nn_fingerprint.py
  2. 5
      test.py

5
nn_fingerprint.py

@ -60,7 +60,10 @@ class NNFingerprint(NNOperator):
model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt')
if model_path.endswith('.onnx'): if model_path.endswith('.onnx'):
log.warning('Using onnx.') log.warning('Using onnx.')
self.model = onnxruntime.InferenceSession(model_path)
self.model = onnxruntime.InferenceSession(
model_path,
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
)
else: else:
state_dict = torch.load(model_path, map_location=self.device) state_dict = torch.load(model_path, map_location=self.device)
if isinstance(state_dict, torch.nn.Module): if isinstance(state_dict, torch.nn.Module):

5
test.py

@ -39,7 +39,10 @@ op = ops.audio_embedding.nnfp()
onnx_model = onnx.load('./saved/onnx/nnfp.onnx') onnx_model = onnx.load('./saved/onnx/nnfp.onnx')
onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx')
ort_session = onnxruntime.InferenceSession(
'./saved/onnx/nnfp.onnx',
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'])
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)}
ort_outs = ort_session.run(None, ort_inputs) ort_outs = ort_session.run(None, ort_inputs)
out3 = ort_outs[0] out3 = ort_outs[0]

Loading…
Cancel
Save