diff --git a/auto_transformers.py b/auto_transformers.py index 9182f92..de10457 100644 --- a/auto_transformers.py +++ b/auto_transformers.py @@ -22,11 +22,8 @@ from towhee.operator import NNOperator from towhee import register import warnings -import logging warnings.filterwarnings('ignore') -logging.getLogger("transformers").setLevel(logging.ERROR) -log = logging.getLogger() @register(output_schema=['vec']) diff --git a/test_onnx.py b/test_onnx.py index 077e73e..b586b03 100644 --- a/test_onnx.py +++ b/test_onnx.py @@ -4,52 +4,52 @@ import numpy import onnx import onnxruntime +import os from pathlib import Path -import warnings import logging +import platform +import psutil # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] models = ['distilbert-base-cased'] +atol = 1e-3 +log_path = 'transformers_onnx.log' -warnings.filterwarnings('ignore') +logger = logging.getLogger('transformers_onnx') +logger.setLevel(logging.DEBUG) +formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') +fh = logging.FileHandler(log_path) +fh.setLevel(logging.DEBUG) +fh.setFormatter(formatter) +logger.addHandler(fh) +ch = logging.StreamHandler() +ch.setLevel(logging.ERROR) +ch.setFormatter(formatter) +logger.addHandler(ch) -log_path = './transformers_onnx.log' - -logger = logging.getLogger() - -if Path(log_path).exists: - handler = logging.FileHandler(log_path) -else: - handler = logging.StreamHandler() -logger.handlers.append(handler) - -logging.basicConfig(filename=log_path, - filemode='a', - format='%(asctime)s, %(msecs)d %(name)s %(levelname)s %(message)s', - datefmt='%H:%M:%S', - level=logging.info) - - -logging.info("Test") +logger.debug(f'machine: {platform.platform()}-{platform.processor()}') +logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 **3))}' + f'/{round(psutil.virtual_memory().available / (1024.0 **3))}' + f'/{round(psutil.virtual_memory().total / (1024.0 **3))} GB') +logger.debug(f'cpu: {psutil.cpu_count()}') for name in models: - print(f'Model {name}:') + logger.info(f'***{name}***') saved_name = name.replace('/', '-') - logging.info(f'\nConverting model {name} to {saved_name}:') try: op = AutoTransformers(model_name=name) out1 = op('hello, world.') - logging.info('OP LOADED.') + logger.info('OP LOADED.') except Exception as e: - logging.error(f'FAIL TO LOAD OP: {e}') + logger.error(f'FAIL TO LOAD OP: {e}') continue try: op.save_model(format='onnx') - logging.info('ONNX SAVED.') + logger.info('ONNX SAVED.') except Exception as e: - logging.info(f'FAIL TO SAVE ONNX: {e}') + logger.error(f'FAIL TO SAVE ONNX: {e}') continue try: try: @@ -58,22 +58,22 @@ for name in models: except Exception: saved_onnx = onnx.load(f'saved/onnx/{saved_name}.onnx', load_external_data=False) onnx.checker.check_model(saved_onnx) - logging.info('ONNX CHECKED.') + logger.info('ONNX CHECKED.') except Exception as e: - logging.error(f'FAIL TO CHECK ONNX: {e}') + logger.error(f'FAIL TO CHECK ONNX: {e}') continue try: sess = onnxruntime.InferenceSession(f'saved/onnx/{saved_name}.onnx', providers=onnxruntime.get_available_providers()) inputs = op.tokenizer('hello, world.', return_tensors='np') out2 = sess.run(output_names=["last_hidden_state"], input_feed=dict(inputs)) - logging.info('ONNX WORKED.') - if numpy.allclose(out1, out2, atol=1e-3): - logging.info('Check accuracy: OK') + logger.info('ONNX WORKED.') + if numpy.allclose(out1, out2, atol=atol): + logger.info('Check accuracy: OK') else: - logging.info('Check accuracy: atol is larger than 1e-3.') + logger.info(f'Check accuracy: atol is larger than {atol}.') except Exception as e: - logging.error(f'FAIL TO RUN ONNX: {e}') + logger.error(f'FAIL TO RUN ONNX: {e}') continue print('Finished.')