logo
Browse Source

Add log to test_onnx

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
f0131944b4
  1. 3
      auto_transformers.py
  2. 66
      test_onnx.py

3
auto_transformers.py

@ -22,11 +22,8 @@ from towhee.operator import NNOperator
from towhee import register
import warnings
import logging
warnings.filterwarnings('ignore')
logging.getLogger("transformers").setLevel(logging.ERROR)
log = logging.getLogger()
@register(output_schema=['vec'])

66
test_onnx.py

@ -4,52 +4,52 @@ import numpy
import onnx
import onnxruntime
import os
from pathlib import Path
import warnings
import logging
import platform
import psutil
# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['distilbert-base-cased']
atol = 1e-3
log_path = 'transformers_onnx.log'
warnings.filterwarnings('ignore')
logger = logging.getLogger('transformers_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)
log_path = './transformers_onnx.log'
logger = logging.getLogger()
if Path(log_path).exists:
handler = logging.FileHandler(log_path)
else:
handler = logging.StreamHandler()
logger.handlers.append(handler)
logging.basicConfig(filename=log_path,
filemode='a',
format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s',
datefmt='%H:%M:%S',
level=logging.info)
logging.info("Test")
logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 **3))}'
f'/{round(psutil.virtual_memory().available / (1024.0 **3))}'
f'/{round(psutil.virtual_memory().total / (1024.0 **3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')
for name in models:
print(f'Model {name}:')
logger.info(f'***{name}***')
saved_name = name.replace('/', '-')
logging.info(f'\nConverting model {name} to {saved_name}:')
try:
op = AutoTransformers(model_name=name)
out1 = op('hello, world.')
logging.info('OP LOADED.')
logger.info('OP LOADED.')
except Exception as e:
logging.error(f'FAIL TO LOAD OP: {e}')
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
logging.info('ONNX SAVED.')
logger.info('ONNX SAVED.')
except Exception as e:
logging.info(f'FAIL TO SAVE ONNX: {e}')
logger.error(f'FAIL TO SAVE ONNX: {e}')
continue
try:
try:
@ -58,22 +58,22 @@ for name in models:
except Exception:
saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False)
onnx.checker.check_model(saved_onnx)
logging.info('ONNX CHECKED.')
logger.info('ONNX CHECKED.')
except Exception as e:
logging.error(f'FAIL TO CHECK ONNX: {e}')
logger.error(f'FAIL TO CHECK ONNX: {e}')
continue
try:
sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx',
providers=onnxruntime.get_available_providers())
inputs = op.tokenizer('hello, world.', return_tensors='np')
out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
logging.info('ONNX WORKED.')
if numpy.allclose(out1, out2, atol=1e-3):
logging.info('Check accuracy: OK')
logger.info('ONNX WORKED.')
if numpy.allclose(out1, out2, atol=atol):
logger.info('Check accuracy: OK')
else:
logging.info('Check accuracy: atol is larger than 1e-3.')
logger.info(f'Check accuracy: atol is larger than {atol}.')
except Exception as e:
logging.error(f'FAIL TO RUN ONNX: {e}')
logger.error(f'FAIL TO RUN ONNX: {e}')
continue
print('Finished.')

Loading…
Cancel
Save