from towhee import ops import torch import numpy import onnx import onnxruntime import os from pathlib import Path import logging import platform import psutil import warnings from transformers import logging as t_logging os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' warnings.filterwarnings("ignore") t_logging.set_verbosity_error() # full_models = AutoTransformers.supported_model_names() # checked_models = AutoTransformers.supported_model_names(format='onnx') # models = [x for x in full_models if x not in checked_models] models = ['distilbert-base-cased', 'sentence-transformers/paraphrase-albert-small-v2'] test_txt = 'hello, world.' atol = 1e-3 log_path = 'transformers_onnx.log' f = open('onnx.csv', 'w+') f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') logger = logging.getLogger('transformers_onnx') logger.setLevel(logging.DEBUG) formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') fh = logging.FileHandler(log_path) fh.setLevel(logging.DEBUG) fh.setFormatter(formatter) logger.addHandler(fh) ch = logging.StreamHandler() ch.setLevel(logging.ERROR) ch.setFormatter(formatter) logger.addHandler(ch) logger.debug(f'machine: {platform.platform()}-{platform.processor()}') logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') logger.debug(f'cpu: {psutil.cpu_count()}') status = None for name in models: logger.info(f'***{name}***') if status: f.write(','.join(status) + '\n') status = [name] + ['fail'] * 5 try: op = ops.sentence_embedding.transformers(model_name=name).get_op() out1 = op(test_txt) logger.info('OP LOADED.') status[1] = 'success' except Exception as e: logger.error(f'FAIL TO LOAD OP: {e}') continue saved_name = op.model_name.replace('/', '-') onnx_path = f'saved/onnx/{saved_name}.onnx' try: op.save_model(model_type='onnx') logger.info('ONNX SAVED.') status[2] = 'success' except Exception as e: logger.error(f'FAIL TO SAVE ONNX: {e}') continue try: try: onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) except Exception: saved_onnx = onnx.load(onnx_path, load_external_data=False) onnx.checker.check_model(saved_onnx) logger.info('ONNX CHECKED.') status[3] = 'success' except Exception as e: logger.error(f'FAIL TO CHECK ONNX: {e}') continue try: sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) inputs = op.tokenizer(test_txt, return_tensors='np') out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] new_inputs = op.tokenizer(test_txt, return_tensors='pt') out2 = op.post_proc(torch.from_numpy(out2), new_inputs) logger.info('ONNX WORKED.') status[4] = 'success' if numpy.allclose(out1, out2, atol=atol): logger.info('Check accuracy: OK') status[5] = 'success' else: logger.info(f'Check accuracy: atol is larger than {atol}.') except Exception as e: logger.error(f'FAIL TO RUN ONNX: {e}') continue if status: f.write(','.join(status) + '\n') print('Finished.')