logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
dbbee2277e
  1. 11
      nn_fingerprint.py

11
nn_fingerprint.py

@ -68,7 +68,10 @@ class NNFingerprint(NNOperator):
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']
)
else:
state_dict = torch.load(model_path, map_location=self.device)
try:
state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception:
state_dict = torch.load(model_path, map_location=self.device)
if isinstance(state_dict, torch.nn.Module):
self.model = state_dict
else:
@ -170,10 +173,6 @@ class NNFingerprint(NNOperator):
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
self.device = 'cpu'
log.warning('Switched to CPU in order to support torchscript.')
dummy_input = dummy_input.to('cpu')
self.model = self.model.to('cpu')
path = path + '.pt'
try:
try:
@ -185,7 +184,7 @@ class NNFingerprint(NNOperator):
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
log.error('Fail to save as torchscript: %s.', e)
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'

Loading…
Cancel
Save