1 changed files with 121 additions and 0 deletions
			
			
		| @ -0,0 +1,121 @@ | |||||
|  | import time | ||||
|  | 
 | ||||
|  | import numpy | ||||
|  | import torch | ||||
|  | from towhee import ops | ||||
|  | import onnx | ||||
|  | import onnxruntime | ||||
|  | from tqdm import tqdm | ||||
|  | 
 | ||||
|  | from rerank import ReRank | ||||
|  | 
 | ||||
|  | test_query = 'abc' | ||||
|  | test_docs = ['123', 'ABC', 'ABCabc'] | ||||
|  | atol = 1e-3 | ||||
|  | 
 | ||||
|  | try_times = 500 | ||||
|  | batch_size_list = [2, 4, 8, 16, 32, 64] | ||||
|  | 
 | ||||
|  | model_name_list = [ | ||||
|  |     # q2p models: | ||||
|  |     'cross-encoder/ms-marco-TinyBERT-L-2-v2', | ||||
|  |     'cross-encoder/ms-marco-MiniLM-L-2-v2', | ||||
|  |     'cross-encoder/ms-marco-MiniLM-L-4-v2', | ||||
|  |     'cross-encoder/ms-marco-MiniLM-L-6-v2', | ||||
|  |     'cross-encoder/ms-marco-MiniLM-L-12-v2', | ||||
|  |     'cross-encoder/ms-marco-TinyBERT-L-2', | ||||
|  |     'cross-encoder/ms-marco-TinyBERT-L-4', | ||||
|  |     'cross-encoder/ms-marco-TinyBERT-L-6', | ||||
|  |     'cross-encoder/ms-marco-electra-base', | ||||
|  |     'nboost/pt-tinybert-msmarco', | ||||
|  |     'nboost/pt-bert-base-uncased-msmarco', | ||||
|  |     'nboost/pt-bert-large-msmarco', | ||||
|  |     'Capreolus/electra-base-msmarco', | ||||
|  |     'amberoad/bert-multilingual-passage-reranking-msmarco', | ||||
|  | 
 | ||||
|  |     # q2q models: | ||||
|  |     'cross-encoder/quora-distilroberta-base', | ||||
|  |     'cross-encoder/quora-roberta-base', | ||||
|  |     'cross-encoder/quora-roberta-large', | ||||
|  | ] | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | for name in tqdm(model_name_list): | ||||
|  |     print('######################\nname=', name) | ||||
|  | 
 | ||||
|  |     ### Test python qps | ||||
|  |     for device in ['cpu', 'cuda:3']: | ||||
|  |     # op = ReRank(model_name=name, threshold=0, device='cpu') | ||||
|  |         op = ops.rerank(model_name=name, threshold=0, device=device) | ||||
|  | 
 | ||||
|  |         for batch_size in batch_size_list: | ||||
|  |             qps = [] | ||||
|  |             for ind, _ in enumerate(range(try_times)): | ||||
|  |                 start = time.time() | ||||
|  |                 out = op(test_query, ['dump input'] * batch_size) | ||||
|  |                 end = time.time() | ||||
|  |                 if ind == 0: | ||||
|  |                     continue | ||||
|  |                 q = 1 / (end - start) | ||||
|  |                 qps.append(q) | ||||
|  |             print(f'device = {device}, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |     ### Test onnx checking | ||||
|  |     op = ops.rerank(model_name=name, threshold=0, device='cpu').get_op() | ||||
|  |     out1 = op(test_query, test_docs) | ||||
|  |     scores1 = out1[1] | ||||
|  |     onnx_path = str(op.save_model(format='onnx')) | ||||
|  |     onnx_model = onnx.load(onnx_path) | ||||
|  |     onnx.checker.check_model(onnx_model) | ||||
|  | 
 | ||||
|  |     batch = [(test_query, doc) for doc in test_docs] | ||||
|  |     texts = [[] for _ in range(len(batch[0]))] | ||||
|  | 
 | ||||
|  |     for example in batch: | ||||
|  |         for idx, text in enumerate(example): | ||||
|  |             texts[idx].append(text.strip()) | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |     sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) | ||||
|  | 
 | ||||
|  |     inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", | ||||
|  |                                max_length=op.max_length) | ||||
|  |     out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] | ||||
|  |     scores2 = op.post_proc(torch.from_numpy(out2)) | ||||
|  |     scores2 = sorted(scores2, reverse=True) | ||||
|  |     assert numpy.allclose(scores1, scores2, atol=atol) is True | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |     ### Test onnx qps | ||||
|  |     for batch_size in batch_size_list: | ||||
|  |         qps = [] | ||||
|  |         model_qps = [] | ||||
|  |         for ind, _ in enumerate(range(try_times)): | ||||
|  |             start = time.time() | ||||
|  |             batch = [(test_query, doc) for doc in ['dump input'] * batch_size] | ||||
|  |             texts = [[] for _ in range(len(batch[0]))] | ||||
|  | 
 | ||||
|  |             for example in batch: | ||||
|  |                 for idx, text in enumerate(example): | ||||
|  |                     texts[idx].append(text.strip()) | ||||
|  | 
 | ||||
|  |             inputs = op.tokenizer(*texts, padding=True, truncation='longest_first', return_tensors="np", | ||||
|  |                                   max_length=op.max_length) | ||||
|  |             model_start = time.time() | ||||
|  |             out2 = sess.run(output_names=['last_hidden_state'], input_feed=dict(inputs))[0] | ||||
|  |             model_end = time.time() | ||||
|  |             scores2 = op.post_proc(torch.from_numpy(out2)) | ||||
|  |             scores2 = sorted(scores2, reverse=True) | ||||
|  |             end = time.time() | ||||
|  |             if ind == 0: | ||||
|  |                 continue | ||||
|  |             q = 1 / (end - start) | ||||
|  |             model_q = 1 / (model_end - model_start) | ||||
|  |             qps.append(q) | ||||
|  |             model_qps.append(model_q) | ||||
|  |         print(f'onnx, batch_size = {batch_size}, mean qps = {sum(qps) / len(qps)}') | ||||
|  |         print(f'model onnx, batch_size = {batch_size}, mean qps = {sum(model_qps) / len(model_qps)}') | ||||
|  | 
 | ||||
|  | 
 | ||||
					Loading…
					
					
				
		Reference in new issue
	
	