transformers
copied
Jael Gu
3 years ago
2 changed files with 60 additions and 13 deletions
@ -1,23 +1,39 @@ |
|||||
from auto_transformers import AutoTransformers |
from auto_transformers import AutoTransformers |
||||
|
import onnx |
||||
|
|
||||
import torch |
|
||||
|
f = open('onnx.csv', 'a+') |
||||
|
f.write('model_name, run op, save_onnx, check_onnx\n') |
||||
|
|
||||
models = [ |
models = [ |
||||
'bert-base-cased', |
'bert-base-cased', |
||||
'distilbert-base-cased', |
|
||||
'distilgpt2', |
|
||||
'google/fnet-base' |
|
||||
|
'distilbert-base-cased' |
||||
] |
] |
||||
|
|
||||
for name in models: |
for name in models: |
||||
|
line = f'{name}, ' |
||||
try: |
try: |
||||
op = AutoTransformers(model_name=name) |
op = AutoTransformers(model_name=name) |
||||
out1 = op('hello, world.') |
out1 = op('hello, world.') |
||||
op.save_model(format='torchscript') |
|
||||
op.model = torch.jit.load(name.replace('/', '-') + '.pt') |
|
||||
out2 = op('hello, world.') |
|
||||
assert (out1 == out2).all() |
|
||||
print(f'[SUCCESS] Saved torchscript for model "{name}"') |
|
||||
|
line += 'success, ' |
||||
except Exception as e: |
except Exception as e: |
||||
print(f'[ERROR] Fail for model "{name}": {e}.') |
|
||||
|
line += 'fail, ' |
||||
|
print(f'Fail to load op for {name}: {e}.') |
||||
continue |
continue |
||||
|
try: |
||||
|
op.save_model(format='onnx') |
||||
|
line += 'success, ' |
||||
|
except Exception as e: |
||||
|
line += 'fail, ' |
||||
|
print(f'Fail to save onnx for {name}: {e}.') |
||||
|
continue |
||||
|
try: |
||||
|
saved_name = name.replace('/', '-') |
||||
|
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') |
||||
|
onnx.checker.check_model(onnx_model) |
||||
|
line += 'success' |
||||
|
except Exception as e: |
||||
|
line += 'fail' |
||||
|
print(f'Fail to check onnx for {name}: {e}.') |
||||
|
continue |
||||
|
line += '\n' |
||||
|
f.write(line) |
||||
|
Loading…
Reference in new issue