from towhee import ops
import torch
import numpy
import onnx
import onnxruntime

import os
from pathlib import Path
import logging
import platform
import psutil

import warnings
from transformers import logging as t_logging

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings("ignore")
t_logging.set_verbosity_error()

# full_models = AutoTransformers.supported_model_names()
# checked_models = AutoTransformers.supported_model_names(format='onnx')
# models = [x for x in full_models if x not in checked_models]
models = ['distilbert-base-cased', 'sentence-transformers/paraphrase-albert-small-v2']
test_txt = 'hello, world.'
atol = 1e-3
log_path = 'transformers_onnx.log'
f = open('onnx.csv', 'w+')
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n')

logger = logging.getLogger('transformers_onnx')
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fh = logging.FileHandler(log_path)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)
logger.addHandler(fh)
ch = logging.StreamHandler()
ch.setLevel(logging.ERROR)
ch.setFormatter(formatter)
logger.addHandler(ch)

logger.debug(f'machine: {platform.platform()}-{platform.processor()}')
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}'
             f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}'
             f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB')
logger.debug(f'cpu: {psutil.cpu_count()}')


status = None
for name in models:
    logger.info(f'***{name}***')
    saved_name = name.replace('/', '-')
    onnx_path = f'saved/onnx/{saved_name}.onnx'
    if status:
        f.write(','.join(status) + '\n')
    status = [name] + ['fail'] * 5
    try:
        op = ops.sentence_embedding.transformers(model_name=name).get_op()
        out1 = op(test_txt)
        logger.info('OP LOADED.')
        status[1] = 'success'
    except Exception as e:
        logger.error(f'FAIL TO LOAD OP: {e}')
        continue
    try:
        op.save_model(model_type='onnx')
        logger.info('ONNX SAVED.')
        status[2] = 'success'
    except Exception as e:
        logger.error(f'FAIL TO SAVE ONNX: {e}')
        continue
    try:
        try:
            onnx_model = onnx.load(onnx_path)
            onnx.checker.check_model(onnx_model)
        except Exception:
            saved_onnx = onnx.load(onnx_path, load_external_data=False)
            onnx.checker.check_model(saved_onnx)
        logger.info('ONNX CHECKED.')
        status[3] = 'success'
    except Exception as e:
        logger.error(f'FAIL TO CHECK ONNX: {e}')
        continue
    try:
        sess = onnxruntime.InferenceSession(onnx_path,
                                            providers=onnxruntime.get_available_providers())
        inputs = op.tokenizer(test_txt, return_tensors='np')
        out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
        new_inputs = op.tokenizer(test_txt, return_tensors='pt')
        out2 = op.post_proc(torch.from_numpy(out2), new_inputs)
        logger.info('ONNX WORKED.')
        status[4] = 'success'
        if numpy.allclose(out1, out2, atol=atol):
            logger.info('Check accuracy: OK')
            status[5] = 'success'
        else:
            logger.info(f'Check accuracy: atol is larger than {atol}.')
    except Exception as e:
        logger.error(f'FAIL TO RUN ONNX: {e}')
        continue

if status:
    f.write(','.join(status) + '\n')

print('Finished.')