logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
6e5da12be2
  1. 6
      nn_fingerprint.py

6
nn_fingerprint.py

@ -62,7 +62,7 @@ class NNFingerprint(NNOperator):
log.warning('Using onnx.') log.warning('Using onnx.')
self.model = onnxruntime.InferenceSession( self.model = onnxruntime.InferenceSession(
model_path, model_path,
providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider']
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
) )
else: else:
state_dict = torch.load(model_path, map_location=self.device) state_dict = torch.load(model_path, map_location=self.device)
@ -165,6 +165,10 @@ class NNFingerprint(NNOperator):
path = path + '.pt' path = path + '.pt'
torch.save(self.model, path) torch.save(self.model, path)
elif format == 'torchscript': elif format == 'torchscript':
self.device = 'cpu'
log.warning('Switched to CPU in order to support torchscript.')
dummy_input = dummy_input.to('cpu')
self.model = self.model.to('cpu')
path = path + '.pt' path = path + '.pt'
try: try:
try: try:

Loading…
Cancel
Save