logo
Browse Source

Update save_onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
6d228847cf
  1. 56
      auto_transformers.py
  2. 11
      test_onnx.py

56
auto_transformers.py

@ -15,6 +15,7 @@
import numpy
import os
import torch
import shutil
from pathlib import Path
from transformers import AutoTokenizer, AutoModel
@ -47,6 +48,8 @@ class AutoTransformers(NNOperator):
try:
self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval()
self.configs = self.model.config
except Exception as e:
model_list = self.supported_model_names()
if model_name not in model_list:
@ -87,57 +90,32 @@ class AutoTransformers(NNOperator):
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary
dummy_input = '[CLS]'
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
torch.save(self.model, path + '.pt')
elif format == 'torchscript':
path = path + '.pt'
inputs = list(inputs.values())
try:
try:
jit_model = torch.jit.script(self.model)
except Exception:
jit_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(jit_model, path)
torch.jit.save(jit_model, path + '.pt')
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
try:
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
do_constant_folding=True,
# enable_onnx_checker=True,
)
except Exception as e:
print(e, '\nTrying with 2 outputs...')
output_names = ['last_hidden_state', 'pooler_output']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
do_constant_folding=True,
# enable_onnx_checker=True,
)
from transformers.convert_graph_to_onnx import convert
if os.path.isdir(path):
shutil.rmtree(path)
path = os.path.join(path, 'model.onnx')
convert(
model=self.model_name,
output=Path(path),
framework='pt',
opset=13
)
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')

11
test_onnx.py

@ -14,7 +14,7 @@ import psutil
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['bert-base-cased', 'distilbert-base-cased']
test_txt = '[UNK]'
test_txt = 'hello, world.'
atol = 1e-3
log_path = 'transformers_onnx.log'
f = open('onnx.csv', 'w+')
@ -43,6 +43,7 @@ status = None
for name in models:
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}/model.onnx'
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
@ -63,10 +64,10 @@ for name in models:
continue
try:
try:
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
except Exception:
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.')
status[3] = 'success'
@ -74,10 +75,10 @@ for name in models:
logger.error(f'FAIL TO CHECK ONNX: {e}')
continue
try:
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx',
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())
inputs = op.tokenizer(test_txt, return_tensors='np')
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))
out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs))
logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol):

Loading…
Cancel
Save