|
@ -104,7 +104,7 @@ if args.onnx: |
|
|
out3 = sess.run(None, input_feed=dict(inputs)) |
|
|
out3 = sess.run(None, input_feed=dict(inputs)) |
|
|
torch_inputs = {} |
|
|
torch_inputs = {} |
|
|
for k, v in inputs.items(): |
|
|
for k, v in inputs.items(): |
|
|
torch_inputs[k] = torch.from_numpy(v).to(device) |
|
|
|
|
|
|
|
|
torch_inputs[k] = torch.from_numpy(v) |
|
|
out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy() |
|
|
out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy() |
|
|
print('Onnx: OK') |
|
|
print('Onnx: OK') |
|
|
if numpy.allclose(out1, out3, atol=args.atol): |
|
|
if numpy.allclose(out1, out3, atol=args.atol): |
|
@ -125,7 +125,7 @@ if args.onnx: |
|
|
for _ in range(args.num): |
|
|
for _ in range(args.num): |
|
|
tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt') |
|
|
tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt') |
|
|
outs = sess.run(None, input_feed=dict(inputs)) |
|
|
outs = sess.run(None, input_feed=dict(inputs)) |
|
|
op.post_proc(torch.from_numpy(outs[0]).to(device), torch_inputs) |
|
|
|
|
|
|
|
|
op.post_proc(torch.from_numpy(outs[0]), torch_inputs) |
|
|
end = time.time() |
|
|
end = time.time() |
|
|
q = args.num / (end - start) |
|
|
q = args.num / (end - start) |
|
|
qps.append(q) |
|
|
qps.append(q) |
|
|