logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

23 lines
626 B

from auto_transformers import AutoTransformers
import torch
models = [
'bert-base-cased',
'distilbert-base-cased',
'distilgpt2',
'google/fnet-base'
]
for name in models:
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
op.save_model(format='torchscript')
op.model = torch.jit.load(name.replace('/', '-') + '.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
print(f'[SUCCESS] Saved torchscript for model "{name}"')
except Exception as e:
print(f'[ERROR] Fail for model "{name}": {e}.')
continue