From 4ea7a7b9ec7326541288ab1d20a397c6669da437 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 14 Dec 2022 16:58:27 +0800 Subject: [PATCH] Update benchmark/run.py Signed-off-by: Jael Gu --- benchmark/run.py | 63 ++++++++++++++++++------------------------------ 1 file changed, 23 insertions(+), 40 deletions(-) diff --git a/benchmark/run.py b/benchmark/run.py index 5e90f25..b048424 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -124,52 +124,35 @@ elif args.format == 'onnx': collection_name = collection_name + '_onnx' saved_name = model_name.replace('/', '-') if not os.path.exists(onnx_path): - try: - op.save_model(format='onnx', path=onnx_path[:-5]) - except Exception: - inputs = op.tokenizer('This is test.', return_tensors='pt') - input_names = list(inputs.keys()) - dynamic_axes = {} - for i_n in input_names: - dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} - output_names = ['last_hidden_state'] - for o_n in output_names: - dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} - torch.onnx.export( - op.model, - tuple(inputs.values()), - onnx_path, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=14, - do_constant_folding=True, - ) + inputs = op.tokenizer('This is test.', return_tensors='pt') + input_names = list(inputs.keys()) + dynamic_axes = {} + for i_n in input_names: + dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} + output_names = ['last_hidden_state'] + for o_n in output_names: + dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} + torch.onnx.export( + op.model, + tuple(inputs.values()), + onnx_path, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=14, + do_constant_folding=True, + ) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) @towhee.register def run_onnx(txt): inputs = op.tokenizer(txt, return_tensors='np') - try: - model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( - op.model, feature='default') - onnx_config = model_onnx_config(op.model.config) - new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs) - onnx_inputs = {} - for name, value in new_inputs.items(): - if isinstance(value, (list, tuple)): - value = onnx_config.flatten_output_collection_property(name, value) - onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) - else: - onnx_inputs[name] = value.numpy() - outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs)) - except Exception: - onnx_inputs = [x.name for x in sess.get_inputs()] - new_inputs = {} - for k in onnx_inputs: - new_inputs[k] = inputs[k] - outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) + onnx_inputs = [x.name for x in sess.get_inputs()] + new_inputs = {} + for k in onnx_inputs: + new_inputs[k] = inputs[k] + outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) return outs[0].squeeze(0) def insert(model_name, collection_name):