diff --git a/auto_transformers.py b/auto_transformers.py index 95b62a5..a253784 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -106,15 +106,19 @@ class AutoTransformers(NNOperator): log.error(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': - from transformers.convert_graph_to_onnx import convert + from transformers.onnx.features import FeaturesManager + from transformers.onnx import export + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( + self.model, feature='default') + onnx_config = model_onnx_config(self.model.config) if os.path.isdir(path): shutil.rmtree(path) - path = os.path.join(path, 'model.onnx') - convert( - model=self.model_name, - output=Path(path), - framework='pt', - opset=13 + onnx_inputs, onnx_outputs = export( + self.tokenizer, + self.model, + config=onnx_config, + opset=13, + output=Path(path+'.onnx') ) # todo: elif format == 'tensorrt': else: diff --git a/test_onnx.py b/test_onnx.py index 9443eb1..f08121c 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -43,7 +43,7 @@ status = None for name in models: logger.info(f'***{name}***') saved_name = name.replace('/', '-') - onnx_path = f'saved/onnx/{saved_name}/model.onnx' + onnx_path = f'saved/onnx/{saved_name}.onnx' if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 @@ -78,7 +78,7 @@ for name in models: sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) inputs = op.tokenizer(test_txt, return_tensors='np') - out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs)) + out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs)) logger.info('ONNX WORKED.') status[4] = 'success' if numpy.allclose(out1, out2, atol=atol):