logo
Browse Source

Update qps_test device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
78ff5258cf
  1. 4
      benchmark/qps_test.py

4
benchmark/qps_test.py

@ -104,7 +104,7 @@ if args.onnx:
out3 = sess.run(None, input_feed=dict(inputs)) out3 = sess.run(None, input_feed=dict(inputs))
torch_inputs = {} torch_inputs = {}
for k, v in inputs.items(): for k, v in inputs.items():
torch_inputs[k] = torch.from_numpy(v).to(device)
torch_inputs[k] = torch.from_numpy(v)
out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy() out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy()
print('Onnx: OK') print('Onnx: OK')
if numpy.allclose(out1, out3, atol=args.atol): if numpy.allclose(out1, out3, atol=args.atol):
@ -125,7 +125,7 @@ if args.onnx:
for _ in range(args.num): for _ in range(args.num):
tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt') tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt')
outs = sess.run(None, input_feed=dict(inputs)) outs = sess.run(None, input_feed=dict(inputs))
op.post_proc(torch.from_numpy(outs[0]).to(device), torch_inputs)
op.post_proc(torch.from_numpy(outs[0]), torch_inputs)
end = time.time() end = time.time()
q = args.num / (end - start) q = args.num / (end - start)
qps.append(q) qps.append(q)

Loading…
Cancel
Save