From 6e5da12be237ccef13f1c48b680e7256b626dad9 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 15 Aug 2022 19:01:37 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- nn_fingerprint.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 58b9205..1432e82 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -62,7 +62,7 @@ class NNFingerprint(NNOperator): log.warning('Using onnx.') self.model = onnxruntime.InferenceSession( model_path, - providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) else: state_dict = torch.load(model_path, map_location=self.device) @@ -165,6 +165,10 @@ class NNFingerprint(NNOperator): path = path + '.pt' torch.save(self.model, path) elif format == 'torchscript': + self.device = 'cpu' + log.warning('Switched to CPU in order to support torchscript.') + dummy_input = dummy_input.to('cpu') + self.model = self.model.to('cpu') path = path + '.pt' try: try: