logo
Browse Source

Update save onnx in evaluation

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
2a39a1697d
  1. 22
      benchmark/run.py

22
benchmark/run.py

@ -122,7 +122,27 @@ elif args.format == 'onnx':
collection_name = collection_name + '_onnx' collection_name = collection_name + '_onnx'
saved_name = model_name.replace('/', '-') saved_name = model_name.replace('/', '-')
if not os.path.exists(onnx_path): if not os.path.exists(onnx_path):
op.save_model(format='onnx', path=onnx_path[:-5])
try:
op.save_model(format='onnx', path=onnx_path[:-5])
except Exception:
inputs = op.tokenizer('This is test.', return_tensors='pt')
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(
op.model,
tuple(inputs.values()),
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=14,
do_constant_folding=True,
)
sess = onnxruntime.InferenceSession(onnx_path, sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers()) providers=onnxruntime.get_available_providers())

Loading…
Cancel
Save