diff --git a/nn_fingerprint.py b/nn_fingerprint.py index 8ff3fcd..680fc23 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -68,7 +68,10 @@ class NNFingerprint(NNOperator): providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] ) else: - state_dict = torch.load(model_path, map_location=self.device) + try: + state_dict = torch.jit.load(model_path, map_location=self.device) + except Exception: + state_dict = torch.load(model_path, map_location=self.device) if isinstance(state_dict, torch.nn.Module): self.model = state_dict else: @@ -170,10 +173,6 @@ class NNFingerprint(NNOperator): path = path + '.pt' torch.save(self.model, path) elif format == 'torchscript': - self.device = 'cpu' - log.warning('Switched to CPU in order to support torchscript.') - dummy_input = dummy_input.to('cpu') - self.model = self.model.to('cpu') path = path + '.pt' try: try: @@ -185,7 +184,7 @@ class NNFingerprint(NNOperator): jit_model = torch.jit.trace(self.model, dummy_input, strict=False) torch.jit.save(jit_model, path) except Exception as e: - log.error(f'Fail to save as torchscript: {e}.') + log.error('Fail to save as torchscript: %s.', e) raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': path = path + '.onnx'