diff --git a/auto_transformers.py b/auto_transformers.py index 797d8ed..95832d6 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -107,7 +107,7 @@ class AutoTransformers(NNOperator): path, input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) output_names=["last_hidden_state"], - opset_version=10, + opset_version=12, dynamic_axes={ "input_ids": {0: "batch_size", 1: "input_length"}, "token_type_ids": {0: "batch_size", 1: "input_length"}, @@ -120,7 +120,7 @@ class AutoTransformers(NNOperator): path, input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) output_names=["last_hidden_state", "pooler_output"], - opset_version=10, + opset_version=12, dynamic_axes={ "input_ids": {0: "batch_size", 1: "input_length"}, "token_type_ids": {0: "batch_size", 1: "input_length"}, diff --git a/test_onnx.py b/test_onnx.py index bc92e41..fd7fdab 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -4,7 +4,8 @@ import onnx f = open('onnx.csv', 'a+') f.write('model_name, run op, save_onnx, check_onnx\n') -models = AutoTransformers.supported_model_names()[:1] +# models = AutoTransformers.supported_model_names()[:1] +models = ['bert-base-cased'] for name in models: line = f'{name}, ' @@ -37,3 +38,4 @@ for name in models: continue line += '\n' f.write(line) +print('Finished.')