logo
Browse Source

Update save_onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
6d228847cf
  1. 54
      auto_transformers.py
  2. 11
      test_onnx.py

54
auto_transformers.py

@ -15,6 +15,7 @@
import numpy import numpy
import os import os
import torch import torch
import shutil
from pathlib import Path from pathlib import Path
from transformers import AutoTokenizer, AutoModel from transformers import AutoTokenizer, AutoModel
@ -47,6 +48,8 @@ class AutoTransformers(NNOperator):
try: try:
self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model = AutoModel.from_pretrained(model_name).to(self.device)
self.model.eval() self.model.eval()
self.configs = self.model.config
except Exception as e: except Exception as e:
model_list = self.supported_model_names() model_list = self.supported_model_names()
if model_name not in model_list: if model_name not in model_list:
@ -87,56 +90,31 @@ class AutoTransformers(NNOperator):
name = self.model_name.replace('/', '-') name = self.model_name.replace('/', '-')
path = os.path.join(path, name) path = os.path.join(path, name)
inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary
dummy_input = '[CLS]'
inputs = self.tokenizer(dummy_input, return_tensors='pt') # a dictionary
if format == 'pytorch': if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
torch.save(self.model, path + '.pt')
elif format == 'torchscript': elif format == 'torchscript':
path = path + '.pt'
inputs = list(inputs.values()) inputs = list(inputs.values())
try: try:
try: try:
jit_model = torch.jit.script(self.model) jit_model = torch.jit.script(self.model)
except Exception: except Exception:
jit_model = torch.jit.trace(self.model, inputs, strict=False) jit_model = torch.jit.trace(self.model, inputs, strict=False)
torch.jit.save(jit_model, path)
torch.jit.save(jit_model, path + '.pt')
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx'
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
try:
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
do_constant_folding=True,
# enable_onnx_checker=True,
)
except Exception as e:
print(e, '\nTrying with 2 outputs...')
output_names = ['last_hidden_state', 'pooler_output']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(self.model,
tuple(inputs.values()),
path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
do_constant_folding=True,
# enable_onnx_checker=True,
from transformers.convert_graph_to_onnx import convert
if os.path.isdir(path):
shutil.rmtree(path)
path = os.path.join(path, 'model.onnx')
convert(
model=self.model_name,
output=Path(path),
framework='pt',
opset=13
) )
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':
else: else:

11
test_onnx.py

@ -14,7 +14,7 @@ import psutil
# checked_models = AutoTransformers.supported_model_names(format='onnx') # checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models] # models = [x for x in full_models if x not in checked_models]
models = ['bert-base-cased', 'distilbert-base-cased'] models = ['bert-base-cased', 'distilbert-base-cased']
test_txt = '[UNK]'
test_txt = 'hello, world.'
atol = 1e-3 atol = 1e-3
log_path = 'transformers_onnx.log' log_path = 'transformers_onnx.log'
f = open('onnx.csv', 'w+') f = open('onnx.csv', 'w+')
@ -43,6 +43,7 @@ status = None
for name in models: for name in models:
logger.info(f'***{name}***') logger.info(f'***{name}***')
saved_name = name.replace('/', '-') saved_name = name.replace('/', '-')
onnx_path = f'saved/onnx/{saved_name}/model.onnx'
if status: if status:
f.write(','.join(status) + '\n') f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5 status = [name] + ['fail'] * 5
@ -63,10 +64,10 @@ for name in models:
continue continue
try: try:
try: try:
onnx_model = onnx.load(f'saved/onnx/{saved_name}.onnx')
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model)
except Exception: except Exception:
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
saved_onnx = onnx.load(onnx_path, load_external_data=False)
onnx.checker.check_model(saved_onnx) onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.') logger.info('ONNX CHECKED.')
status[3] = 'success' status[3] = 'success'
@ -74,10 +75,10 @@ for name in models:
logger.error(f'FAIL TO CHECK ONNX: {e}') logger.error(f'FAIL TO CHECK ONNX: {e}')
continue continue
try: try:
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx',
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers()) providers=onnxruntime.get_available_providers())
inputs = op.tokenizer(test_txt, return_tensors='np') inputs = op.tokenizer(test_txt, return_tensors='np')
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))
out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs))
logger.info('ONNX WORKED.') logger.info('ONNX WORKED.')
status[4] = 'success' status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol): if numpy.allclose(out1, out2, atol=atol):

Loading…
Cancel
Save