diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 58b9205..1432e82 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -62,7 +62,7 @@ class NNFingerprint(NNOperator): log.warning('Using onnx.') self.model = onnxruntime.InferenceSession( model_path, - providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] ) else: state_dict = torch.load(model_path, map_location=self.device) @@ -165,6 +165,10 @@ class NNFingerprint(NNOperator): path = path + '.pt' torch.save(self.model, path) elif format == 'torchscript': + self.device = 'cpu' + log.warning('Switched to CPU in order to support torchscript.') + dummy_input = dummy_input.to('cpu') + self.model = self.model.to('cpu') path = path + '.pt' try: try: