logo
Browse Source

Fix dim issue

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
3f503747f3
  1. 2
      benchmark/run.py
  2. 8
      benchmark/run.sh
  3. 3
      timm_image.py

2
benchmark/run.py

@ -120,7 +120,7 @@ if args.format == 'pytorch':
elif args.format == 'onnx':
collection_name = collection_name + '_onnx'
if not os.path.exists(onnx_path):
onnx_path = op.save_model(format='onnx')
onnx_path = str(op.save_model(format='onnx'))
sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers())

8
benchmark/run.sh

@ -0,0 +1,8 @@
#!/bin/bash
for name in beit_base_patch16_224 beit_base_patch16_224_in22k beit_base_patch16_384 beit_large_patch16_224 beit_large_patch16_224_in22k beit_large_patch16_384 beit_large_patch16_512 beitv2_base_patch16_224 beitv2_base_patch16_224_in22k beitv2_large_patch16_224 beitv2_large_patch16_224_in22k cait_m36_384 cait_m48_448 cait_s24_224 cait_xs24_384 convnext_large_in22ft1k convnext_small_384_in22ft1k convnext_tiny_in22k convnext_xlarge_in22ft1k convnext_xlarge_in22k deit3_medium_patch16_224 deit3_small_patch16_384 deit_base_distilled_patch16_384 mixer_b16_224 mixer_b16_224_in21k mixer_b16_224_miil mixer_b16_224_miil_in21k mixer_l16_224 mixer_l16_224_in21k mobilevitv2_175_384_in22ft1k mobilevitv2_200_384_in22ft1k repvgg_b2g4 res2net50_26w_8s resmlp_big_24_distilled_224 seresnextaa101d_32x8d vit_base_patch16_224_in21k vit_base_patch16_384 vit_base_patch8_224 vit_base_patch8_224_in21k vit_giant_patch14_224_clip_laion2b vit_large_patch16_224 vit_large_patch16_224_in21k vit_large_patch16_384 vit_large_patch32_384 vit_large_r50_s32_224 vit_large_r50_s32_384 vit_relpos_base_patch16_clsgap_224 vit_relpos_medium_patch16_224 vit_relpos_small_patch16_224 vit_small_patch32_224 vit_small_patch32_224_in21k vit_small_r26_s32_384 xcit_large_24_p8_224 xcit_large_24_p8_224_dist xcit_large_24_p8_384_dist xcit_nano_12_p16_384_dist xcit_nano_12_p8_224 xcit_nano_12_p8_224_dist xcit_nano_12_p8_384_dist xcit_small_24_p8_224 xcit_tiny_12_p8_224 xcit_tiny_12_p8_384_dist xcit_tiny_24_p8_224 xcit_tiny_24_p8_384_dist
do
echo ***${name}***
python run.py --model ${name} --format pytorch
python run.py --model ${name} --format onnx
done

3
timm_image.py

@ -110,7 +110,7 @@ class TimmImage(NNOperator):
vecs = [list(x.detach().numpy()) for x in features] if isinstance(features, list) \
else list(features.detach().numpy())
else:
vecs = [x.squeeze(0).detach().numpy()] if instance(features, list) \
vecs = [x.squeeze(0).detach().numpy()] if isinstance(features, list) \
else features.squeeze(0).detach().numpy()
return vecs
@ -133,6 +133,7 @@ class TimmImage(NNOperator):
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features)
features = features.flatten(1)
assert features.dim() == 2, f'Invalid output dim {features.dim()}'
features = features.to('cpu')
return features

Loading…
Cancel
Save