import time import numpy import torch from towhee import ops import onnx import onnxruntime from tqdm import tqdm from rerank import ReRank test_query = 'abc' test_docs = ['123', 'ABC', 'ABCabc'] atol = 1e-3 try_times = 500 batch_size_list = [2, 4, 8, 16, 32, 64] model_name_list = [ # q2p models: 'cross-encoder/ms-marco-TinyBERT-L-2-v2', 'cross-encoder/ms-marco-MiniLM-L-2-v2', 'cross-encoder/ms-marco-MiniLM-L-4-v2', 'cross-encoder/ms-marco-MiniLM-L-6-v2', 'cross-encoder/ms-marco-MiniLM-L-12-v2', 'cross-encoder/ms-marco-TinyBERT-L-2', 'cross-encoder/ms-marco-TinyBERT-L-4', 'cross-encoder/ms-marco-TinyBERT-L-6', 'cross-encoder/ms-marco-electra-base', 'nboost/pt-tinybert-msmarco', 'nboost/pt-bert-base-uncased-msmarco', 'nboost/pt-bert-large-msmarco', 'Capreolus/electra-base-msmarco', 'amberoad/bert-multilingual-passage-reranking-msmarco', # q2q models: 'cross-encoder/quora-distilroberta-base', 'cross-encoder/quora-roberta-base', 'cross-encoder/quora-roberta-large', ] for name in tqdm(model_name_list): print('######################\nname=', name) ### Test python qps for device in ['cpu', 'cuda:3']: # op = ReRank(model_name=name, threshold=0, device='cpu') op = ops.rerank(model_name=name, threshold=0, device=device) for batch_size in batch_size_list: qps = [] for ind, _ in enumerate(range(try_times)): start = time.time() out = op(test_query, ['dump input'] * batch_size) end = time.time() if ind == 0: continue q = 1 / (end - start) qps.append(q) print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') ### Test onnx checking op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op() out1 = op(test_query, test_docs) scores1 = out1[1] onnx_path = str(op.save_model(format='onnx')) onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) batch = [(test_query, doc) for doc in test_docs] texts = [[] for _ in range(len(batch[0]))] for example in batch: for idx, text in enumerate(example): texts[idx].append(text.strip()) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", max_length=op.max_length) out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] scores2 = op.post_proc(torch.from_numpy(out2)) scores2 = sorted(scores2, reverse=True) assert numpy.allclose(scores1, scores2, atol=atol) is True ### Test onnx qps for batch_size in batch_size_list: qps = [] model_qps = [] for ind, _ in enumerate(range(try_times)): start = time.time() batch = [(test_query, doc) for doc in ['dump input'] * batch_size] texts = [[] for _ in range(len(batch[0]))] for example in batch: for idx, text in enumerate(example): texts[idx].append(text.strip()) inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", max_length=op.max_length) model_start = time.time() out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] model_end = time.time() scores2 = op.post_proc(torch.from_numpy(out2)) scores2 = sorted(scores2, reverse=True) end = time.time() if ind == 0: continue q = 1 / (end - start) model_q = 1 / (model_end - model_start) qps.append(q) model_qps.append(model_q) print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}')