logo
Browse Source

Add onnx test

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
fa0aac7037
  1. 9
      test_onnx.py
  2. 3
      timm_image.py

9
test_onnx.py

@ -11,8 +11,8 @@ import logging
import platform
import psutil
# models = TimmImage.supported_model_names()[:1]
models = ['resnet50']
models = TimmImage.supported_model_names()[:2]
# models = ['resnet50']
atol = 1e-3
log_path = 'timm_onnx.log'
@ -50,7 +50,7 @@ for name in models:
op = TimmImage(model_name=name, device='cpu')
data = torch.rand((1,) + op.config['input_size'])
try:
out1 = op.model(data).detach().numpy()
out1 = op.model.forward_features(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:
@ -74,11 +74,10 @@ for name in models:
status[3] = 'success'
except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}')
continue
pass
try:
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()})
logger.info('ONNX WORKED.')
status[4] = 'success'

3
timm_image.py

@ -125,13 +125,14 @@ class TimmImage(NNOperator):
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
self.model.forward = self.model.forward_features
try:
torch.onnx.export(self.model,
dummy_input,
path,
input_names=['input_0'],
output_names=['output_0'],
opset_version=13,
opset_version=12,
dynamic_axes={
'input_0': {0: 'batch_size'},
'output_0': {0: 'batch_size'}

Loading…
Cancel
Save