diff --git a/auto_transformers.py b/auto_transformers.py index 050cfb1..9182f92 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -108,15 +108,18 @@ class AutoTransformers(NNOperator): torch.onnx.export(self.model, tuple(inputs.values()), path, - input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) + input_names=list(inputs.keys()), output_names=["last_hidden_state"], - opset_version=13, dynamic_axes={ "input_ids": {0: "batch_size", 1: "input_length"}, "token_type_ids": {0: "batch_size", 1: "input_length"}, "attention_mask": {0: "batch_size", 1: "input_length"}, "last_hidden_state": {0: "batch_size"}, - }) + }, + opset_version=13, + do_constant_folding=True, + enable_onnx_checker=True, + ) except Exception as e: print(e, '\nTrying with 2 outputs...') torch.onnx.export(self.model, diff --git a/test_onnx.py b/test_onnx.py index 8a3c5e4..077e73e 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -1,43 +1,79 @@ from auto_transformers import AutoTransformers -import onnx -import warnings -warnings.filterwarnings('ignore') +import numpy +import onnx +import onnxruntime -f = open('onnx.csv', 'a+') -f.write('model_name, run_op, save_onnx, check_onnx\n') +from pathlib import Path +import warnings +import logging # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] -models = ['funnel-transformer/large', 'funnel-transformer/medium', 'funnel-transformer/small', 'funnel-transformer/xlarge'] +models = ['distilbert-base-cased'] + +warnings.filterwarnings('ignore') + +log_path = './transformers_onnx.log' + +logger = logging.getLogger() + +if Path(log_path).exists: + handler = logging.FileHandler(log_path) +else: + handler = logging.StreamHandler() +logger.handlers.append(handler) + +logging.basicConfig(filename=log_path, + filemode='a', + format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s', + datefmt='%H:%M:%S', + level=logging.info) + + +logging.info("Test") for name in models: - f.write(f'{name},') + print(f'Model {name}:') + saved_name = name.replace('/', '-') + logging.info(f'\nConverting model {name} to {saved_name}:') try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') - f.write('success,') + logging.info('OP LOADED.') except Exception as e: - f.write('fail') - print(f'Fail to load op for {name}: {e}') - pass + logging.error(f'FAIL TO LOAD OP: {e}') + continue try: op.save_model(format='onnx') - f.write('success,') + logging.info('ONNX SAVED.') + except Exception as e: + logging.info(f'FAIL TO SAVE ONNX: {e}') + continue + try: + try: + onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') + onnx.checker.check_model(onnx_model) + except Exception: + saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) + onnx.checker.check_model(saved_onnx) + logging.info('ONNX CHECKED.') except Exception as e: - f.write('fail') - print(f'Fail to save onnx for {name}: {e}') - pass + logging.error(f'FAIL TO CHECK ONNX: {e}') + continue try: - saved_name = name.replace('/', '-') - onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) - onnx.checker.check_model(onnx_model) - f.write('success') + sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', + providers=onnxruntime.get_available_providers()) + inputs = op.tokenizer('hello, world.', return_tensors='np') + out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs)) + logging.info('ONNX WORKED.') + if numpy.allclose(out1, out2, atol=1e-3): + logging.info('Check accuracy: OK') + else: + logging.info('Check accuracy: atol is larger than 1e-3.') except Exception as e: - f.write('fail') - print(f'Fail to check onnx for {name}: {e}') - pass - - f.write('\n') + logging.error(f'FAIL TO RUN ONNX: {e}') + continue + print('Finished.')