|
@ -122,7 +122,27 @@ elif args.format == 'onnx': |
|
|
collection_name = collection_name + '_onnx' |
|
|
collection_name = collection_name + '_onnx' |
|
|
saved_name = model_name.replace('/', '-') |
|
|
saved_name = model_name.replace('/', '-') |
|
|
if not os.path.exists(onnx_path): |
|
|
if not os.path.exists(onnx_path): |
|
|
|
|
|
try: |
|
|
op.save_model(format='onnx', path=onnx_path[:-5]) |
|
|
op.save_model(format='onnx', path=onnx_path[:-5]) |
|
|
|
|
|
except Exception: |
|
|
|
|
|
inputs = op.tokenizer('This is test.', return_tensors='pt') |
|
|
|
|
|
input_names = list(inputs.keys()) |
|
|
|
|
|
dynamic_axes = {} |
|
|
|
|
|
for i_n in input_names: |
|
|
|
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
|
|
output_names = ['last_hidden_state'] |
|
|
|
|
|
for o_n in output_names: |
|
|
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
|
|
torch.onnx.export( |
|
|
|
|
|
op.model, |
|
|
|
|
|
tuple(inputs.values()), |
|
|
|
|
|
onnx_path, |
|
|
|
|
|
input_names=input_names, |
|
|
|
|
|
output_names=output_names, |
|
|
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
|
|
opset_version=14, |
|
|
|
|
|
do_constant_folding=True, |
|
|
|
|
|
) |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
|
|
|
|
|