Browse Source
Add onnx test
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
6 additions and
6 deletions
-
test_onnx.py
-
timm_image.py
|
|
@ -11,8 +11,8 @@ import logging |
|
|
|
import platform |
|
|
|
import psutil |
|
|
|
|
|
|
|
# models = TimmImage.supported_model_names()[:1] |
|
|
|
models = ['resnet50'] |
|
|
|
models = TimmImage.supported_model_names()[:2] |
|
|
|
# models = ['resnet50'] |
|
|
|
|
|
|
|
atol = 1e-3 |
|
|
|
log_path = 'timm_onnx.log' |
|
|
@ -50,7 +50,7 @@ for name in models: |
|
|
|
op = TimmImage(model_name=name, device='cpu') |
|
|
|
data = torch.rand((1,) + op.config['input_size']) |
|
|
|
try: |
|
|
|
out1 = op.model(data).detach().numpy() |
|
|
|
out1 = op.model.forward_features(data).detach().numpy() |
|
|
|
logger.info('OP LOADED.') |
|
|
|
status[1] = 'success' |
|
|
|
except Exception as e: |
|
|
@ -74,11 +74,10 @@ for name in models: |
|
|
|
status[3] = 'success' |
|
|
|
except Exception as e: |
|
|
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|
|
|
continue |
|
|
|
pass |
|
|
|
try: |
|
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
|
|
|
|
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) |
|
|
|
logger.info('ONNX WORKED.') |
|
|
|
status[4] = 'success' |
|
|
|
|
|
@ -125,13 +125,14 @@ class TimmImage(NNOperator): |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|
self.model.forward = self.model.forward_features |
|
|
|
try: |
|
|
|
torch.onnx.export(self.model, |
|
|
|
dummy_input, |
|
|
|
path, |
|
|
|
input_names=['input_0'], |
|
|
|
output_names=['output_0'], |
|
|
|
opset_version=13, |
|
|
|
opset_version=12, |
|
|
|
dynamic_axes={ |
|
|
|
'input_0': {0: 'batch_size'}, |
|
|
|
'output_0': {0: 'batch_size'} |
|
|
|