|
@ -2,27 +2,26 @@ from auto_transformers import AutoTransformers |
|
|
import torch |
|
|
import torch |
|
|
|
|
|
|
|
|
f = open('torchscript.csv', 'a+') |
|
|
f = open('torchscript.csv', 'a+') |
|
|
f.write('model_name, run op, save_torchscript, check_result\n') |
|
|
|
|
|
|
|
|
f.write('model_name,run_op,save_torchscript,check_result\n') |
|
|
|
|
|
|
|
|
models = AutoTransformers.supported_model_names()[:1] |
|
|
|
|
|
|
|
|
# models = AutoTransformers.supported_model_names()[:1] |
|
|
|
|
|
models = ['bert-base-cased', 'distilbert-base-cased'] |
|
|
|
|
|
|
|
|
for name in models: |
|
|
for name in models: |
|
|
line = f'{name}, ' |
|
|
|
|
|
|
|
|
f.write(f'{name},') |
|
|
try: |
|
|
try: |
|
|
op = AutoTransformers(model_name=name) |
|
|
op = AutoTransformers(model_name=name) |
|
|
out1 = op('hello, world.') |
|
|
out1 = op('hello, world.') |
|
|
line += 'success, ' |
|
|
|
|
|
|
|
|
f.write('success,') |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
line += 'fail\n' |
|
|
|
|
|
f.write(line) |
|
|
|
|
|
|
|
|
f.write('fail') |
|
|
print(f'Fail to load op for {name}: {e}') |
|
|
print(f'Fail to load op for {name}: {e}') |
|
|
continue |
|
|
continue |
|
|
try: |
|
|
try: |
|
|
op.save_model(format='torchscript') |
|
|
op.save_model(format='torchscript') |
|
|
line += 'success, ' |
|
|
|
|
|
|
|
|
f.write('success,') |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
line += 'fail\n' |
|
|
|
|
|
f.write(line) |
|
|
|
|
|
|
|
|
f.write('fail') |
|
|
print(f'Fail to save onnx for {name}: {e}') |
|
|
print(f'Fail to save onnx for {name}: {e}') |
|
|
continue |
|
|
continue |
|
|
try: |
|
|
try: |
|
@ -30,11 +29,10 @@ for name in models: |
|
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') |
|
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') |
|
|
out2 = op('hello, world.') |
|
|
out2 = op('hello, world.') |
|
|
assert (out1 == out2).all() |
|
|
assert (out1 == out2).all() |
|
|
line += 'success' |
|
|
|
|
|
|
|
|
f.write('success') |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
line += 'fail\n' |
|
|
|
|
|
f.write(line) |
|
|
|
|
|
|
|
|
f.write('fail') |
|
|
print(f'Fail to check onnx for {name}: {e}') |
|
|
print(f'Fail to check onnx for {name}: {e}') |
|
|
continue |
|
|
continue |
|
|
line += '\n' |
|
|
|
|
|
f.write(line) |
|
|
|
|
|
|
|
|
f.write('\n') |
|
|
|
|
|
print('Finished.') |
|
|