diff --git a/benchmark/run.py b/benchmark/run.py index 5e4f907..5e90f25 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -1,5 +1,7 @@ import os import onnxruntime +from transformers.onnx.features import FeaturesManager +from transformers.onnx import validate_model_outputs import towhee from towhee import ops @@ -149,11 +151,25 @@ elif args.format == 'onnx': @towhee.register def run_onnx(txt): inputs = op.tokenizer(txt, return_tensors='np') - onnx_inputs = [x.name for x in sess.get_inputs()] - new_inputs = {} - for k in onnx_inputs: - new_inputs[k] = inputs[k] - outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) + try: + model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( + op.model, feature='default') + onnx_config = model_onnx_config(op.model.config) + new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs) + onnx_inputs = {} + for name, value in new_inputs.items(): + if isinstance(value, (list, tuple)): + value = onnx_config.flatten_output_collection_property(name, value) + onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) + else: + onnx_inputs[name] = value.numpy() + outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs)) + except Exception: + onnx_inputs = [x.name for x in sess.get_inputs()] + new_inputs = {} + for k in onnx_inputs: + new_inputs[k] = inputs[k] + outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) return outs[0].squeeze(0) def insert(model_name, collection_name):