transformers
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
23 lines
626 B
23 lines
626 B
from auto_transformers import AutoTransformers
|
|
|
|
import torch
|
|
|
|
models = [
|
|
'bert-base-cased',
|
|
'distilbert-base-cased',
|
|
'distilgpt2',
|
|
'google/fnet-base'
|
|
]
|
|
|
|
for name in models:
|
|
try:
|
|
op = AutoTransformers(model_name=name)
|
|
out1 = op('hello, world.')
|
|
op.save_model(format='torchscript')
|
|
op.model = torch.jit.load(name.replace('/', '-') + '.pt')
|
|
out2 = op('hello, world.')
|
|
assert (out1 == out2).all()
|
|
print(f'[SUCCESS] Saved torchscript for model "{name}"')
|
|
except Exception as e:
|
|
print(f'[ERROR] Fail for model "{name}": {e}.')
|
|
continue
|