logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

18 lines
550 B

from auto_transformers import AutoTransformers
import torch
models = ['bert-base-cased', 'distilbert-base-cased', 'distilgpt2']
for name in models:
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
op.save_model()
op.model = torch.jit.load(name + '.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
print(f'[SUCCESS] Saved torchscript for model "{name}"')
except Exception as e:
print(f'[ERROR] Fail for model "{name}": {e}.')
continue