logo
Browse Source

Update test scripts

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
9dda6285eb
  1. 24
      test_onnx.py
  2. 26
      test_torchscript.py

24
test_onnx.py

@ -2,40 +2,36 @@ from auto_transformers import AutoTransformers
import onnx import onnx
f = open('onnx.csv', 'a+') f = open('onnx.csv', 'a+')
f.write('model_name, run op, save_onnx, check_onnx\n')
f.write('model_name, run_op, save_onnx, check_onnx\n')
# models = AutoTransformers.supported_model_names()[:1] # models = AutoTransformers.supported_model_names()[:1]
models = ['bert-base-cased']
models = ['bert-base-cased', 'distilbert-base-cased']
for name in models: for name in models:
line = f'{name}, '
f.write(f'{name},')
try: try:
op = AutoTransformers(model_name=name) op = AutoTransformers(model_name=name)
out1 = op('hello, world.') out1 = op('hello, world.')
line += 'success, '
f.write('success,')
except Exception as e: except Exception as e:
line += 'fail\n'
f.write(line)
f.write('fail')
print(f'Fail to load op for {name}: {e}') print(f'Fail to load op for {name}: {e}')
continue continue
try: try:
op.save_model(format='onnx') op.save_model(format='onnx')
line += 'success, '
f.write('success,')
except Exception as e: except Exception as e:
line += 'fail\n'
f.write(line)
f.write('fail')
print(f'Fail to save onnx for {name}: {e}') print(f'Fail to save onnx for {name}: {e}')
continue continue
try: try:
saved_name = name.replace('/', '-') saved_name = name.replace('/', '-')
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model)
line += 'success'
f.write('success')
except Exception as e: except Exception as e:
line += 'fail\n'
f.write(line)
f.write('fail')
print(f'Fail to check onnx for {name}: {e}') print(f'Fail to check onnx for {name}: {e}')
continue continue
line += '\n'
f.write(line)
f.write('\n')
print('Finished.') print('Finished.')

26
test_torchscript.py

@ -2,27 +2,26 @@ from auto_transformers import AutoTransformers
import torch import torch
f = open('torchscript.csv', 'a+') f = open('torchscript.csv', 'a+')
f.write('model_name, run op, save_torchscript, check_result\n')
f.write('model_name,run_op,save_torchscript,check_result\n')
models = AutoTransformers.supported_model_names()[:1]
# models = AutoTransformers.supported_model_names()[:1]
models = ['bert-base-cased', 'distilbert-base-cased']
for name in models: for name in models:
line = f'{name}, '
f.write(f'{name},')
try: try:
op = AutoTransformers(model_name=name) op = AutoTransformers(model_name=name)
out1 = op('hello, world.') out1 = op('hello, world.')
line += 'success, '
f.write('success,')
except Exception as e: except Exception as e:
line += 'fail\n'
f.write(line)
f.write('fail')
print(f'Fail to load op for {name}: {e}') print(f'Fail to load op for {name}: {e}')
continue continue
try: try:
op.save_model(format='torchscript') op.save_model(format='torchscript')
line += 'success, '
f.write('success,')
except Exception as e: except Exception as e:
line += 'fail\n'
f.write(line)
f.write('fail')
print(f'Fail to save onnx for {name}: {e}') print(f'Fail to save onnx for {name}: {e}')
continue continue
try: try:
@ -30,11 +29,10 @@ for name in models:
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
out2 = op('hello, world.') out2 = op('hello, world.')
assert (out1 == out2).all() assert (out1 == out2).all()
line += 'success'
f.write('success')
except Exception as e: except Exception as e:
line += 'fail\n'
f.write(line)
f.write('fail')
print(f'Fail to check onnx for {name}: {e}') print(f'Fail to check onnx for {name}: {e}')
continue continue
line += '\n'
f.write(line)
f.write('\n')
print('Finished.')

Loading…
Cancel
Save