From 53285855688dcb3c329d66140b8ad439e737ebfc Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 16 Dec 2022 10:31:42 +0800 Subject: [PATCH] Modify save_model to support tritonserver Signed-off-by: Jael Gu --- auto_transformers.py | 44 ++++++++++++++++++++++++++++++-------------- test_onnx.py | 4 ++-- 2 files changed, 32 insertions(+), 16 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index d9662a2..2e15d5a 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -119,47 +119,63 @@ class AutoTransformers(NNOperator): vec = features.cpu().detach().numpy() return vec - def save_model(self, format: str = 'pytorch', path: str = 'default'): - if path == 'default': - path = str(Path(__file__).parent) - path = os.path.join(path, 'saved', format) - os.makedirs(path, exist_ok=True) + def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): + if output_file == 'default': + output_file = str(Path(__file__).parent) + output_file = os.path.join(output_file, 'saved', model_type) + os.makedirs(output_file, exist_ok=True) name = self.model_name.replace('/', '-') - path = os.path.join(path, name) + output_file = os.path.join(output_file, name) + if model_type in ['pytorch', 'torchscript']: + output_file = output_file + '.pt' + elif model_type == 'onnx': + output_file = output_file + '.onnx' + else: + raise AttributeError('Unsupported model_type.') dummy_input = '[CLS]' inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary - if format == 'pytorch': - torch.save(self.model, path + '.pt') - elif format == 'torchscript': + if model_type == 'pytorch': + torch.save(self.model, output_file) + elif model_type == 'torchscript': inputs = list(inputs.values()) try: try: jit_model = torch.jit.script(self.model) except Exception: jit_model = torch.jit.trace(self.model, inputs, strict=False) - torch.jit.save(jit_model, path + '.pt') + torch.jit.save(jit_model, output_file) except Exception as e: log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') - elif format == 'onnx': + elif model_type == 'onnx': from transformers.onnx.features import FeaturesManager from transformers.onnx import export model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( self.model, feature='default') onnx_config = model_onnx_config(self.configs) - if os.path.isdir(path): - shutil.rmtree(path) + # if os.path.isdir(output_file[:-5]): + # shutil.rmtree(output_file[:-5]) + # print('********', Path(output_file)) onnx_inputs, onnx_outputs = export( self.tokenizer, self.model, config=onnx_config, opset=13, - output=Path(path+'.onnx') + output=Path(output_file) ) # todo: elif format == 'tensorrt': else: log.error(f'Unsupported format "{format}".') + return True + + @property + def supported_formats(self): + onnxes = self.supported_model_names(format='onnx') + if self.model_name in onnxes: + return ['onnx'] + else: + return ['pytorch'] @staticmethod def supported_model_names(format: str = None): diff --git a/test_onnx.py b/test_onnx.py index cd95354..662045c 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -12,7 +12,7 @@ import psutil # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] -models = ['bert-base-cased', 'distilbert-base-cased'] +models = ['distilbert-base-cased'] test_txt = 'hello, world.' atol = 1e-3 log_path = 'transformers_onnx.log' @@ -55,7 +55,7 @@ for name in models: logger.error(f'FAIL TO LOAD OP: {e}') continue try: - op.save_model(format='onnx') + op.save_model(model_type='onnx') logger.info('ONNX SAVED.') status[2] = 'success' except Exception as e: