logo
Browse Source

Update benchmark/run.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
b883739db2
  1. 26
      benchmark/run.py

26
benchmark/run.py

@ -1,5 +1,7 @@
import os
import onnxruntime
from transformers.onnx.features import FeaturesManager
from transformers.onnx import validate_model_outputs
import towhee
from towhee import ops
@ -149,11 +151,25 @@ elif args.format == 'onnx':
@towhee.register
def run_onnx(txt):
inputs = op.tokenizer(txt, return_tensors='np')
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
try:
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
op.model, feature='default')
onnx_config = model_onnx_config(op.model.config)
new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs)
onnx_inputs = {}
for name, value in new_inputs.items():
if isinstance(value, (list, tuple)):
value = onnx_config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs))
except Exception:
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
return outs[0].squeeze(0)
def insert(model_name, collection_name):

Loading…
Cancel
Save