logo
Browse Source

Update benchmark/run.py

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
4ea7a7b9ec
  1. 63
      benchmark/run.py

63
benchmark/run.py

@ -124,52 +124,35 @@ elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
saved_name = model_name.replace('/', '-')
if not os.path.exists(onnx_path):
try:
op.save_model(format='onnx', path=onnx_path[:-5])
except Exception:
inputs = op.tokenizer('This is test.', return_tensors='pt')
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(
op.model,
tuple(inputs.values()),
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
inputs = op.tokenizer('This is test.', return_tensors='pt')
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(
op.model,
tuple(inputs.values()),
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
@towhee.register
def run_onnx(txt):
inputs = op.tokenizer(txt, return_tensors='np')
try:
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
op.model, feature='default')
onnx_config = model_onnx_config(op.model.config)
new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs)
onnx_inputs = {}
for name, value in new_inputs.items():
if isinstance(value, (list, tuple)):
value = onnx_config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
else:
onnx_inputs[name] = value.numpy()
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs))
except Exception:
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
onnx_inputs = [x.name for x in sess.get_inputs()]
new_inputs = {}
for k in onnx_inputs:
new_inputs[k] = inputs[k]
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs))
return outs[0].squeeze(0)
def insert(model_name, collection_name):

Loading…
Cancel
Save