1 changed files with 121 additions and 0 deletions
@ -0,0 +1,121 @@ |
|||
import time |
|||
|
|||
import numpy |
|||
import torch |
|||
from towhee import ops |
|||
import onnx |
|||
import onnxruntime |
|||
from tqdm import tqdm |
|||
|
|||
from rerank import ReRank |
|||
|
|||
test_query = 'abc' |
|||
test_docs = ['123', 'ABC', 'ABCabc'] |
|||
atol = 1e-3 |
|||
|
|||
try_times = 500 |
|||
batch_size_list = [2, 4, 8, 16, 32, 64] |
|||
|
|||
model_name_list = [ |
|||
# q2p models: |
|||
'cross-encoder/ms-marco-TinyBERT-L-2-v2', |
|||
'cross-encoder/ms-marco-MiniLM-L-2-v2', |
|||
'cross-encoder/ms-marco-MiniLM-L-4-v2', |
|||
'cross-encoder/ms-marco-MiniLM-L-6-v2', |
|||
'cross-encoder/ms-marco-MiniLM-L-12-v2', |
|||
'cross-encoder/ms-marco-TinyBERT-L-2', |
|||
'cross-encoder/ms-marco-TinyBERT-L-4', |
|||
'cross-encoder/ms-marco-TinyBERT-L-6', |
|||
'cross-encoder/ms-marco-electra-base', |
|||
'nboost/pt-tinybert-msmarco', |
|||
'nboost/pt-bert-base-uncased-msmarco', |
|||
'nboost/pt-bert-large-msmarco', |
|||
'Capreolus/electra-base-msmarco', |
|||
'amberoad/bert-multilingual-passage-reranking-msmarco', |
|||
|
|||
# q2q models: |
|||
'cross-encoder/quora-distilroberta-base', |
|||
'cross-encoder/quora-roberta-base', |
|||
'cross-encoder/quora-roberta-large', |
|||
] |
|||
|
|||
|
|||
for name in tqdm(model_name_list): |
|||
print('######################\nname=', name) |
|||
|
|||
### Test python qps |
|||
for device in ['cpu', 'cuda:3']: |
|||
# op = ReRank(model_name=name, threshold=0, device='cpu') |
|||
op = ops.rerank(model_name=name, threshold=0, device=device) |
|||
|
|||
for batch_size in batch_size_list: |
|||
qps = [] |
|||
for ind, _ in enumerate(range(try_times)): |
|||
start = time.time() |
|||
out = op(test_query, ['dump input'] * batch_size) |
|||
end = time.time() |
|||
if ind == 0: |
|||
continue |
|||
q = 1 / (end - start) |
|||
qps.append(q) |
|||
print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') |
|||
|
|||
|
|||
|
|||
### Test onnx checking |
|||
op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op() |
|||
out1 = op(test_query, test_docs) |
|||
scores1 = out1[1] |
|||
onnx_path = str(op.save_model(format='onnx')) |
|||
onnx_model = onnx.load(onnx_path) |
|||
onnx.checker.check_model(onnx_model) |
|||
|
|||
batch = [(test_query, doc) for doc in test_docs] |
|||
texts = [[] for _ in range(len(batch[0]))] |
|||
|
|||
for example in batch: |
|||
for idx, text in enumerate(example): |
|||
texts[idx].append(text.strip()) |
|||
|
|||
|
|||
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) |
|||
|
|||
inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", |
|||
max_length=op.max_length) |
|||
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] |
|||
scores2 = op.post_proc(torch.from_numpy(out2)) |
|||
scores2 = sorted(scores2, reverse=True) |
|||
assert numpy.allclose(scores1, scores2, atol=atol) is True |
|||
|
|||
|
|||
### Test onnx qps |
|||
for batch_size in batch_size_list: |
|||
qps = [] |
|||
model_qps = [] |
|||
for ind, _ in enumerate(range(try_times)): |
|||
start = time.time() |
|||
batch = [(test_query, doc) for doc in ['dump input'] * batch_size] |
|||
texts = [[] for _ in range(len(batch[0]))] |
|||
|
|||
for example in batch: |
|||
for idx, text in enumerate(example): |
|||
texts[idx].append(text.strip()) |
|||
|
|||
inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", |
|||
max_length=op.max_length) |
|||
model_start = time.time() |
|||
out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] |
|||
model_end = time.time() |
|||
scores2 = op.post_proc(torch.from_numpy(out2)) |
|||
scores2 = sorted(scores2, reverse=True) |
|||
end = time.time() |
|||
if ind == 0: |
|||
continue |
|||
q = 1 / (end - start) |
|||
model_q = 1 / (model_end - model_start) |
|||
qps.append(q) |
|||
model_qps.append(model_q) |
|||
print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') |
|||
print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}') |
|||
|
|||
|
Loading…
Reference in new issue