|
|
@ -22,6 +22,9 @@ from towhee.operator import NNOperator |
|
|
|
from towhee import register |
|
|
|
|
|
|
|
import warnings |
|
|
|
import logging |
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
|
@ -83,6 +86,7 @@ class AutoTransformers(NNOperator): |
|
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name) |
|
|
|
|
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary |
|
|
|
if format == 'pytorch': |
|
|
|
path = path + '.pt' |
|
|
@ -101,37 +105,39 @@ class AutoTransformers(NNOperator): |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|
input_names = list(inputs.keys()) |
|
|
|
dynamic_axes = {} |
|
|
|
for i_n in input_names: |
|
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
try: |
|
|
|
output_names = ['last_hidden_state'] |
|
|
|
for o_n in output_names: |
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
torch.onnx.export(self.model, |
|
|
|
tuple(inputs.values()), |
|
|
|
path, |
|
|
|
input_names=list(inputs.keys()), |
|
|
|
output_names=["last_hidden_state"], |
|
|
|
dynamic_axes={ |
|
|
|
"input_ids": {0: "batch_size", 1: "input_length"}, |
|
|
|
"token_type_ids": {0: "batch_size", 1: "input_length"}, |
|
|
|
"attention_mask": {0: "batch_size", 1: "input_length"}, |
|
|
|
"last_hidden_state": {0: "batch_size"}, |
|
|
|
}, |
|
|
|
opset_version=13, |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
opset_version=11, |
|
|
|
do_constant_folding=True, |
|
|
|
# enable_onnx_checker=True, |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
print(e, '\nTrying with 2 outputs...') |
|
|
|
output_names = ['last_hidden_state', 'pooler_output'] |
|
|
|
for o_n in output_names: |
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
torch.onnx.export(self.model, |
|
|
|
tuple(inputs.values()), |
|
|
|
path, |
|
|
|
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys()) |
|
|
|
output_names=["last_hidden_state", "pooler_output"], |
|
|
|
opset_version=13, |
|
|
|
dynamic_axes={ |
|
|
|
"input_ids": {0: "batch_size", 1: "input_length"}, |
|
|
|
"token_type_ids": {0: "batch_size", 1: "input_length"}, |
|
|
|
"attention_mask": {0: "batch_size", 1: "input_length"}, |
|
|
|
"last_hidden_state": {0: "batch_size"}, |
|
|
|
"pooler_outputs": {0: "batch_size"} |
|
|
|
}) |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
opset_version=11, |
|
|
|
do_constant_folding=True, |
|
|
|
# enable_onnx_checker=True, |
|
|
|
) |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|
else: |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|