import time

import numpy
import torch
from towhee import ops
import onnx
import onnxruntime
from tqdm import tqdm

from rerank import ReRank

test_query = 'abc'
test_docs = ['123', 'ABC', 'ABCabc']
atol = 1e-3

try_times = 500
batch_size_list = [2, 4, 8, 16, 32, 64]

model_name_list = [
    # q2p models:
    'cross-encoder/ms-marco-TinyBERT-L-2-v2',
    'cross-encoder/ms-marco-MiniLM-L-2-v2',
    'cross-encoder/ms-marco-MiniLM-L-4-v2',
    'cross-encoder/ms-marco-MiniLM-L-6-v2',
    'cross-encoder/ms-marco-MiniLM-L-12-v2',
    'cross-encoder/ms-marco-TinyBERT-L-2',
    'cross-encoder/ms-marco-TinyBERT-L-4',
    'cross-encoder/ms-marco-TinyBERT-L-6',
    'cross-encoder/ms-marco-electra-base',
    'nboost/pt-tinybert-msmarco',
    'nboost/pt-bert-base-uncased-msmarco',
    'nboost/pt-bert-large-msmarco',
    'Capreolus/electra-base-msmarco',
    'amberoad/bert-multilingual-passage-reranking-msmarco',

    # q2q models:
    'cross-encoder/quora-distilroberta-base',
    'cross-encoder/quora-roberta-base',
    'cross-encoder/quora-roberta-large',
]


for name in tqdm(model_name_list):
    print('######################\nname=', name)

    ### Test python qps
    for device in ['cpu', 'cuda:3']:
    # op = ReRank(model_name=name, threshold=0, device='cpu')
        op = ops.rerank(model_name=name, threshold=0, device=device)

        for batch_size in batch_size_list:
            qps = []
            for ind, _ in enumerate(range(try_times)):
                start = time.time()
                out = op(test_query, ['dump input'] * batch_size)
                end = time.time()
                if ind == 0:
                    continue
                q = 1 / (end - start)
                qps.append(q)
            print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}')



    ### Test onnx checking
    op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op()
    out1 = op(test_query, test_docs)
    scores1 = out1[1]
    onnx_path = str(op.save_model(format='onnx'))
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)

    batch = [(test_query, doc) for doc in test_docs]
    texts = [[] for _ in range(len(batch[0]))]

    for example in batch:
        for idx, text in enumerate(example):
            texts[idx].append(text.strip())


    sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())

    inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np",
                               max_length=op.max_length)
    out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
    scores2 = op.post_proc(torch.from_numpy(out2))
    scores2 = sorted(scores2, reverse=True)
    assert numpy.allclose(scores1, scores2, atol=atol) is True


    ### Test onnx qps
    for batch_size in batch_size_list:
        qps = []
        model_qps = []
        for ind, _ in enumerate(range(try_times)):
            start = time.time()
            batch = [(test_query, doc) for doc in ['dump input'] * batch_size]
            texts = [[] for _ in range(len(batch[0]))]

            for example in batch:
                for idx, text in enumerate(example):
                    texts[idx].append(text.strip())

            inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np",
                                  max_length=op.max_length)
            model_start = time.time()
            out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
            model_end = time.time()
            scores2 = op.post_proc(torch.from_numpy(out2))
            scores2 = sorted(scores2, reverse=True)
            end = time.time()
            if ind == 0:
                continue
            q = 1 / (end - start)
            model_q = 1 / (model_end - model_start)
            qps.append(q)
            model_qps.append(model_q)
        print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}')
        print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}')