From d2f54a94c0942a584aedab91803ba092cc1098ee Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 20 Jun 2022 17:16:06 +0800 Subject: [PATCH] update Signed-off-by: Jael Gu --- test_save.py => test_onnx.py | 14 ++++++------- test_torchscript.py | 40 ++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 7 deletions(-) rename test_save.py => test_onnx.py (82%) create mode 100644 test_torchscript.py diff --git a/test_save.py b/test_onnx.py similarity index 82% rename from test_save.py rename to test_onnx.py index f809e18..4311655 100644 --- a/test_save.py +++ b/test_onnx.py @@ -4,10 +4,7 @@ import onnx f = open('onnx.csv', 'a+') f.write('model_name, run op, save_onnx, check_onnx\n') -models = [ - 'bert-base-cased', - 'distilbert-base-cased' -] +models = AutoTransformers.supported_model_names()[:1] for name in models: line = f'{name}, ' @@ -16,14 +13,16 @@ for name in models: out1 = op('hello, world.') line += 'success, ' except Exception as e: - line += 'fail, ' + line += 'fail\n' + f.write(line) print(f'Fail to load op for {name}: {e}.') continue try: op.save_model(format='onnx') line += 'success, ' except Exception as e: - line += 'fail, ' + line += 'fail\n' + f.write(line) print(f'Fail to save onnx for {name}: {e}.') continue try: @@ -32,7 +31,8 @@ for name in models: onnx.checker.check_model(onnx_model) line += 'success' except Exception as e: - line += 'fail' + line += 'fail\n' + f.write(line) print(f'Fail to check onnx for {name}: {e}.') continue line += '\n' diff --git a/test_torchscript.py b/test_torchscript.py new file mode 100644 index 0000000..10eac79 --- /dev/null +++ b/test_torchscript.py @@ -0,0 +1,40 @@ +from auto_transformers import AutoTransformers +import torch + +f = open('torchscript.csv', 'a+') +f.write('model_name, run op, save_torchscript, check_result\n') + +models = AutoTransformers.supported_model_names()[:1] + +for name in models: + line = f'{name}, ' + try: + op = AutoTransformers(model_name=name) + out1 = op('hello, world.') + line += 'success, ' + except Exception as e: + line += 'fail\n' + f.write(line) + print(f'Fail to load op for {name}: {e}.') + continue + try: + op.save_model(format='torchscript') + line += 'success, ' + except Exception as e: + line += 'fail\n' + f.write(line) + print(f'Fail to save onnx for {name}: {e}.') + continue + try: + saved_name = name.replace('/', '-') + op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') + out2 = op('hello, world.') + assert (out1 == out2).all() + line += 'success' + except Exception as e: + line += 'fail\n' + f.write(line) + print(f'Fail to check onnx for {name}: {e}.') + continue + line += '\n' + f.write(line)