|
@ -11,8 +11,8 @@ import logging |
|
|
import platform |
|
|
import platform |
|
|
import psutil |
|
|
import psutil |
|
|
|
|
|
|
|
|
# models = TimmImage.supported_model_names()[:1] |
|
|
|
|
|
models = ['resnet50'] |
|
|
|
|
|
|
|
|
models = TimmImage.supported_model_names()[:2] |
|
|
|
|
|
# models = ['resnet50'] |
|
|
|
|
|
|
|
|
atol = 1e-3 |
|
|
atol = 1e-3 |
|
|
log_path = 'timm_onnx.log' |
|
|
log_path = 'timm_onnx.log' |
|
@ -50,7 +50,7 @@ for name in models: |
|
|
op = TimmImage(model_name=name, device='cpu') |
|
|
op = TimmImage(model_name=name, device='cpu') |
|
|
data = torch.rand((1,) + op.config['input_size']) |
|
|
data = torch.rand((1,) + op.config['input_size']) |
|
|
try: |
|
|
try: |
|
|
out1 = op.model(data).detach().numpy() |
|
|
|
|
|
|
|
|
out1 = op.model.forward_features(data).detach().numpy() |
|
|
logger.info('OP LOADED.') |
|
|
logger.info('OP LOADED.') |
|
|
status[1] = 'success' |
|
|
status[1] = 'success' |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
@ -74,11 +74,10 @@ for name in models: |
|
|
status[3] = 'success' |
|
|
status[3] = 'success' |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|
|
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|
|
continue |
|
|
|
|
|
|
|
|
pass |
|
|
try: |
|
|
try: |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
|
|
|
|
|
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) |
|
|
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) |
|
|
logger.info('ONNX WORKED.') |
|
|
logger.info('ONNX WORKED.') |
|
|
status[4] = 'success' |
|
|
status[4] = 'success' |
|
|