From 6d228847cfa1623e43b4470c425e550694a4a8a5 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 7 Dec 2022 16:53:50 +0800 Subject: [PATCH] Update save_onnx Signed-off-by: Jael Gu --- auto_transformers.py | 56 ++++++++++++++------------------------------ test_onnx.py | 11 +++++---- 2 files changed, 23 insertions(+), 44 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index 9ff2488..95b62a5 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -15,6 +15,7 @@ import numpy import os import torch +import shutil from pathlib import Path from transformers import AutoTokenizer, AutoModel @@ -47,6 +48,8 @@ class AutoTransformers(NNOperator): try: self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model.eval() + + self.configs = self.model.config except Exception as e: model_list = self.supported_model_names() if model_name not in model_list: @@ -87,57 +90,32 @@ class AutoTransformers(NNOperator): name = self.model_name.replace('/', '-') path = os.path.join(path, name) - inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary + dummy_input = '[CLS]' + inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary if format == 'pytorch': - path = path + '.pt' - torch.save(self.model, path) + torch.save(self.model, path + '.pt') elif format == 'torchscript': - path = path + '.pt' inputs = list(inputs.values()) try: try: jit_model = torch.jit.script(self.model) except Exception: jit_model = torch.jit.trace(self.model, inputs, strict=False) - torch.jit.save(jit_model, path) + torch.jit.save(jit_model, path + '.pt') except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': - path = path + '.onnx' - input_names = list(inputs.keys()) - dynamic_axes = {} - for i_n in input_names: - dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} - try: - output_names = ['last_hidden_state'] - for o_n in output_names: - dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} - torch.onnx.export(self.model, - tuple(inputs.values()), - path, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=11, - do_constant_folding=True, - # enable_onnx_checker=True, - ) - except Exception as e: - print(e, '\nTrying with 2 outputs...') - output_names = ['last_hidden_state', 'pooler_output'] - for o_n in output_names: - dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} - torch.onnx.export(self.model, - tuple(inputs.values()), - path, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=11, - do_constant_folding=True, - # enable_onnx_checker=True, - ) + from transformers.convert_graph_to_onnx import convert + if os.path.isdir(path): + shutil.rmtree(path) + path = os.path.join(path, 'model.onnx') + convert( + model=self.model_name, + output=Path(path), + framework='pt', + opset=13 + ) # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') diff --git a/test_onnx.py b/test_onnx.py index 239d2b5..9443eb1 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -14,7 +14,7 @@ import psutil # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] models = ['bert-base-cased', 'distilbert-base-cased'] -test_txt = '[UNK]' +test_txt = 'hello, world.' atol = 1e-3 log_path = 'transformers_onnx.log' f = open('onnx.csv', 'w+') @@ -43,6 +43,7 @@ status = None for name in models: logger.info(f'***{name}***') saved_name = name.replace('/', '-') + onnx_path = f'saved/onnx/{saved_name}/model.onnx' if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 @@ -63,10 +64,10 @@ for name in models: continue try: try: - onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx') + onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) except Exception: - saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) + saved_onnx = onnx.load(onnx_path, load_external_data=False) onnx.checker.check_model(saved_onnx) logger.info('ONNX CHECKED.') status[3] = 'success' @@ -74,10 +75,10 @@ for name in models: logger.error(f'FAIL TO CHECK ONNX: {e}') continue try: - sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', + sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) inputs = op.tokenizer(test_txt, return_tensors='np') - out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs)) + out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs)) logger.info('ONNX WORKED.') status[4] = 'success' if numpy.allclose(out1, out2, atol=atol):