You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
122 lines
4.1 KiB
122 lines
4.1 KiB
2 years ago
|
import time
|
||
|
|
||
|
import numpy
|
||
|
import torch
|
||
|
from towhee import ops
|
||
|
import onnx
|
||
|
import onnxruntime
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
from rerank import ReRank
|
||
|
|
||
|
test_query = 'abc'
|
||
|
test_docs = ['123', 'ABC', 'ABCabc']
|
||
|
atol = 1e-3
|
||
|
|
||
|
try_times = 500
|
||
|
batch_size_list = [2, 4, 8, 16, 32, 64]
|
||
|
|
||
|
model_name_list = [
|
||
|
# q2p models:
|
||
|
'cross-encoder/ms-marco-TinyBERT-L-2-v2',
|
||
|
'cross-encoder/ms-marco-MiniLM-L-2-v2',
|
||
|
'cross-encoder/ms-marco-MiniLM-L-4-v2',
|
||
|
'cross-encoder/ms-marco-MiniLM-L-6-v2',
|
||
|
'cross-encoder/ms-marco-MiniLM-L-12-v2',
|
||
|
'cross-encoder/ms-marco-TinyBERT-L-2',
|
||
|
'cross-encoder/ms-marco-TinyBERT-L-4',
|
||
|
'cross-encoder/ms-marco-TinyBERT-L-6',
|
||
|
'cross-encoder/ms-marco-electra-base',
|
||
|
'nboost/pt-tinybert-msmarco',
|
||
|
'nboost/pt-bert-base-uncased-msmarco',
|
||
|
'nboost/pt-bert-large-msmarco',
|
||
|
'Capreolus/electra-base-msmarco',
|
||
|
'amberoad/bert-multilingual-passage-reranking-msmarco',
|
||
|
|
||
|
# q2q models:
|
||
|
'cross-encoder/quora-distilroberta-base',
|
||
|
'cross-encoder/quora-roberta-base',
|
||
|
'cross-encoder/quora-roberta-large',
|
||
|
]
|
||
|
|
||
|
|
||
|
for name in tqdm(model_name_list):
|
||
|
print('######################\nname=', name)
|
||
|
|
||
|
### Test python qps
|
||
|
for device in ['cpu', 'cuda:3']:
|
||
|
# op = ReRank(model_name=name, threshold=0, device='cpu')
|
||
|
op = ops.rerank(model_name=name, threshold=0, device=device)
|
||
|
|
||
|
for batch_size in batch_size_list:
|
||
|
qps = []
|
||
|
for ind, _ in enumerate(range(try_times)):
|
||
|
start = time.time()
|
||
|
out = op(test_query, ['dump input'] * batch_size)
|
||
|
end = time.time()
|
||
|
if ind == 0:
|
||
|
continue
|
||
|
q = 1 / (end - start)
|
||
|
qps.append(q)
|
||
|
print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}')
|
||
|
|
||
|
|
||
|
|
||
|
### Test onnx checking
|
||
|
op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op()
|
||
|
out1 = op(test_query, test_docs)
|
||
|
scores1 = out1[1]
|
||
|
onnx_path = str(op.save_model(format='onnx'))
|
||
|
onnx_model = onnx.load(onnx_path)
|
||
|
onnx.checker.check_model(onnx_model)
|
||
|
|
||
|
batch = [(test_query, doc) for doc in test_docs]
|
||
|
texts = [[] for _ in range(len(batch[0]))]
|
||
|
|
||
|
for example in batch:
|
||
|
for idx, text in enumerate(example):
|
||
|
texts[idx].append(text.strip())
|
||
|
|
||
|
|
||
|
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers())
|
||
|
|
||
|
inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np",
|
||
|
max_length=op.max_length)
|
||
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
|
||
|
scores2 = op.post_proc(torch.from_numpy(out2))
|
||
|
scores2 = sorted(scores2, reverse=True)
|
||
|
assert numpy.allclose(scores1, scores2, atol=atol) is True
|
||
|
|
||
|
|
||
|
### Test onnx qps
|
||
|
for batch_size in batch_size_list:
|
||
|
qps = []
|
||
|
model_qps = []
|
||
|
for ind, _ in enumerate(range(try_times)):
|
||
|
start = time.time()
|
||
|
batch = [(test_query, doc) for doc in ['dump input'] * batch_size]
|
||
|
texts = [[] for _ in range(len(batch[0]))]
|
||
|
|
||
|
for example in batch:
|
||
|
for idx, text in enumerate(example):
|
||
|
texts[idx].append(text.strip())
|
||
|
|
||
|
inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np",
|
||
|
max_length=op.max_length)
|
||
|
model_start = time.time()
|
||
|
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0]
|
||
|
model_end = time.time()
|
||
|
scores2 = op.post_proc(torch.from_numpy(out2))
|
||
|
scores2 = sorted(scores2, reverse=True)
|
||
|
end = time.time()
|
||
|
if ind == 0:
|
||
|
continue
|
||
|
q = 1 / (end - start)
|
||
|
model_q = 1 / (model_end - model_start)
|
||
|
qps.append(q)
|
||
|
model_qps.append(model_q)
|
||
|
print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}')
|
||
|
print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}')
|
||
|
|
||
|
|