diff --git a/benchmark/run.py b/benchmark/run.py index 51016ca..5e4f907 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -122,7 +122,27 @@ elif args.format == 'onnx': collection_name = collection_name + '_onnx' saved_name = model_name.replace('/', '-') if not os.path.exists(onnx_path): - op.save_model(format='onnx', path=onnx_path[:-5]) + try: + op.save_model(format='onnx', path=onnx_path[:-5]) + except Exception: + inputs = op.tokenizer('This is test.', return_tensors='pt') + input_names = list(inputs.keys()) + dynamic_axes = {} + for i_n in input_names: + dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} + output_names = ['last_hidden_state'] + for o_n in output_names: + dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} + torch.onnx.export( + op.model, + tuple(inputs.values()), + onnx_path, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=14, + do_constant_folding=True, + ) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())