from towhee import ops
import numpy
import onnx
import onnxruntime

import os
from pathlib import Path
import logging
import platform
import psutil

# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['bert-base-cased', 'distilbert-base-cased']
test_txt = 'hello, world.'
atol = 1e-3
log_path = 'transformers_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')

logger = logging.getLogger('transformers_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)

logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
             f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
             f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')


status = None
for name in models:
    logger.info(f'***{name}***')
    saved_name = name.replace('/', '-')
    onnx_path = f'saved/onnx/{saved_name}.onnx'
    if status:
        f.write(','.join(status) + '\n')
    status = [name] + ['fail'] * 5
    try:
        op = ops.text_embedding.transformers(model_name=name, device='cpu').get_op()
        out1 = op(test_txt)
        logger.info('OP LOADED.')
        status[1] = 'success'
    except Exception as e:
        logger.error(f'FAIL TO LOAD OP: {e}')
        continue
    try:
        op.save_model(format='onnx')
        logger.info('ONNX SAVED.')
        status[2] = 'success'
    except Exception as e:
        logger.error(f'FAIL TO SAVE ONNX: {e}')
        continue
    try:
        try:
            onnx_model = onnx.load(onnx_path)
            onnx.checker.check_model(onnx_model)
        except Exception:
            saved_onnx = onnx.load(onnx_path, load_external_data=False)
            onnx.checker.check_model(saved_onnx)
        logger.info('ONNX CHECKED.')
        status[3] = 'success'
    except Exception as e:
        logger.error(f'FAIL TO CHECK ONNX: {e}')
        continue
    try:
        sess = onnxruntime.InferenceSession(onnx_path,
                                            providers=onnxruntime.get_available_providers())
        inputs = op.tokenizer(test_txt, return_tensors='np')
        out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))
        logger.info('ONNX WORKED.')
        status[4] = 'success'
        if numpy.allclose(out1, out2, atol=atol):
            logger.info('Check accuracy: OK')
            status[5] = 'success'
        else:
            logger.info(f'Check accuracy: atol is larger than {atol}.')
    except Exception as e:
        logger.error(f'FAIL TO RUN ONNX: {e}')
        continue

if status:
    f.write(','.join(status) + '\n')

print('Finished.')