From 3245db6f0c27c87d4de74f8c14aa2941dd956d17 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 17 Jan 2023 11:07:22 +0800 Subject: [PATCH] Update onnx export config Signed-off-by: Jael Gu --- auto_transformers.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/auto_transformers.py b/auto_transformers.py index a6a47bb..763be88 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -142,13 +142,22 @@ class AutoTransformers(NNOperator): @property def onnx_config(self): from transformers.onnx.features import FeaturesManager - model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( - self._model, feature='default') - old_config = model_onnx_config(self.model_config) - onnx_config = { - 'inputs': dict(old_config.inputs), - 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} - } + try: + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( + self._model, feature='default') + old_config = model_onnx_config(self.model_config) + onnx_config = { + 'inputs': dict(old_config.inputs), + 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} + } + except Exception: + input_dict = {} + for k in self.tokenizer.model_input_names: + input_dict[k] = {0: 'batch_size', 1: 'sequence_length'} + onnx_config = { + 'inputs': input_dict, + 'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}} + } return onnx_config def post_proc(self, token_embeddings, inputs):