logo
Browse Source

Update benchmark/run.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
4ea7a7b9ec
  1. 17
      benchmark/run.py

17
benchmark/run.py

@ -124,9 +124,6 @@ elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
saved_name = model_name.replace('/', '-')
if not os.path.exists(onnx_path):
try:
op.save_model(format='onnx', path=onnx_path[:-5])
except Exception:
inputs = op.tokenizer('This is test.', return_tensors='pt')
input_names = list(inputs.keys())
dynamic_axes = {}
@ -151,20 +148,6 @@ elif args.format == 'onnx':
@towhee.register
def run_onnx(txt):
inputs = op.tokenizer(txt, return_tensors='np')
try:
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
op.model, feature='default')
onnx_config = model_onnx_config(op.model.config)
new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs)
onnx_inputs = {}
for name, value in new_inputs.items():
if isinstance(value, (list, tuple)):
value = onnx_config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs))
except Exception:
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:

Loading…
Cancel
Save