nnfp
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
112 lines
3.5 KiB
112 lines
3.5 KiB
2 years ago
|
import towhee
|
||
|
from towhee.dc2 import pipe, ops
|
||
|
from towhee import triton_client
|
||
|
|
||
|
import onnxruntime
|
||
|
import numpy
|
||
|
import torch
|
||
|
from statistics import mean
|
||
|
|
||
|
import time
|
||
|
import argparse
|
||
|
|
||
|
import os
|
||
|
import re
|
||
|
import warnings
|
||
|
import logging
|
||
|
|
||
|
warnings.filterwarnings("ignore", category=UserWarning)
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--pipe', action='store_true')
|
||
|
parser.add_argument('--triton', action='store_true')
|
||
|
parser.add_argument('--onnx', action='store_true')
|
||
|
parser.add_argument('--atol', type=float, default=1e-3)
|
||
|
parser.add_argument('--num', type=int, default=100)
|
||
|
parser.add_argument('--device', type=str, default='cpu')
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
p = (
|
||
|
pipe.input('path')
|
||
|
.map('path', 'frame', ops.audio_decode.ffmpeg())
|
||
|
.map('frame', 'vecs', ops.audio_embedding.nnfp(device=args.device))
|
||
|
.output('vecs')
|
||
|
)
|
||
|
|
||
|
data = '1sec.wav'
|
||
|
out1 = p(data).get()[0]
|
||
|
print('Pipe: OK', out1.shape)
|
||
|
|
||
|
if args.num and args.pipe:
|
||
|
qps = []
|
||
|
for _ in range(10):
|
||
|
start = time.time()
|
||
|
# for _ in range(args.num):
|
||
|
# p(data)
|
||
|
p.batch([data] * args.num)
|
||
|
end = time.time()
|
||
|
q = args.num / (end - start)
|
||
|
qps.append(q)
|
||
|
print('Pipe qps:', mean(qps))
|
||
|
|
||
|
if args.triton:
|
||
|
client = triton_client.Client(url='localhost:8000')
|
||
|
out2 = client(data)[0][0]
|
||
|
print('Triton: OK')
|
||
|
|
||
|
if numpy.allclose(out1, out2, atol=args.atol):
|
||
|
print('Check accuracy: OK')
|
||
|
else:
|
||
|
max_diff = numpy.abs(out1 - out2).max()
|
||
|
min_diff = numpy.abs(out1 - out2).min()
|
||
|
mean_diff = numpy.abs(out1 - out2).mean()
|
||
|
print(f'Check accuracy: atol is larger than {args.atol}.')
|
||
|
print(f'Maximum absolute difference is {max_diff}.')
|
||
|
print(f'Minimum absolute difference is {min_diff}.')
|
||
|
print(f'Mean difference is {mean_diff}.')
|
||
|
|
||
|
if args.num:
|
||
|
qps = []
|
||
|
for _ in range(10):
|
||
|
start = time.time()
|
||
|
client.batch([data] * args.num)
|
||
|
end = time.time()
|
||
|
q = args.num / (end - start)
|
||
|
qps.append(q)
|
||
|
print('Triton qps:', mean(qps))
|
||
|
|
||
|
if args.onnx:
|
||
|
op = ops.audio_embedding.nnfp(device=args.device).get_op()
|
||
|
decoder = ops.audio_decode.ffmpeg()
|
||
|
# if not os.path.exists('test.onnx'):
|
||
|
op.save_model('onnx', 'test.onnx')
|
||
|
sess = onnxruntime.InferenceSession('test.onnx',
|
||
|
providers=['CUDAExecutionProvider'])
|
||
|
inputs = [x for x in decoder(data)]
|
||
|
inputs = op.preprocess(inputs)
|
||
|
out3 = sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()})
|
||
|
print('Onnx: OK', out3[0].shape)
|
||
|
if numpy.allclose(out1, out3, atol=args.atol):
|
||
|
print('Check accuracy: OK')
|
||
|
else:
|
||
|
max_diff = numpy.abs(out1 - out3).max()
|
||
|
min_diff = numpy.abs(out1 - out3).min()
|
||
|
mean_diff = numpy.abs(out1 - out3).mean()
|
||
|
print(f'Check accuracy: atol is larger than {args.atol}.')
|
||
|
print(f'Maximum absolute difference is {max_diff}.')
|
||
|
print(f'Minimum absolute difference is {min_diff}.')
|
||
|
print(f'Mean difference is {mean_diff}.')
|
||
|
|
||
|
if args.num:
|
||
|
qps = []
|
||
|
for _ in range(10):
|
||
|
start = time.time()
|
||
|
for _ in range(args.num):
|
||
|
inputs = [x for x in decoder(data)]
|
||
|
inputs = op.preprocess(inputs)
|
||
|
sess.run(None, input_feed={'input': inputs.cpu().detach().numpy()})
|
||
|
end = time.time()
|
||
|
q = args.num / (end - start)
|
||
|
qps.append(q)
|
||
|
print('Onnx qps:', mean(qps))
|