logo
Browse Source

Update onnx export config

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
3245db6f0c
  1. 23
      auto_transformers.py

23
auto_transformers.py

@ -142,13 +142,22 @@ class AutoTransformers(NNOperator):
@property
def onnx_config(self):
from transformers.onnx.features import FeaturesManager
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
old_config = model_onnx_config(self.model_config)
onnx_config = {
'inputs': dict(old_config.inputs),
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
}
try:
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self._model, feature='default')
old_config = model_onnx_config(self.model_config)
onnx_config = {
'inputs': dict(old_config.inputs),
'outputs': {'last_hidden_state': old_config.outputs['last_hidden_state']}
}
except Exception:
input_dict = {}
for k in self.tokenizer.model_input_names:
input_dict[k] = {0: 'batch_size', 1: 'sequence_length'}
onnx_config = {
'inputs': input_dict,
'outputs': {'last_hidden_state': {0: 'batch_size', 1: 'sequence_length'}}
}
return onnx_config
def post_proc(self, token_embeddings, inputs):

Loading…
Cancel
Save