diff --git a/benchmark/qps_test.py b/benchmark/qps_test.py index 1c326ee..aa9f9b9 100644 --- a/benchmark/qps_test.py +++ b/benchmark/qps_test.py @@ -104,7 +104,7 @@ if args.onnx: out3 = sess.run(None, input_feed=dict(inputs)) torch_inputs = {} for k, v in inputs.items(): - torch_inputs[k] = torch.from_numpy(v).to(device) + torch_inputs[k] = torch.from_numpy(v) out3 = op.post_proc(torch.from_numpy(out3[0]), torch_inputs).cpu().detach().numpy() print('Onnx: OK') if numpy.allclose(out1, out3, atol=args.atol): @@ -125,7 +125,7 @@ if args.onnx: for _ in range(args.num): tokens = op.tokenizer([text], padding=True, truncation=True, return_tensors='pt') outs = sess.run(None, input_feed=dict(inputs)) - op.post_proc(torch.from_numpy(outs[0]).to(device), torch_inputs) + op.post_proc(torch.from_numpy(outs[0]), torch_inputs) end = time.time() q = args.num / (end - start) qps.append(q)