From fa0aac70379bb1af9e67eb199c637967edc97898 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 8 Dec 2022 17:17:44 +0800 Subject: [PATCH] Add onnx test Signed-off-by: Jael Gu --- test_onnx.py | 9 ++++----- timm_image.py | 3 ++- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/test_onnx.py b/test_onnx.py index 286e07b..5498278 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -11,8 +11,8 @@ import logging import platform import psutil -# models = TimmImage.supported_model_names()[:1] -models = ['resnet50'] +models = TimmImage.supported_model_names()[:2] +# models = ['resnet50'] atol = 1e-3 log_path = 'timm_onnx.log' @@ -50,7 +50,7 @@ for name in models: op = TimmImage(model_name=name, device='cpu') data = torch.rand((1,) + op.config['input_size']) try: - out1 = op.model(data).detach().numpy() + out1 = op.model.forward_features(data).detach().numpy() logger.info('OP LOADED.') status[1] = 'success' except Exception as e: @@ -74,11 +74,10 @@ for name in models: status[3] = 'success' except Exception as e: logger.error(f'FAIL TO CHECK ONNX: {e}') - continue + pass try: sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) - out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) logger.info('ONNX WORKED.') status[4] = 'success' diff --git a/timm_image.py b/timm_image.py index c8fd104..6c75833 100644 --- a/timm_image.py +++ b/timm_image.py @@ -125,13 +125,14 @@ class TimmImage(NNOperator): raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': path = path + '.onnx' + self.model.forward = self.model.forward_features try: torch.onnx.export(self.model, dummy_input, path, input_names=['input_0'], output_names=['output_0'], - opset_version=13, + opset_version=12, dynamic_axes={ 'input_0': {0: 'batch_size'}, 'output_0': {0: 'batch_size'}