From 516372f339d7e387a2e9ef8d8a3c4102a3070990 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 20 Jun 2022 15:50:10 +0800 Subject: [PATCH] Allow to save model as onnx Signed-off-by: Jael Gu --- auto_transformers.py | 37 ++++++++++++++++++++++++++++++++++--- test_save.py | 36 ++++++++++++++++++++++++++---------- 2 files changed, 60 insertions(+), 13 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 90dc590..eb0a46b 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -79,10 +79,13 @@ class AutoTransformers(NNOperator): def save_model(self, format: str = 'pytorch', path: str = 'default'): if path == 'default': path = str(Path(__file__).parent) + path = os.path.join(path, 'saved', format) + os.makedirs(path, exist_ok=True) name = self.model_name.replace('/', '-') path = os.path.join(path, name) - inputs = self.tokenizer('[CLS]', return_tensors='pt') + inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary if format == 'pytorch': + path = path + '.pt' torch.save(self.model, path) elif format == 'torchscript': path = path + '.pt' @@ -96,8 +99,36 @@ class AutoTransformers(NNOperator): except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') - elif format == 'onxx': - pass # todo + elif format == 'onnx': + path = path + '.onnx' + + try: + torch.onnx.export(self.model, + tuple(inputs.values()), + path, + input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) + output_names=["last_hidden_state"], + opset_version=10, + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "input_length"}, + "token_type_ids": {0: "batch_size", 1: "input_length"}, + "attention_mask": {0: "batch_size", 1: "input_length"}, + "last_hidden_state": {0: "batch_size"}, + }) + except Exception: + torch.onnx.export(self.model, + tuple(inputs.values()), + path, + input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) + output_names=["last_hidden_state", "pooler_output"], + opset_version=10, + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "input_length"}, + "token_type_ids": {0: "batch_size", 1: "input_length"}, + "attention_mask": {0: "batch_size", 1: "input_length"}, + "last_hidden_state": {0: "batch_size"}, + "pooler_outputs": {0: "batch_size"} + }) else: log.error(f'Unsupported format "{format}".') diff --git a/test_save.py b/test_save.py index 5b4d2ae..f809e18 100644 --- a/test_save.py +++ b/test_save.py @@ -1,23 +1,39 @@ from auto_transformers import AutoTransformers +import onnx -import torch +f = open('onnx.csv', 'a+') +f.write('model_name, run op, save_onnx, check_onnx\n') models = [ 'bert-base-cased', - 'distilbert-base-cased', - 'distilgpt2', - 'google/fnet-base' + 'distilbert-base-cased' ] for name in models: + line = f'{name}, ' try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') - op.save_model(format='torchscript') - op.model = torch.jit.load(name.replace('/', '-') + '.pt') - out2 = op('hello, world.') - assert (out1 == out2).all() - print(f'[SUCCESS] Saved torchscript for model "{name}"') + line += 'success, ' except Exception as e: - print(f'[ERROR] Fail for model "{name}": {e}.') + line += 'fail, ' + print(f'Fail to load op for {name}: {e}.') continue + try: + op.save_model(format='onnx') + line += 'success, ' + except Exception as e: + line += 'fail, ' + print(f'Fail to save onnx for {name}: {e}.') + continue + try: + saved_name = name.replace('/', '-') + onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') + onnx.checker.check_model(onnx_model) + line += 'success' + except Exception as e: + line += 'fail' + print(f'Fail to check onnx for {name}: {e}.') + continue + line += '\n' + f.write(line)