From 66446ded34a0c2c31a4611a694c6d22fc41f4929 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 6 Feb 2023 14:57:28 +0800 Subject: [PATCH] Add qps_test Signed-off-by: Jael Gu --- benchmark/qps_test.py | 111 ++++++++++++++++++++++++++++++++++++++++++ nn_fingerprint.py | 7 +-- 2 files changed, 115 insertions(+), 3 deletions(-) create mode 100644 benchmark/qps_test.py diff --git a/benchmark/qps_test.py b/benchmark/qps_test.py new file mode 100644 index 0000000..7a9f092 --- /dev/null +++ b/benchmark/qps_test.py @@ -0,0 +1,111 @@ +import towhee +from towhee.dc2 import pipe, ops +from towhee import triton_client + +import onnxruntime +import numpy +import torch +from statistics import mean + +import time +import argparse + +import os +import re +import warnings +import logging + +warnings.filterwarnings("ignore", category=UserWarning) + +parser = argparse.ArgumentParser() +parser.add_argument('--pipe', action='store_true') +parser.add_argument('--triton', action='store_true') +parser.add_argument('--onnx', action='store_true') +parser.add_argument('--atol', type=float, default=1e-3) +parser.add_argument('--num', type=int, default=100) +parser.add_argument('--device', type=str, default='cpu') +args = parser.parse_args() + +p = ( + pipe.input('path') + .map('path', 'frame', ops.audio_decode.ffmpeg()) + .map('frame', 'vecs', ops.audio_embedding.nnfp(device=args.device)) + .output('vecs') +) + +data = '1sec.wav' +out1 = p(data).get()[0] +print('Pipe: OK', out1.shape) + +if args.num and args.pipe: + qps = [] + for _ in range(10): + start = time.time() + # for _ in range(args.num): + # p(data) + p.batch([data] * args.num) + end = time.time() + q = args.num / (end - start) + qps.append(q) + print('Pipe qps:', mean(qps)) + +if args.triton: + client = triton_client.Client(url='localhost:8000') + out2 = client(data)[0][0] + print('Triton: OK') + + if numpy.allclose(out1, out2, atol=args.atol): + print('Check accuracy: OK') + else: + max_diff = numpy.abs(out1 - out2).max() + min_diff = numpy.abs(out1 - out2).min() + mean_diff = numpy.abs(out1 - out2).mean() + print(f'Check accuracy: atol is larger than {args.atol}.') + print(f'Maximum absolute difference is {max_diff}.') + print(f'Minimum absolute difference is {min_diff}.') + print(f'Mean difference is {mean_diff}.') + + if args.num: + qps = [] + for _ in range(10): + start = time.time() + client.batch([data] * args.num) + end = time.time() + q = args.num / (end - start) + qps.append(q) + print('Triton qps:', mean(qps)) + +if args.onnx: + op = ops.audio_embedding.nnfp(device=args.device).get_op() + decoder = ops.audio_decode.ffmpeg() + # if not os.path.exists('test.onnx'): + op.save_model('onnx', 'test.onnx') + sess = onnxruntime.InferenceSession('test.onnx', + providers=['CUDAExecutionProvider']) + inputs = [x for x in decoder(data)] + inputs = op.preprocess(inputs) + out3 = sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()}) + print('Onnx: OK', out3[0].shape) + if numpy.allclose(out1, out3, atol=args.atol): + print('Check accuracy: OK') + else: + max_diff = numpy.abs(out1 - out3).max() + min_diff = numpy.abs(out1 - out3).min() + mean_diff = numpy.abs(out1 - out3).mean() + print(f'Check accuracy: atol is larger than {args.atol}.') + print(f'Maximum absolute difference is {max_diff}.') + print(f'Minimum absolute difference is {min_diff}.') + print(f'Mean difference is {mean_diff}.') + + if args.num: + qps = [] + for _ in range(10): + start = time.time() + for _ in range(args.num): + inputs = [x for x in decoder(data)] + inputs = op.preprocess(inputs) + sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()}) + end = time.time() + q = args.num / (end - start) + qps.append(q) + print('Onnx qps:', mean(qps)) diff --git a/nn_fingerprint.py b/nn_fingerprint.py index dbefb44..d71f6d7 100644 --- a/nn_fingerprint.py +++ b/nn_fingerprint.py @@ -34,6 +34,7 @@ from .configs import default_params, hop25_params, distill_params warnings.filterwarnings('ignore') log = logging.getLogger('nnfp_op') +log.setLevel(logging.ERROR) # @accelerate @@ -199,7 +200,7 @@ class NNFingerprint(NNOperator): raise ValueError(f'Invalid format {format}.') dummy_input = torch.rand( (1,) + (self.params['n_mels'], self.params['u']) - ).to(self.device) + ) if format == 'pytorch': torch.save(self._model, path) elif format == 'torchscript': @@ -210,14 +211,14 @@ class NNFingerprint(NNOperator): log.warning( 'Failed to directly export as torchscript.' 'Using dummy input in shape of %s now.', dummy_input.shape) - jit_model = torch.jit.trace(self._model, dummy_input, strict=False) + jit_model = torch.jit.trace(self._model.to('cpu'), dummy_input, strict=False) torch.jit.save(jit_model, path) except Exception as e: log.error('Fail to save as torchscript: %s.', e) raise RuntimeError(f'Fail to save as torchscript: {e}.') elif format == 'onnx': try: - torch.onnx.export(self._model, + torch.onnx.export(self._model.to('cpu'), dummy_input, path, input_names=['input'],