logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
ebfc9c74e3
  1. 3
      auto_transformers.py
  2. 18
      test_save.py

3
auto_transformers.py

@ -84,6 +84,9 @@ class AutoTransformers(NNOperator):
inputs = list(inputs.values())
if jit:
try:
try:
traced_model = torch.jit.script(self.model)
except Exception:
traced_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(traced_model, destination)
except Exception as e:

18
test_save.py

@ -0,0 +1,18 @@
from auto_transformers import AutoTransformers
import torch
models = ['bert-base-cased', 'distilbert-base-cased', 'distilgpt2']
for name in models:
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
op.save_model()
op.model = torch.jit.load(name + '.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
print(f'[SUCCESS] Saved torchscript for model "{name}"')
except Exception as e:
print(f'[ERROR] Fail for model "{name}": {e}.')
continue
Loading…
Cancel
Save