From ea8020947c41642a388fe5675b2d8e98ffa6d6a2 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 9 Feb 2023 13:49:48 +0800 Subject: [PATCH] Update qps test Signed-off-by: Jael Gu --- benchmark/qps_test.py | 3 ++- benchmark/test_client.py | 24 ++++++++++-------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/benchmark/qps_test.py b/benchmark/qps_test.py index aa9f9b9..9e6753c 100644 --- a/benchmark/qps_test.py +++ b/benchmark/qps_test.py @@ -95,7 +95,8 @@ if args.onnx: # if not os.path.exists('test.onnx'): op.save_model('onnx', 'test.onnx') sess = onnxruntime.InferenceSession('test.onnx', - providers=['CUDAExecutionProvider']) + providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + sess.set_providers(['CUDAExecutionProvider'], [{'device_id': 0 if args.device < 0 else args.device}]) inputs = op.tokenizer([text], padding=True, truncation=True, return_tensors='np') # inputs = {} # for k, v in tokens.items(): diff --git a/benchmark/test_client.py b/benchmark/test_client.py index 7ca35e2..26ec4a4 100644 --- a/benchmark/test_client.py +++ b/benchmark/test_client.py @@ -3,22 +3,18 @@ import sys import time num = int(sys.argv[-1]) -client = triton_client.Client(url='localhost:8000') +data = 'Hello, world.' +client = triton_client.Client('localhost:8000') -data = 'hello' -# data = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.' +# warm up +client.batch([data]) +print('client: ok') -res = client(data) -print(res[0][0].shape) +time.sleep(5) -res = client([data]) -print(len(res[0][0]), res[0][0][0].shape) - -# Or run data with batch +print('test...') start = time.time() -res = client.batch([data] * num, batch_size=8) +client.batch([data] * num, batch_size=8) end = time.time() -print(num / (end-start)) -# print(len(res), res[0][0][0].shape) - -client.close() +print(f'duration: {end - start}') +print(f'qps: {num / (end - start)}')