From ea091310ca403faed66313427535869e79128102 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Wed, 14 Dec 2022 19:29:14 +0800 Subject: [PATCH] Fix run.py Signed-off-by: Jael Gu --- benchmark/run.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/benchmark/run.py b/benchmark/run.py index 2c2d5a3..d702f2e 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -118,7 +118,6 @@ elif args.format == 'onnx': collection_name = collection_name + '_onnx' if not os.path.exists(onnx_path): onnx_path = op.save_model(format='onnx') - sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) @@ -126,7 +125,7 @@ elif args.format == 'onnx': @towhee.register def run_onnx(img): img = img.convert('RGB') - img = op.tfms(img).cpu().detach().numpy() + img = op.tfms(img).unsqueeze(0).cpu().detach().numpy() features = sess.run(output_names=['output_0'], input_feed={'input_0': img})[0] if len(features.shape) == 4: global_pool = nn.AdaptiveAvgPool2d(1)