From 3f503747f3c196412c6354c8781bac3256fd6005 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Tue, 20 Dec 2022 11:39:13 +0800 Subject: [PATCH] Fix dim issue Signed-off-by: Jael Gu --- benchmark/run.py | 2 +- benchmark/run.sh | 8 ++++++++ timm_image.py | 3 ++- 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100755 benchmark/run.sh diff --git a/benchmark/run.py b/benchmark/run.py index 7052ad3..2d5f5fd 100644 --- a/benchmark/run.py +++ b/benchmark/run.py @@ -120,7 +120,7 @@ if args.format == 'pytorch': elif args.format == 'onnx': collection_name = collection_name + '_onnx' if not os.path.exists(onnx_path): - onnx_path = op.save_model(format='onnx') + onnx_path = str(op.save_model(format='onnx')) sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) diff --git a/benchmark/run.sh b/benchmark/run.sh new file mode 100755 index 0000000..bdf4b65 --- /dev/null +++ b/benchmark/run.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +for name in beit_base_patch16_224 beit_base_patch16_224_in22k beit_base_patch16_384 beit_large_patch16_224 beit_large_patch16_224_in22k beit_large_patch16_384 beit_large_patch16_512 beitv2_base_patch16_224 beitv2_base_patch16_224_in22k beitv2_large_patch16_224 beitv2_large_patch16_224_in22k cait_m36_384 cait_m48_448 cait_s24_224 cait_xs24_384 convnext_large_in22ft1k convnext_small_384_in22ft1k convnext_tiny_in22k convnext_xlarge_in22ft1k convnext_xlarge_in22k deit3_medium_patch16_224 deit3_small_patch16_384 deit_base_distilled_patch16_384 mixer_b16_224 mixer_b16_224_in21k mixer_b16_224_miil mixer_b16_224_miil_in21k mixer_l16_224 mixer_l16_224_in21k mobilevitv2_175_384_in22ft1k mobilevitv2_200_384_in22ft1k repvgg_b2g4 res2net50_26w_8s resmlp_big_24_distilled_224 seresnextaa101d_32x8d vit_base_patch16_224_in21k vit_base_patch16_384 vit_base_patch8_224 vit_base_patch8_224_in21k vit_giant_patch14_224_clip_laion2b vit_large_patch16_224 vit_large_patch16_224_in21k vit_large_patch16_384 vit_large_patch32_384 vit_large_r50_s32_224 vit_large_r50_s32_384 vit_relpos_base_patch16_clsgap_224 vit_relpos_medium_patch16_224 vit_relpos_small_patch16_224 vit_small_patch32_224 vit_small_patch32_224_in21k vit_small_r26_s32_384 xcit_large_24_p8_224 xcit_large_24_p8_224_dist xcit_large_24_p8_384_dist xcit_nano_12_p16_384_dist xcit_nano_12_p8_224 xcit_nano_12_p8_224_dist xcit_nano_12_p8_384_dist xcit_small_24_p8_224 xcit_tiny_12_p8_224 xcit_tiny_12_p8_384_dist xcit_tiny_24_p8_224 xcit_tiny_24_p8_384_dist +do + echo ***${name}*** + python run.py --model ${name} --format pytorch + python run.py --model ${name} --format onnx +done diff --git a/timm_image.py b/timm_image.py index b8eb495..94d1eec 100644 --- a/timm_image.py +++ b/timm_image.py @@ -110,7 +110,7 @@ class TimmImage(NNOperator): vecs = [list(x.detach().numpy()) for x in features] if isinstance(features, list) \ else list(features.detach().numpy()) else: - vecs = [x.squeeze(0).detach().numpy()] if instance(features, list) \ + vecs = [x.squeeze(0).detach().numpy()] if isinstance(features, list) \ else features.squeeze(0).detach().numpy() return vecs @@ -133,6 +133,7 @@ class TimmImage(NNOperator): if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) features = global_pool(features) + features = features.flatten(1) assert features.dim() == 2, f'Invalid output dim {features.dim()}' features = features.to('cpu') return features