From ebfc9c74e3d92b7d62178e2b69975a8bd12b8aa9 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 1 Jun 2022 18:48:09 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- auto_transformers.py | 5 ++++- test_save.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 test_save.py diff --git a/auto_transformers.py b/auto_transformers.py index 55d76e1..c032711 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -84,7 +84,10 @@ class AutoTransformers(NNOperator): inputs = list(inputs.values()) if jit: try: - traced_model = torch.jit.trace(self.model, inputs, strict=False) + try: + traced_model = torch.jit.script(self.model) + except Exception: + traced_model = torch.jit.trace(self.model, inputs, strict=False) torch.jit.save(traced_model, destination) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') diff --git a/test_save.py b/test_save.py new file mode 100644 index 0000000..3cf21d0 --- /dev/null +++ b/test_save.py @@ -0,0 +1,18 @@ +from auto_transformers import AutoTransformers + +import torch + +models = ['bert-base-cased', 'distilbert-base-cased', 'distilgpt2'] + +for name in models: + try: + op = AutoTransformers(model_name=name) + out1 = op('hello, world.') + op.save_model() + op.model = torch.jit.load(name + '.pt') + out2 = op('hello, world.') + assert (out1 == out2).all() + print(f'[SUCCESS] Saved torchscript for model "{name}"') + except Exception as e: + print(f'[ERROR] Fail for model "{name}": {e}.') + continue