|
|
@ -124,52 +124,35 @@ elif args.format == 'onnx': |
|
|
|
collection_name = collection_name + '_onnx' |
|
|
|
saved_name = model_name.replace('/', '-') |
|
|
|
if not os.path.exists(onnx_path): |
|
|
|
try: |
|
|
|
op.save_model(format='onnx', path=onnx_path[:-5]) |
|
|
|
except Exception: |
|
|
|
inputs = op.tokenizer('This is test.', return_tensors='pt') |
|
|
|
input_names = list(inputs.keys()) |
|
|
|
dynamic_axes = {} |
|
|
|
for i_n in input_names: |
|
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
output_names = ['last_hidden_state'] |
|
|
|
for o_n in output_names: |
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
torch.onnx.export( |
|
|
|
op.model, |
|
|
|
tuple(inputs.values()), |
|
|
|
onnx_path, |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
opset_version=14, |
|
|
|
do_constant_folding=True, |
|
|
|
) |
|
|
|
inputs = op.tokenizer('This is test.', return_tensors='pt') |
|
|
|
input_names = list(inputs.keys()) |
|
|
|
dynamic_axes = {} |
|
|
|
for i_n in input_names: |
|
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
output_names = ['last_hidden_state'] |
|
|
|
for o_n in output_names: |
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
torch.onnx.export( |
|
|
|
op.model, |
|
|
|
tuple(inputs.values()), |
|
|
|
onnx_path, |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
opset_version=14, |
|
|
|
do_constant_folding=True, |
|
|
|
) |
|
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
|
|
|
|
@towhee.register |
|
|
|
def run_onnx(txt): |
|
|
|
inputs = op.tokenizer(txt, return_tensors='np') |
|
|
|
try: |
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
op.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(op.model.config) |
|
|
|
new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs) |
|
|
|
onnx_inputs = {} |
|
|
|
for name, value in new_inputs.items(): |
|
|
|
if isinstance(value, (list, tuple)): |
|
|
|
value = onnx_config.flatten_output_collection_property(name, value) |
|
|
|
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) |
|
|
|
else: |
|
|
|
onnx_inputs[name] = value.numpy() |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs)) |
|
|
|
except Exception: |
|
|
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
|
|
|
new_inputs = {} |
|
|
|
for k in onnx_inputs: |
|
|
|
new_inputs[k] = inputs[k] |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
|
|
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
|
|
|
new_inputs = {} |
|
|
|
for k in onnx_inputs: |
|
|
|
new_inputs[k] = inputs[k] |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
|
|
|
return outs[0].squeeze(0) |
|
|
|
|
|
|
|
def insert(model_name, collection_name): |
|
|
|