Browse Source
Add providers to onnxruntime
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
8 additions and
2 deletions
-
nn_fingerprint.py
-
test.py
|
|
@ -60,7 +60,10 @@ class NNFingerprint(NNOperator): |
|
|
|
model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') |
|
|
|
if model_path.endswith('.onnx'): |
|
|
|
log.warning('Using onnx.') |
|
|
|
self.model = onnxruntime.InferenceSession(model_path) |
|
|
|
self.model = onnxruntime.InferenceSession( |
|
|
|
model_path, |
|
|
|
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] |
|
|
|
) |
|
|
|
else: |
|
|
|
state_dict = torch.load(model_path, map_location=self.device) |
|
|
|
if isinstance(state_dict, torch.nn.Module): |
|
|
|
|
|
@ -39,7 +39,10 @@ op = ops.audio_embedding.nnfp() |
|
|
|
onnx_model = onnx.load('./saved/onnx/nnfp.onnx') |
|
|
|
onnx.checker.check_model(onnx_model) |
|
|
|
|
|
|
|
ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx') |
|
|
|
ort_session = onnxruntime.InferenceSession( |
|
|
|
'./saved/onnx/nnfp.onnx', |
|
|
|
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) |
|
|
|
|
|
|
|
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} |
|
|
|
ort_outs = ort_session.run(None, ort_inputs) |
|
|
|
out3 = ort_outs[0] |
|
|
|