|
|
@ -1,5 +1,7 @@ |
|
|
|
import os |
|
|
|
import onnxruntime |
|
|
|
from transformers.onnx.features import FeaturesManager |
|
|
|
from transformers.onnx import validate_model_outputs |
|
|
|
|
|
|
|
import towhee |
|
|
|
from towhee import ops |
|
|
@ -149,11 +151,25 @@ elif args.format == 'onnx': |
|
|
|
@towhee.register |
|
|
|
def run_onnx(txt): |
|
|
|
inputs = op.tokenizer(txt, return_tensors='np') |
|
|
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
|
|
|
new_inputs = {} |
|
|
|
for k in onnx_inputs: |
|
|
|
new_inputs[k] = inputs[k] |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
|
|
|
try: |
|
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
|
op.model, feature='default') |
|
|
|
onnx_config = model_onnx_config(op.model.config) |
|
|
|
new_inputs = onnx_config.generate_dummy_inputs_onnxruntime(inputs) |
|
|
|
onnx_inputs = {} |
|
|
|
for name, value in new_inputs.items(): |
|
|
|
if isinstance(value, (list, tuple)): |
|
|
|
value = onnx_config.flatten_output_collection_property(name, value) |
|
|
|
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()}) |
|
|
|
else: |
|
|
|
onnx_inputs[name] = value.numpy() |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(onnx_inputs)) |
|
|
|
except Exception: |
|
|
|
onnx_inputs = [x.name for x in sess.get_inputs()] |
|
|
|
new_inputs = {} |
|
|
|
for k in onnx_inputs: |
|
|
|
new_inputs[k] = inputs[k] |
|
|
|
outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) |
|
|
|
return outs[0].squeeze(0) |
|
|
|
|
|
|
|
def insert(model_name, collection_name): |
|
|
|