diff --git a/nn_fingerprint.py b/nn_fingerprint.py index cef8c08..58b9205 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -60,7 +60,10 @@ class NNFingerprint(NNOperator): model_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') if model_path.endswith('.onnx'): log.warning('Using onnx.') - self.model = onnxruntime.InferenceSession(model_path) + self.model = onnxruntime.InferenceSession( + model_path, + providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + ) else: state_dict = torch.load(model_path, map_location=self.device) if isinstance(state_dict, torch.nn.Module): diff --git a/test.py b/test.py index b230402..87cd27d 100644 --- a/test.py +++ b/test.py @@ -39,7 +39,10 @@ op = ops.audio_embedding.nnfp() onnx_model = onnx.load('./saved/onnx/nnfp.onnx') onnx.checker.check_model(onnx_model) -ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx') +ort_session = onnxruntime.InferenceSession( +'./saved/onnx/nnfp.onnx', +providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']) + ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)} ort_outs = ort_session.run(None, ort_inputs) out3 = ort_outs[0]