diff --git a/auto_transformers.py b/auto_transformers.py index a6a47bb..763be88 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -142,13 +142,22 @@ class AutoTransformers(NNOperator): @property def onnx_config(self): from transformers.onnx.features import FeaturesManager - model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( - self._model, feature='default') - old_config = model_onnx_config(self.model_config) - onnx_config = { - 'inputs': dict(old_config.inputs), - 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} - } + try: + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( + self._model, feature='default') + old_config = model_onnx_config(self.model_config) + onnx_config = { + 'inputs': dict(old_config.inputs), + 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} + } + except Exception: + input_dict = {} + for k in self.tokenizer.model_input_names: + input_dict[k] = {0: 'batch_size', 1: 'sequence_length'} + onnx_config = { + 'inputs': input_dict, + 'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}} + } return onnx_config def post_proc(self, token_embeddings, inputs):