nnfp
copied
Jael Gu
2 years ago
2 changed files with 115 additions and 3 deletions
@ -0,0 +1,111 @@ |
|||
import towhee |
|||
from towhee.dc2 import pipe, ops |
|||
from towhee import triton_client |
|||
|
|||
import onnxruntime |
|||
import numpy |
|||
import torch |
|||
from statistics import mean |
|||
|
|||
import time |
|||
import argparse |
|||
|
|||
import os |
|||
import re |
|||
import warnings |
|||
import logging |
|||
|
|||
warnings.filterwarnings("ignore", category=UserWarning) |
|||
|
|||
parser = argparse.ArgumentParser() |
|||
parser.add_argument('--pipe', action='store_true') |
|||
parser.add_argument('--triton', action='store_true') |
|||
parser.add_argument('--onnx', action='store_true') |
|||
parser.add_argument('--atol', type=float, default=1e-3) |
|||
parser.add_argument('--num', type=int, default=100) |
|||
parser.add_argument('--device', type=str, default='cpu') |
|||
args = parser.parse_args() |
|||
|
|||
p = ( |
|||
pipe.input('path') |
|||
.map('path', 'frame', ops.audio_decode.ffmpeg()) |
|||
.map('frame', 'vecs', ops.audio_embedding.nnfp(device=args.device)) |
|||
.output('vecs') |
|||
) |
|||
|
|||
data = '1sec.wav' |
|||
out1 = p(data).get()[0] |
|||
print('Pipe: OK', out1.shape) |
|||
|
|||
if args.num and args.pipe: |
|||
qps = [] |
|||
for _ in range(10): |
|||
start = time.time() |
|||
# for _ in range(args.num): |
|||
# p(data) |
|||
p.batch([data] * args.num) |
|||
end = time.time() |
|||
q = args.num / (end - start) |
|||
qps.append(q) |
|||
print('Pipe qps:', mean(qps)) |
|||
|
|||
if args.triton: |
|||
client = triton_client.Client(url='localhost:8000') |
|||
out2 = client(data)[0][0] |
|||
print('Triton: OK') |
|||
|
|||
if numpy.allclose(out1, out2, atol=args.atol): |
|||
print('Check accuracy: OK') |
|||
else: |
|||
max_diff = numpy.abs(out1 - out2).max() |
|||
min_diff = numpy.abs(out1 - out2).min() |
|||
mean_diff = numpy.abs(out1 - out2).mean() |
|||
print(f'Check accuracy: atol is larger than {args.atol}.') |
|||
print(f'Maximum absolute difference is {max_diff}.') |
|||
print(f'Minimum absolute difference is {min_diff}.') |
|||
print(f'Mean difference is {mean_diff}.') |
|||
|
|||
if args.num: |
|||
qps = [] |
|||
for _ in range(10): |
|||
start = time.time() |
|||
client.batch([data] * args.num) |
|||
end = time.time() |
|||
q = args.num / (end - start) |
|||
qps.append(q) |
|||
print('Triton qps:', mean(qps)) |
|||
|
|||
if args.onnx: |
|||
op = ops.audio_embedding.nnfp(device=args.device).get_op() |
|||
decoder = ops.audio_decode.ffmpeg() |
|||
# if not os.path.exists('test.onnx'): |
|||
op.save_model('onnx', 'test.onnx') |
|||
sess = onnxruntime.InferenceSession('test.onnx', |
|||
providers=['CUDAExecutionProvider']) |
|||
inputs = [x for x in decoder(data)] |
|||
inputs = op.preprocess(inputs) |
|||
out3 = sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()}) |
|||
print('Onnx: OK', out3[0].shape) |
|||
if numpy.allclose(out1, out3, atol=args.atol): |
|||
print('Check accuracy: OK') |
|||
else: |
|||
max_diff = numpy.abs(out1 - out3).max() |
|||
min_diff = numpy.abs(out1 - out3).min() |
|||
mean_diff = numpy.abs(out1 - out3).mean() |
|||
print(f'Check accuracy: atol is larger than {args.atol}.') |
|||
print(f'Maximum absolute difference is {max_diff}.') |
|||
print(f'Minimum absolute difference is {min_diff}.') |
|||
print(f'Mean difference is {mean_diff}.') |
|||
|
|||
if args.num: |
|||
qps = [] |
|||
for _ in range(10): |
|||
start = time.time() |
|||
for _ in range(args.num): |
|||
inputs = [x for x in decoder(data)] |
|||
inputs = op.preprocess(inputs) |
|||
sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()}) |
|||
end = time.time() |
|||
q = args.num / (end - start) |
|||
qps.append(q) |
|||
print('Onnx qps:', mean(qps)) |
Loading…
Reference in new issue