logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
a55e58b29b
  1. 4
      auto_transformers.py
  2. 4
      test_onnx.py

4
auto_transformers.py

@ -107,7 +107,7 @@ class AutoTransformers(NNOperator):
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state"],
opset_version=10,
opset_version=12,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
@ -120,7 +120,7 @@ class AutoTransformers(NNOperator):
path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state", "pooler_output"],
opset_version=10,
opset_version=12,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},

4
test_onnx.py

@ -4,7 +4,8 @@ import onnx
f = open('onnx.csv', 'a+')
f.write('model_name, run op, save_onnx, check_onnx\n')
models = AutoTransformers.supported_model_names()[:1]
# models = AutoTransformers.supported_model_names()[:1]
models = ['bert-base-cased']
for name in models:
line = f'{name}, '
@ -37,3 +38,4 @@ for name in models:
continue
line += '\n'
f.write(line)
print('Finished.')

Loading…
Cancel
Save