logo
rerank
repo-copy-icon

copied

You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

121 lines
4.1 KiB

import time
import numpy
import torch
from towhee import ops
import onnx
import onnxruntime
from tqdm import tqdm
from rerank import ReRank
test_query = 'abc'
test_docs = ['123', 'ABC', 'ABCabc']
atol = 1e-3
try_times = 500
batch_size_list = [2, 4, 8, 16, 32, 64]
model_name_list = [
# q2p models:
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-2-v2',
'cross-encoder/ms-marco-MiniLM-L-4-v2',
'cross-encoder/ms-marco-MiniLM-L-6-v2',
'cross-encoder/ms-marco-MiniLM-L-12-v2',
'cross-encoder/ms-marco-TinyBERT-L-2',
'cross-encoder/ms-marco-TinyBERT-L-4',
'cross-encoder/ms-marco-TinyBERT-L-6',
'cross-encoder/ms-marco-electra-base',
'nboost/pt-tinybert-msmarco',
'nboost/pt-bert-base-uncased-msmarco',
'nboost/pt-bert-large-msmarco',
'Capreolus/electra-base-msmarco',
'amberoad/bert-multilingual-passage-reranking-msmarco',
# q2q models:
'cross-encoder/quora-distilroberta-base',
'cross-encoder/quora-roberta-base',
'cross-encoder/quora-roberta-large',
]
for name in tqdm(model_name_list):
print('######################\nname=', name)
### Test python qps
for device in ['cpu', 'cuda:3']:
# op = ReRank(model_name=name, threshold=0, device='cpu')
op = ops.rerank(model_name=name, threshold=0, device=device)
for batch_size in batch_size_list:
qps = []
for ind, _ in enumerate(range(try_times)):
start = time.time()
out = op(test_query, ['dump input'] * batch_size)
end = time.time()
if ind == 0:
continue
q = 1 / (end - start)
qps.append(q)
print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}')
### Test onnx checking
op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op()
out1 = op(test_query, test_docs)
scores1 = out1[1]
onnx_path = str(op.save_model(format='onnx'))
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
batch = [(test_query, doc) for doc in test_docs]
texts = [[] for _ in range(len(batch[0]))]
for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())
inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np",
max_length=op.max_length)
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
scores2 = op.post_proc(torch.from_numpy(out2))
scores2 = sorted(scores2, reverse=True)
assert numpy.allclose(scores1, scores2, atol=atol) is True
### Test onnx qps
for batch_size in batch_size_list:
qps = []
model_qps = []
for ind, _ in enumerate(range(try_times)):
start = time.time()
batch = [(test_query, doc) for doc in ['dump input'] * batch_size]
texts = [[] for _ in range(len(batch[0]))]
for example in batch:
for idx, text in enumerate(example):
texts[idx].append(text.strip())
inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np",
max_length=op.max_length)
model_start = time.time()
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
model_end = time.time()
scores2 = op.post_proc(torch.from_numpy(out2))
scores2 = sorted(scores2, reverse=True)
end = time.time()
if ind == 0:
continue
q = 1 / (end - start)
model_q = 1 / (model_end - model_start)
qps.append(q)
model_qps.append(model_q)
print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}')
print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}')