Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
5 additions and
6 deletions
-
nn_fingerprint.py
|
|
@ -68,6 +68,9 @@ class NNFingerprint(NNOperator): |
|
|
|
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider'] |
|
|
|
) |
|
|
|
else: |
|
|
|
try: |
|
|
|
state_dict = torch.jit.load(model_path, map_location=self.device) |
|
|
|
except Exception: |
|
|
|
state_dict = torch.load(model_path, map_location=self.device) |
|
|
|
if isinstance(state_dict, torch.nn.Module): |
|
|
|
self.model = state_dict |
|
|
@ -170,10 +173,6 @@ class NNFingerprint(NNOperator): |
|
|
|
path = path + '.pt' |
|
|
|
torch.save(self.model, path) |
|
|
|
elif format == 'torchscript': |
|
|
|
self.device = 'cpu' |
|
|
|
log.warning('Switched to CPU in order to support torchscript.') |
|
|
|
dummy_input = dummy_input.to('cpu') |
|
|
|
self.model = self.model.to('cpu') |
|
|
|
path = path + '.pt' |
|
|
|
try: |
|
|
|
try: |
|
|
@ -185,7 +184,7 @@ class NNFingerprint(NNOperator): |
|
|
|
jit_model = torch.jit.trace(self.model, dummy_input, strict=False) |
|
|
|
torch.jit.save(jit_model, path) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
log.error('Fail to save as torchscript: %s.', e) |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|