logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
482a1caeff
  1. 2
      nn_fingerprint.py

2
nn_fingerprint.py

@ -65,7 +65,7 @@ class NNFingerprint(NNOperator):
log.warning('Using onnx.') log.warning('Using onnx.')
self.model = onnxruntime.InferenceSession( self.model = onnxruntime.InferenceSession(
model_path, model_path,
providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
providers=['CPUExecutionProvider'] if self.device == 'cpu' else ['CUDAExecutionProvider']
) )
else: else:
state_dict = torch.load(model_path, map_location=self.device) state_dict = torch.load(model_path, map_location=self.device)

Loading…
Cancel
Save