logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 3 years ago
parent
commit
ebfc9c74e3
  1. 5
      auto_transformers.py
  2. 18
      test_save.py

5
auto_transformers.py

@ -84,7 +84,10 @@ class AutoTransformers(NNOperator):
inputs = list(inputs.values()) inputs = list(inputs.values())
if jit: if jit:
try: try:
traced_model = torch.jit.trace(self.model, inputs, strict=False)
try:
traced_model = torch.jit.script(self.model)
except Exception:
traced_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(traced_model, destination) torch.jit.save(traced_model, destination)
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')

18
test_save.py

@ -0,0 +1,18 @@
from auto_transformers import AutoTransformers
import torch
models = ['bert-base-cased', 'distilbert-base-cased', 'distilgpt2']
for name in models:
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
op.save_model()
op.model = torch.jit.load(name + '.pt')
out2 = op('hello, world.')
assert (out1 == out2).all()
print(f'[SUCCESS] Saved torchscript for model "{name}"')
except Exception as e:
print(f'[ERROR] Fail for model "{name}": {e}.')
continue
Loading…
Cancel
Save