logo
Browse Source

Fix post process

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
c7881ca7a0
  1. 2
      timm_image.py

2
timm_image.py

@ -128,7 +128,7 @@ class TimmImage(NNOperator):
return img
def post_proc(self, features):
if 'vit' in self.model_name and features.dim() == 3:
if features.dim() == 3:
features = features[:, 0]
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)

Loading…
Cancel
Save