from auto_transformers import AutoTransformers import numpy import onnx import onnxruntime from pathlib import Path import warnings import logging # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] models = ['distilbert-base-cased'] warnings.filterwarnings('ignore') log_path = './transformers_onnx.log' logger = logging.getLogger() if Path(log_path).exists: handler = logging.FileHandler(log_path) else: handler = logging.StreamHandler() logger.handlers.append(handler) logging.basicConfig(filename=log_path, filemode='a', format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s', datefmt='%H:%M:%S', level=logging.info) logging.info("Test") for name in models: print(f'Model {name}:') saved_name = name.replace('/', '-') logging.info(f'\nConverting model {name} to {saved_name}:') try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') logging.info('OP LOADED.') except Exception as e: logging.error(f'FAIL TO LOAD OP: {e}') continue try: op.save_model(format='onnx') logging.info('ONNX SAVED.') except Exception as e: logging.info(f'FAIL TO SAVE ONNX: {e}') continue try: try: onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') onnx.checker.check_model(onnx_model) except Exception: saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) onnx.checker.check_model(saved_onnx) logging.info('ONNX CHECKED.') except Exception as e: logging.error(f'FAIL TO CHECK ONNX: {e}') continue try: sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', providers=onnxruntime.get_available_providers()) inputs = op.tokenizer('hello, world.', return_tensors='np') out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs)) logging.info('ONNX WORKED.') if numpy.allclose(out1, out2, atol=1e-3): logging.info('Check accuracy: OK') else: logging.info('Check accuracy: atol is larger than 1e-3.') except Exception as e: logging.error(f'FAIL TO RUN ONNX: {e}') continue print('Finished.')