Browse Source
Update test scripts
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
2 changed files with
22 additions and
28 deletions
-
test_onnx.py
-
test_torchscript.py
|
|
@ -2,40 +2,36 @@ from auto_transformers import AutoTransformers |
|
|
|
import onnx |
|
|
|
|
|
|
|
f = open('onnx.csv', 'a+') |
|
|
|
f.write('model_name, run op, save_onnx, check_onnx\n') |
|
|
|
f.write('model_name, run_op, save_onnx, check_onnx\n') |
|
|
|
|
|
|
|
# models = AutoTransformers.supported_model_names()[:1] |
|
|
|
models = ['bert-base-cased'] |
|
|
|
models = ['bert-base-cased', 'distilbert-base-cased'] |
|
|
|
|
|
|
|
for name in models: |
|
|
|
line = f'{name}, ' |
|
|
|
f.write(f'{name},') |
|
|
|
try: |
|
|
|
op = AutoTransformers(model_name=name) |
|
|
|
out1 = op('hello, world.') |
|
|
|
line += 'success, ' |
|
|
|
f.write('success,') |
|
|
|
except Exception as e: |
|
|
|
line += 'fail\n' |
|
|
|
f.write(line) |
|
|
|
f.write('fail') |
|
|
|
print(f'Fail to load op for {name}: {e}') |
|
|
|
continue |
|
|
|
try: |
|
|
|
op.save_model(format='onnx') |
|
|
|
line += 'success, ' |
|
|
|
f.write('success,') |
|
|
|
except Exception as e: |
|
|
|
line += 'fail\n' |
|
|
|
f.write(line) |
|
|
|
f.write('fail') |
|
|
|
print(f'Fail to save onnx for {name}: {e}') |
|
|
|
continue |
|
|
|
try: |
|
|
|
saved_name = name.replace('/', '-') |
|
|
|
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') |
|
|
|
onnx.checker.check_model(onnx_model) |
|
|
|
line += 'success' |
|
|
|
f.write('success') |
|
|
|
except Exception as e: |
|
|
|
line += 'fail\n' |
|
|
|
f.write(line) |
|
|
|
f.write('fail') |
|
|
|
print(f'Fail to check onnx for {name}: {e}') |
|
|
|
continue |
|
|
|
line += '\n' |
|
|
|
f.write(line) |
|
|
|
f.write('\n') |
|
|
|
print('Finished.') |
|
|
|
|
|
@ -2,27 +2,26 @@ from auto_transformers import AutoTransformers |
|
|
|
import torch |
|
|
|
|
|
|
|
f = open('torchscript.csv', 'a+') |
|
|
|
f.write('model_name, run op, save_torchscript, check_result\n') |
|
|
|
f.write('model_name,run_op,save_torchscript,check_result\n') |
|
|
|
|
|
|
|
models = AutoTransformers.supported_model_names()[:1] |
|
|
|
# models = AutoTransformers.supported_model_names()[:1] |
|
|
|
models = ['bert-base-cased', 'distilbert-base-cased'] |
|
|
|
|
|
|
|
for name in models: |
|
|
|
line = f'{name}, ' |
|
|
|
f.write(f'{name},') |
|
|
|
try: |
|
|
|
op = AutoTransformers(model_name=name) |
|
|
|
out1 = op('hello, world.') |
|
|
|
line += 'success, ' |
|
|
|
f.write('success,') |
|
|
|
except Exception as e: |
|
|
|
line += 'fail\n' |
|
|
|
f.write(line) |
|
|
|
f.write('fail') |
|
|
|
print(f'Fail to load op for {name}: {e}') |
|
|
|
continue |
|
|
|
try: |
|
|
|
op.save_model(format='torchscript') |
|
|
|
line += 'success, ' |
|
|
|
f.write('success,') |
|
|
|
except Exception as e: |
|
|
|
line += 'fail\n' |
|
|
|
f.write(line) |
|
|
|
f.write('fail') |
|
|
|
print(f'Fail to save onnx for {name}: {e}') |
|
|
|
continue |
|
|
|
try: |
|
|
@ -30,11 +29,10 @@ for name in models: |
|
|
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') |
|
|
|
out2 = op('hello, world.') |
|
|
|
assert (out1 == out2).all() |
|
|
|
line += 'success' |
|
|
|
f.write('success') |
|
|
|
except Exception as e: |
|
|
|
line += 'fail\n' |
|
|
|
f.write(line) |
|
|
|
f.write('fail') |
|
|
|
print(f'Fail to check onnx for {name}: {e}') |
|
|
|
continue |
|
|
|
line += '\n' |
|
|
|
f.write(line) |
|
|
|
f.write('\n') |
|
|
|
print('Finished.') |
|
|
|