transformers
copied
Jael Gu
2 years ago
2 changed files with 47 additions and 7 deletions
@ -0,0 +1,40 @@ |
|||
from auto_transformers import AutoTransformers |
|||
import torch |
|||
|
|||
f = open('torchscript.csv', 'a+') |
|||
f.write('model_name, run op, save_torchscript, check_result\n') |
|||
|
|||
models = AutoTransformers.supported_model_names()[:1] |
|||
|
|||
for name in models: |
|||
line = f'{name}, ' |
|||
try: |
|||
op = AutoTransformers(model_name=name) |
|||
out1 = op('hello, world.') |
|||
line += 'success, ' |
|||
except Exception as e: |
|||
line += 'fail\n' |
|||
f.write(line) |
|||
print(f'Fail to load op for {name}: {e}.') |
|||
continue |
|||
try: |
|||
op.save_model(format='torchscript') |
|||
line += 'success, ' |
|||
except Exception as e: |
|||
line += 'fail\n' |
|||
f.write(line) |
|||
print(f'Fail to save onnx for {name}: {e}.') |
|||
continue |
|||
try: |
|||
saved_name = name.replace('/', '-') |
|||
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') |
|||
out2 = op('hello, world.') |
|||
assert (out1 == out2).all() |
|||
line += 'success' |
|||
except Exception as e: |
|||
line += 'fail\n' |
|||
f.write(line) |
|||
print(f'Fail to check onnx for {name}: {e}.') |
|||
continue |
|||
line += '\n' |
|||
f.write(line) |
Loading…
Reference in new issue