logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
dbbee2277e
  1. 9
      nn_fingerprint.py

9
nn_fingerprint.py

@ -68,6 +68,9 @@ class NNFingerprint(NNOperator):
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']
) )
else: else:
try:
state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception:
state_dict = torch.load(model_path, map_location=self.device) state_dict = torch.load(model_path, map_location=self.device)
if isinstance(state_dict, torch.nn.Module): if isinstance(state_dict, torch.nn.Module):
self.model = state_dict self.model = state_dict
@ -170,10 +173,6 @@ class NNFingerprint(NNOperator):
path = path + '.pt' path = path + '.pt'
torch.save(self.model, path) torch.save(self.model, path)
elif format == 'torchscript': elif format == 'torchscript':
self.device = 'cpu'
log.warning('Switched to CPU in order to support torchscript.')
dummy_input = dummy_input.to('cpu')
self.model = self.model.to('cpu')
path = path + '.pt' path = path + '.pt'
try: try:
try: try:
@ -185,7 +184,7 @@ class NNFingerprint(NNOperator):
jit_model = torch.jit.trace(self.model, dummy_input, strict=False) jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
torch.jit.save(jit_model, path) torch.jit.save(jit_model, path)
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
log.error('Fail to save as torchscript: %s.', e)
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx' path = path + '.onnx'

Loading…
Cancel
Save