logo
Browse Source

Add qps_test

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
66446ded34
  1. 111
      benchmark/qps_test.py
  2. 7
      nn_fingerprint.py

111
benchmark/qps_test.py

@ -0,0 +1,111 @@
import towhee
from towhee.dc2 import pipe, ops
from towhee import triton_client
import onnxruntime
import numpy
import torch
from statistics import mean
import time
import argparse
import os
import re
import warnings
import logging
warnings.filterwarnings("ignore", category=UserWarning)
parser = argparse.ArgumentParser()
parser.add_argument('--pipe', action='store_true')
parser.add_argument('--triton', action='store_true')
parser.add_argument('--onnx', action='store_true')
parser.add_argument('--atol', type=float, default=1e-3)
parser.add_argument('--num', type=int, default=100)
parser.add_argument('--device', type=str, default='cpu')
args = parser.parse_args()
p = (
pipe.input('path')
.map('path', 'frame', ops.audio_decode.ffmpeg())
.map('frame', 'vecs', ops.audio_embedding.nnfp(device=args.device))
.output('vecs')
)
data = '1sec.wav'
out1 = p(data).get()[0]
print('Pipe: OK', out1.shape)
if args.num and args.pipe:
qps = []
for _ in range(10):
start = time.time()
# for _ in range(args.num):
# p(data)
p.batch([data] * args.num)
end = time.time()
q = args.num / (end - start)
qps.append(q)
print('Pipe qps:', mean(qps))
if args.triton:
client = triton_client.Client(url='localhost:8000')
out2 = client(data)[0][0]
print('Triton: OK')
if numpy.allclose(out1, out2, atol=args.atol):
print('Check accuracy: OK')
else:
max_diff = numpy.abs(out1 - out2).max()
min_diff = numpy.abs(out1 - out2).min()
mean_diff = numpy.abs(out1 - out2).mean()
print(f'Check accuracy: atol is larger than {args.atol}.')
print(f'Maximum absolute difference is {max_diff}.')
print(f'Minimum absolute difference is {min_diff}.')
print(f'Mean difference is {mean_diff}.')
if args.num:
qps = []
for _ in range(10):
start = time.time()
client.batch([data] * args.num)
end = time.time()
q = args.num / (end - start)
qps.append(q)
print('Triton qps:', mean(qps))
if args.onnx:
op = ops.audio_embedding.nnfp(device=args.device).get_op()
decoder = ops.audio_decode.ffmpeg()
# if not os.path.exists('test.onnx'):
op.save_model('onnx', 'test.onnx')
sess = onnxruntime.InferenceSession('test.onnx',
providers=['CUDAExecutionProvider'])
inputs = [x for x in decoder(data)]
inputs = op.preprocess(inputs)
out3 = sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()})
print('Onnx: OK', out3[0].shape)
if numpy.allclose(out1, out3, atol=args.atol):
print('Check accuracy: OK')
else:
max_diff = numpy.abs(out1 - out3).max()
min_diff = numpy.abs(out1 - out3).min()
mean_diff = numpy.abs(out1 - out3).mean()
print(f'Check accuracy: atol is larger than {args.atol}.')
print(f'Maximum absolute difference is {max_diff}.')
print(f'Minimum absolute difference is {min_diff}.')
print(f'Mean difference is {mean_diff}.')
if args.num:
qps = []
for _ in range(10):
start = time.time()
for _ in range(args.num):
inputs = [x for x in decoder(data)]
inputs = op.preprocess(inputs)
sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()})
end = time.time()
q = args.num / (end - start)
qps.append(q)
print('Onnx qps:', mean(qps))

7
nn_fingerprint.py

@ -34,6 +34,7 @@ from .configs import default_params, hop25_params, distill_params
warnings.filterwarnings('ignore')
log = logging.getLogger('nnfp_op')
log.setLevel(logging.ERROR)
# @accelerate
@ -199,7 +200,7 @@ class NNFingerprint(NNOperator):
raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand(
(1,) + (self.params['n_mels'], self.params['u'])
).to(self.device)
)
if format == 'pytorch':
torch.save(self._model, path)
elif format == 'torchscript':
@ -210,14 +211,14 @@ class NNFingerprint(NNOperator):
log.warning(
'Failed to directly export as torchscript.'
'Using dummy input in shape of %s now.', dummy_input.shape)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model.to('cpu'), dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error('Fail to save as torchscript: %s.', e)
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
try:
torch.onnx.export(self._model,
torch.onnx.export(self._model.to('cpu'),
dummy_input,
path,
input_names=['input'],

Loading…
Cancel
Save