logo
Browse Source

Fix log

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
9acb9072e1
  1. 44
      auto_transformers.py
  2. 30
      test_onnx.py

44
auto_transformers.py

@ -22,6 +22,9 @@ from towhee.operator import NNOperator
from towhee import register from towhee import register
import warnings import warnings
import logging
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
@ -83,6 +86,7 @@ class AutoTransformers(NNOperator):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-') name = self.model_name.replace('/', '-')
path = os.path.join(path, name) path = os.path.join(path, name)
inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary inputs = self.tokenizer('[CLS]', return_tensors='pt') # a dictionary
if format == 'pytorch': if format == 'pytorch':
path = path + '.pt' path = path + '.pt'
@ -101,37 +105,39 @@ class AutoTransformers(NNOperator):
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx' path = path + '.onnx'
input_names = list(inputs.keys())
dynamic_axes = {}
for i_n in input_names:
dynamic_axes[i_n] = {0: 'batch_size', 1: 'sequence_length'}
try: try:
output_names = ['last_hidden_state']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(self.model, torch.onnx.export(self.model,
tuple(inputs.values()), tuple(inputs.values()),
path, path,
input_names=list(inputs.keys()),
output_names=["last_hidden_state"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
"attention_mask": {0: "batch_size", 1: "input_length"},
"last_hidden_state": {0: "batch_size"},
},
opset_version=13,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
do_constant_folding=True, do_constant_folding=True,
# enable_onnx_checker=True, # enable_onnx_checker=True,
) )
except Exception as e: except Exception as e:
print(e, '\nTrying with 2 outputs...') print(e, '\nTrying with 2 outputs...')
output_names = ['last_hidden_state', 'pooler_output']
for o_n in output_names:
dynamic_axes[o_n] = {0: 'batch_size', 1: 'sequence_length'}
torch.onnx.export(self.model, torch.onnx.export(self.model,
tuple(inputs.values()), tuple(inputs.values()),
path, path,
input_names=["input_ids", "token_type_ids", "attention_mask"], # list(inputs.keys())
output_names=["last_hidden_state", "pooler_output"],
opset_version=13,
dynamic_axes={
"input_ids": {0: "batch_size", 1: "input_length"},
"token_type_ids": {0: "batch_size", 1: "input_length"},
"attention_mask": {0: "batch_size", 1: "input_length"},
"last_hidden_state": {0: "batch_size"},
"pooler_outputs": {0: "batch_size"}
})
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=11,
do_constant_folding=True,
# enable_onnx_checker=True,
)
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':
else: else:
log.error(f'Unsupported format "{format}".') log.error(f'Unsupported format "{format}".')

30
test_onnx.py

@ -13,9 +13,12 @@ import psutil
# full_models = AutoTransformers.supported_model_names() # full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx') # checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models] # models = [x for x in full_models if x not in checked_models]
models = ['distilbert-base-cased']
models = ['bert-base-cased', 'distilbert-base-cased']
test_txt = '[UNK]'
atol = 1e-3 atol = 1e-3
log_path = 'transformers_onnx.log' log_path = 'transformers_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')
logger = logging.getLogger('transformers_onnx') logger = logging.getLogger('transformers_onnx')
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
@ -30,24 +33,31 @@ ch.setFormatter(formatter)
logger.addHandler(ch) logger.addHandler(ch)
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 **3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 **3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 **3))} GB')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}') logger.debug(f'cpu: {psutil.cpu_count()}')
status = None
for name in models: for name in models:
logger.info(f'***{name}***') logger.info(f'***{name}***')
saved_name = name.replace('/', '-') saved_name = name.replace('/', '-')
if status:
f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5
try: try:
op = AutoTransformers(model_name=name) op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
out1 = op(test_txt)
logger.info('OP LOADED.') logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e: except Exception as e:
logger.error(f'FAIL TO LOAD OP: {e}') logger.error(f'FAIL TO LOAD OP: {e}')
continue continue
try: try:
op.save_model(format='onnx') op.save_model(format='onnx')
logger.info('ONNX SAVED.') logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e: except Exception as e:
logger.error(f'FAIL TO SAVE ONNX: {e}') logger.error(f'FAIL TO SAVE ONNX: {e}')
continue continue
@ -59,21 +69,27 @@ for name in models:
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
onnx.checker.check_model(saved_onnx) onnx.checker.check_model(saved_onnx)
logger.info('ONNX CHECKED.') logger.info('ONNX CHECKED.')
status[3] = 'success'
except Exception as e: except Exception as e:
logger.error(f'FAIL TO CHECK ONNX: {e}') logger.error(f'FAIL TO CHECK ONNX: {e}')
continue continue
try: try:
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx',
providers=onnxruntime.get_available_providers()) providers=onnxruntime.get_available_providers())
inputs = op.tokenizer('hello, world.', return_tensors='np')
out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
inputs = op.tokenizer(test_txt, return_tensors='np')
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))
logger.info('ONNX WORKED.') logger.info('ONNX WORKED.')
status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol): if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK') logger.info('Check accuracy: OK')
status[5] = 'success'
else: else:
logger.info(f'Check accuracy: atol is larger than {atol}.') logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e: except Exception as e:
logger.error(f'FAIL TO RUN ONNX: {e}') logger.error(f'FAIL TO RUN ONNX: {e}')
continue continue
if status:
f.write(','.join(status) + '\n')
print('Finished.') print('Finished.')

Loading…
Cancel
Save