From 97f155fcaf3ab398707eb98fadef92b11094a263 Mon Sep 17 00:00:00 2001 From: ChengZi Date: Thu, 20 Jul 2023 17:52:58 +0800 Subject: [PATCH] test onnx and qps Signed-off-by: ChengZi --- test_onnx_and_qps.py | 121 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 test_onnx_and_qps.py diff --git a/test_onnx_and_qps.py b/test_onnx_and_qps.py new file mode 100644 index 0000000..af34903 --- /dev/null +++ b/test_onnx_and_qps.py @@ -0,0 +1,121 @@ +import time + +import numpy +import torch +from towhee import ops +import onnx +import onnxruntime +from tqdm import tqdm + +from rerank import ReRank + +test_query = 'abc' +test_docs = ['123', 'ABC', 'ABCabc'] +atol = 1e-3 + +try_times = 500 +batch_size_list = [2, 4, 8, 16, 32, 64] + +model_name_list = [ + # q2p models: + 'cross-encoder/ms-marco-TinyBERT-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-2-v2', + 'cross-encoder/ms-marco-MiniLM-L-4-v2', + 'cross-encoder/ms-marco-MiniLM-L-6-v2', + 'cross-encoder/ms-marco-MiniLM-L-12-v2', + 'cross-encoder/ms-marco-TinyBERT-L-2', + 'cross-encoder/ms-marco-TinyBERT-L-4', + 'cross-encoder/ms-marco-TinyBERT-L-6', + 'cross-encoder/ms-marco-electra-base', + 'nboost/pt-tinybert-msmarco', + 'nboost/pt-bert-base-uncased-msmarco', + 'nboost/pt-bert-large-msmarco', + 'Capreolus/electra-base-msmarco', + 'amberoad/bert-multilingual-passage-reranking-msmarco', + + # q2q models: + 'cross-encoder/quora-distilroberta-base', + 'cross-encoder/quora-roberta-base', + 'cross-encoder/quora-roberta-large', +] + + +for name in tqdm(model_name_list): + print('######################\nname=', name) + + ### Test python qps + for device in ['cpu', 'cuda:3']: + # op = ReRank(model_name=name, threshold=0, device='cpu') + op = ops.rerank(model_name=name, threshold=0, device=device) + + for batch_size in batch_size_list: + qps = [] + for ind, _ in enumerate(range(try_times)): + start = time.time() + out = op(test_query, ['dump input'] * batch_size) + end = time.time() + if ind == 0: + continue + q = 1 / (end - start) + qps.append(q) + print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') + + + + ### Test onnx checking + op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op() + out1 = op(test_query, test_docs) + scores1 = out1[1] + onnx_path = str(op.save_model(format='onnx')) + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + + batch = [(test_query, doc) for doc in test_docs] + texts = [[] for _ in range(len(batch[0]))] + + for example in batch: + for idx, text in enumerate(example): + texts[idx].append(text.strip()) + + + sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) + + inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", + max_length=op.max_length) + out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] + scores2 = op.post_proc(torch.from_numpy(out2)) + scores2 = sorted(scores2, reverse=True) + assert numpy.allclose(scores1, scores2, atol=atol) is True + + + ### Test onnx qps + for batch_size in batch_size_list: + qps = [] + model_qps = [] + for ind, _ in enumerate(range(try_times)): + start = time.time() + batch = [(test_query, doc) for doc in ['dump input'] * batch_size] + texts = [[] for _ in range(len(batch[0]))] + + for example in batch: + for idx, text in enumerate(example): + texts[idx].append(text.strip()) + + inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", + max_length=op.max_length) + model_start = time.time() + out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] + model_end = time.time() + scores2 = op.post_proc(torch.from_numpy(out2)) + scores2 = sorted(scores2, reverse=True) + end = time.time() + if ind == 0: + continue + q = 1 / (end - start) + model_q = 1 / (model_end - model_start) + qps.append(q) + model_qps.append(model_q) + print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') + print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}') + +