diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 2bb88c2..8ff3fcd 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -65,7 +65,7 @@ class NNFingerprint(NNOperator): log.warning('Using onnx.') self.model = onnxruntime.InferenceSession( model_path, - providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] + providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] ) else: state_dict = torch.load(model_path, map_location=self.device)