diff --git a/benchmark/run.py b/benchmark/run.py index 2c2d5a3..d702f2e 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -118,7 +118,6 @@ elif args.format == 'onnx': collection_name = collection_name + '_onnx' if not os.path.exists(onnx_path): onnx_path = op.save_model(format='onnx') - sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) @@ -126,7 +125,7 @@ elif args.format == 'onnx': @towhee.register def run_onnx(img): img = img.convert('RGB') - img = op.tfms(img).cpu().detach().numpy() + img = op.tfms(img).unsqueeze(0).cpu().detach().numpy() features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] if len(features.shape) == 4: global_pool = nn.AdaptiveAvgPool2d(1)