Browse Source
Update onnx export config
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
16 additions and
7 deletions
-
auto_transformers.py
|
@ -142,13 +142,22 @@ class AutoTransformers(NNOperator): |
|
|
@property |
|
|
@property |
|
|
def onnx_config(self): |
|
|
def onnx_config(self): |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
|
|
self._model, feature='default') |
|
|
|
|
|
old_config = model_onnx_config(self.model_config) |
|
|
|
|
|
onnx_config = { |
|
|
|
|
|
'inputs': dict(old_config.inputs), |
|
|
|
|
|
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
|
|
self._model, feature='default') |
|
|
|
|
|
old_config = model_onnx_config(self.model_config) |
|
|
|
|
|
onnx_config = { |
|
|
|
|
|
'inputs': dict(old_config.inputs), |
|
|
|
|
|
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']} |
|
|
|
|
|
} |
|
|
|
|
|
except Exception: |
|
|
|
|
|
input_dict = {} |
|
|
|
|
|
for k in self.tokenizer.model_input_names: |
|
|
|
|
|
input_dict[k] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
|
|
onnx_config = { |
|
|
|
|
|
'inputs': input_dict, |
|
|
|
|
|
'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}} |
|
|
|
|
|
} |
|
|
return onnx_config |
|
|
return onnx_config |
|
|
|
|
|
|
|
|
def post_proc(self, token_embeddings, inputs): |
|
|
def post_proc(self, token_embeddings, inputs): |
|
|