from auto_transformers import AutoTransformers import numpy import onnx import onnxruntime import os from pathlib import Path import logging import platform import psutil # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] models = ['distilbert-base-cased'] atol = 1e-3 log_path = 'transformers_onnx.log' logger = logging.getLogger('transformers_onnx') logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh = logging.FileHandler(log_path) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) ch = logging.StreamHandler() ch.setLevel(logging.ERROR) ch.setFormatter(formatter) logger.addHandler(ch) logger.debug(f'machine: {platform.platform()}-{platform.processor()}') logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 **3))}' f'/{round(psutil.virtual_memory().available / (1024.0 **3))}' f'/{round(psutil.virtual_memory().total / (1024.0 **3))} GB') logger.debug(f'cpu: {psutil.cpu_count()}') for name in models: logger.info(f'***{name}***') saved_name = name.replace('/', '-') try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') logger.info('OP LOADED.') except Exception as e: logger.error(f'FAIL TO LOAD OP: {e}') continue try: op.save_model(format='onnx') logger.info('ONNX SAVED.') except Exception as e: logger.error(f'FAIL TO SAVE ONNX: {e}') continue try: try: onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') onnx.checker.check_model(onnx_model) except Exception: saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) onnx.checker.check_model(saved_onnx) logger.info('ONNX CHECKED.') except Exception as e: logger.error(f'FAIL TO CHECK ONNX: {e}') continue try: sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', providers=onnxruntime.get_available_providers()) inputs = op.tokenizer('hello, world.', return_tensors='np') out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs)) logger.info('ONNX WORKED.') if numpy.allclose(out1, out2, atol=atol): logger.info('Check accuracy: OK') else: logger.info(f'Check accuracy: atol is larger than {atol}.') except Exception as e: logger.error(f'FAIL TO RUN ONNX: {e}') continue print('Finished.')