import onnx from isc import Isc from towhee import ops import torch import numpy import onnx import onnxruntime import os from pathlib import Path import logging import platform import psutil models = ['tf_efficientnetv2_m_in21ft1k'] atol = 1e-3 log_path = 'isc_onnx.log' f = open('onnx.csv', 'w+') f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') logger = logging.getLogger('isc_onnx') logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh = logging.FileHandler(log_path) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) ch = logging.StreamHandler() ch.setLevel(logging.ERROR) ch.setFormatter(formatter) logger.addHandler(ch) logger.debug(f'machine: {platform.platform()}-{platform.processor()}') logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') logger.debug(f'cpu: {psutil.cpu_count()}') status = None for name in models: logger.info(f'***{name}***') saved_name = name.replace('/', '-') onnx_path = f'saved/onnx/{saved_name}.onnx' try: op = Isc(timm_backbone=name, device='cpu') except Exception as e: logger.error(f'Fail to load model {name}. Please check weights.') data = torch.rand(1, 3, 224, 224) if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 try: out1 = op.model(data).detach().numpy() logger.info('OP LOADED.') status[1] = 'success' except Exception as e: logger.error(f'FAIL TO LOAD OP: {e}') continue try: op.save_model(format='onnx') logger.info('ONNX SAVED.') status[2] = 'success' except Exception as e: logger.error(f'FAIL TO SAVE ONNX: {e}') continue try: try: onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) except Exception: saved_onnx = onnx.load(onnx_path, load_external_data=False) onnx.checker.check_model(saved_onnx) logger.info('ONNX CHECKED.') status[3] = 'success' except Exception as e: logger.error(f'FAIL TO CHECK ONNX: {e}') pass try: sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()}) logger.info('ONNX WORKED.') status[4] = 'success' if numpy.allclose(out1, out2, atol=atol): logger.info('Check accuracy: OK') status[5] = 'success' else: logger.info(f'Check accuracy: atol is larger than {atol}.') except Exception as e: logger.error(f'FAIL TO RUN ONNX: {e}') continue if status: f.write(','.join(status) + '\n') print('Finished.')