From 9dda6285ebe1e38e65b49a5260b51ac8ac1d7b64 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 22 Jun 2022 10:36:32 +0800 Subject: [PATCH] Update test scripts Signed-off-by: Jael Gu --- test_onnx.py | 24 ++++++++++-------------- test_torchscript.py | 26 ++++++++++++-------------- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/test_onnx.py b/test_onnx.py index fd7fdab..0c75b77 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -2,40 +2,36 @@ from auto_transformers import AutoTransformers import onnx f = open('onnx.csv', 'a+') -f.write('model_name, run op, save_onnx, check_onnx\n') +f.write('model_name, run_op, save_onnx, check_onnx\n') # models = AutoTransformers.supported_model_names()[:1] -models = ['bert-base-cased'] +models = ['bert-base-cased', 'distilbert-base-cased'] for name in models: - line = f'{name}, ' + f.write(f'{name},') try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') - line += 'success, ' + f.write('success,') except Exception as e: - line += 'fail\n' - f.write(line) + f.write('fail') print(f'Fail to load op for {name}: {e}') continue try: op.save_model(format='onnx') - line += 'success, ' + f.write('success,') except Exception as e: - line += 'fail\n' - f.write(line) + f.write('fail') print(f'Fail to save onnx for {name}: {e}') continue try: saved_name = name.replace('/', '-') onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') onnx.checker.check_model(onnx_model) - line += 'success' + f.write('success') except Exception as e: - line += 'fail\n' - f.write(line) + f.write('fail') print(f'Fail to check onnx for {name}: {e}') continue - line += '\n' - f.write(line) + f.write('\n') print('Finished.') diff --git a/test_torchscript.py b/test_torchscript.py index 12b5a0c..bd344d0 100644 --- a/test_torchscript.py +++ b/test_torchscript.py @@ -2,27 +2,26 @@ from auto_transformers import AutoTransformers import torch f = open('torchscript.csv', 'a+') -f.write('model_name, run op, save_torchscript, check_result\n') +f.write('model_name,run_op,save_torchscript,check_result\n') -models = AutoTransformers.supported_model_names()[:1] +# models = AutoTransformers.supported_model_names()[:1] +models = ['bert-base-cased', 'distilbert-base-cased'] for name in models: - line = f'{name}, ' + f.write(f'{name},') try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') - line += 'success, ' + f.write('success,') except Exception as e: - line += 'fail\n' - f.write(line) + f.write('fail') print(f'Fail to load op for {name}: {e}') continue try: op.save_model(format='torchscript') - line += 'success, ' + f.write('success,') except Exception as e: - line += 'fail\n' - f.write(line) + f.write('fail') print(f'Fail to save onnx for {name}: {e}') continue try: @@ -30,11 +29,10 @@ for name in models: op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') out2 = op('hello, world.') assert (out1 == out2).all() - line += 'success' + f.write('success') except Exception as e: - line += 'fail\n' - f.write(line) + f.write('fail') print(f'Fail to check onnx for {name}: {e}') continue - line += '\n' - f.write(line) + f.write('\n') +print('Finished.')