transformers
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
38 lines
1.1 KiB
38 lines
1.1 KiB
from auto_transformers import AutoTransformers
|
|
import torch
|
|
|
|
f = open('torchscript.csv', 'a+')
|
|
f.write('model_name,run_op,save_torchscript,check_result\n')
|
|
|
|
# models = AutoTransformers.supported_model_names()[:1]
|
|
models = ['bert-base-cased', 'distilbert-base-cased']
|
|
|
|
for name in models:
|
|
f.write(f'{name},')
|
|
try:
|
|
op = AutoTransformers(model_name=name)
|
|
out1 = op('hello, world.')
|
|
f.write('success,')
|
|
except Exception as e:
|
|
f.write('fail')
|
|
print(f'Fail to load op for {name}: {e}')
|
|
continue
|
|
try:
|
|
op.save_model(format='torchscript')
|
|
f.write('success,')
|
|
except Exception as e:
|
|
f.write('fail')
|
|
print(f'Fail to save onnx for {name}: {e}')
|
|
continue
|
|
try:
|
|
saved_name = name.replace('/', '-')
|
|
op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt')
|
|
out2 = op('hello, world.')
|
|
assert (out1 == out2).all()
|
|
f.write('success')
|
|
except Exception as e:
|
|
f.write('fail')
|
|
print(f'Fail to check onnx for {name}: {e}')
|
|
continue
|
|
f.write('\n')
|
|
print('Finished.')
|