logo
Browse Source

Update onnx export config

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
3245db6f0c
  1. 9
      auto_transformers.py

9
auto_transformers.py

@ -142,6 +142,7 @@ class AutoTransformers(NNOperator):
@property @property
def onnx_config(self): def onnx_config(self):
from transformers.onnx.features import FeaturesManager from transformers.onnx.features import FeaturesManager
try:
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default') self._model, feature='default')
old_config = model_onnx_config(self.model_config) old_config = model_onnx_config(self.model_config)
@ -149,6 +150,14 @@ class AutoTransformers(NNOperator):
'inputs': dict(old_config.inputs), 'inputs': dict(old_config.inputs),
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} 'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
} }
except Exception:
input_dict = {}
for k in self.tokenizer.model_input_names:
input_dict[k] = {0: 'batch_size', 1: 'sequence_length'}
onnx_config = {
'inputs': input_dict,
'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}}
}
return onnx_config return onnx_config
def post_proc(self, token_embeddings, inputs): def post_proc(self, token_embeddings, inputs):

Loading…
Cancel
Save