|
|
@ -124,9 +124,6 @@ elif args.format == 'onnx': |
|
|
|
collection_name = collection_name + '_onnx' |
|
|
|
saved_name = model_name.replace('/', '-') |
|
|
|
if not os.path.exists(onnx_path): |
|
|
|
try: |
|
|
|
op.save_model(format='onnx', path=onnx_path[:-5]) |
|
|
|
except Exception: |
|
|
|
inputs = op.tokenizer('This is test.', return_tensors='pt') |
|
|
|
input_names = list(inputs.keys()) |
|
|
|
dynamic_axes = {} |
|
|
@ -151,20 +148,6 @@ elif args.format == 'onnx': |
|
|
|
@towhee.register |
|
|
|
def run_onnx(txt): |
|
|
|
inputs = op.tokenizer(txt, return_tensors='np') |
|
|
|
try: |
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
op.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(op.model.config) |
|
|
|
new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs) |
|
|
|
onnx_inputs = {} |
|
|
|
for name, value in new_inputs.items(): |
|
|
|
if isinstance(value, (list, tuple)): |
|
|
|
value = onnx_config.flatten_output_collection_property(name, value) |
|
|
|
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) |
|
|
|
else: |
|
|
|
onnx_inputs[name] = value.numpy() |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs)) |
|
|
|
except Exception: |
|
|
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
|
|
|
new_inputs = {} |
|
|
|
for k in onnx_inputs: |
|
|
|