|
|
@ -15,6 +15,7 @@ |
|
|
|
import numpy |
|
|
|
import os |
|
|
|
import torch |
|
|
|
import shutil |
|
|
|
from pathlib import Path |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
@ -47,6 +48,8 @@ class AutoTransformers(NNOperator): |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
self.configs = self.model.config |
|
|
|
except Exception as e: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
if model_name not in model_list: |
|
|
@ -87,57 +90,32 @@ class AutoTransformers(NNOperator): |
|
|
|
name = self.model_name.replace('/', '-') |
|
|
|
path = os.path.join(path, name) |
|
|
|
|
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary |
|
|
|
dummy_input = '[CLS]' |
|
|
|
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary |
|
|
|
if format == 'pytorch': |
|
|
|
path = path + '.pt' |
|
|
|
torch.save(self.model, path) |
|
|
|
torch.save(self.model, path + '.pt') |
|
|
|
elif format == 'torchscript': |
|
|
|
path = path + '.pt' |
|
|
|
inputs = list(inputs.values()) |
|
|
|
try: |
|
|
|
try: |
|
|
|
jit_model = torch.jit.script(self.model) |
|
|
|
except Exception: |
|
|
|
jit_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
torch.jit.save(jit_model, path) |
|
|
|
torch.jit.save(jit_model, path + '.pt') |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
elif format == 'onnx': |
|
|
|
path = path + '.onnx' |
|
|
|
input_names = list(inputs.keys()) |
|
|
|
dynamic_axes = {} |
|
|
|
for i_n in input_names: |
|
|
|
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
try: |
|
|
|
output_names = ['last_hidden_state'] |
|
|
|
for o_n in output_names: |
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
torch.onnx.export(self.model, |
|
|
|
tuple(inputs.values()), |
|
|
|
path, |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
opset_version=11, |
|
|
|
do_constant_folding=True, |
|
|
|
# enable_onnx_checker=True, |
|
|
|
) |
|
|
|
except Exception as e: |
|
|
|
print(e, '\nTrying with 2 outputs...') |
|
|
|
output_names = ['last_hidden_state', 'pooler_output'] |
|
|
|
for o_n in output_names: |
|
|
|
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'} |
|
|
|
torch.onnx.export(self.model, |
|
|
|
tuple(inputs.values()), |
|
|
|
path, |
|
|
|
input_names=input_names, |
|
|
|
output_names=output_names, |
|
|
|
dynamic_axes=dynamic_axes, |
|
|
|
opset_version=11, |
|
|
|
do_constant_folding=True, |
|
|
|
# enable_onnx_checker=True, |
|
|
|
) |
|
|
|
from transformers.convert_graph_to_onnx import convert |
|
|
|
if os.path.isdir(path): |
|
|
|
shutil.rmtree(path) |
|
|
|
path = os.path.join(path, 'model.onnx') |
|
|
|
convert( |
|
|
|
model=self.model_name, |
|
|
|
output=Path(path), |
|
|
|
framework='pt', |
|
|
|
opset=13 |
|
|
|
) |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|
else: |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|