logo
Browse Source

update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
d2f54a94c0
  1. 14
      test_onnx.py
  2. 40
      test_torchscript.py

14
test_save.py → test_onnx.py

@ -4,10 +4,7 @@ import onnx
f = open('onnx.csv', 'a+')
f.write('model_name, run op, save_onnx, check_onnx\n')
models = [
'bert-base-cased',
'distilbert-base-cased'
]
models = AutoTransformers.supported_model_names()[:1]
for name in models:
line = f'{name}, '
@ -16,14 +13,16 @@ for name in models:
out1 = op('hello, world.')
line += 'success, '
except Exception as e:
line += 'fail, '
line += 'fail\n'
f.write(line)
print(f'Fail to load op for {name}: {e}.')
continue
try:
op.save_model(format='onnx')
line += 'success, '
except Exception as e:
line += 'fail, '
line += 'fail\n'
f.write(line)
print(f'Fail to save onnx for {name}: {e}.')
continue
try:
@ -32,7 +31,8 @@ for name in models:
onnx.checker.check_model(onnx_model)
line += 'success'
except Exception as e:
line += 'fail'
line += 'fail\n'
f.write(line)
print(f'Fail to check onnx for {name}: {e}.')
continue
line += '\n'

40
test_torchscript.py

@ -0,0 +1,40 @@
from auto_transformers import AutoTransformers
import torch
f = open('torchscript.csv', 'a+')
f.write('model_name, run op, save_torchscript, check_result\n')
models = AutoTransformers.supported_model_names()[:1]
for name in models:
line = f'{name}, '
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
line += 'success, '
except Exception as e:
line += 'fail\n'
f.write(line)
print(f'Fail to load op for {name}: {e}.')
continue
try:
op.save_model(format='torchscript')
line += 'success, '
except Exception as e:
line += 'fail\n'
f.write(line)
print(f'Fail to save onnx for {name}: {e}.')
continue
try:
saved_name = name.replace('/', '-')
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
line += 'success'
except Exception as e:
line += 'fail\n'
f.write(line)
print(f'Fail to check onnx for {name}: {e}.')
continue
line += '\n'
f.write(line)
Loading…
Cancel
Save