From c7881ca7a04325f75c5674cdcd3ce108329792e4 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Mon, 26 Dec 2022 12:16:47 +0800 Subject: [PATCH] Fix post process Signed-off-by: Jael Gu --- timm_image.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm_image.py b/timm_image.py index 94d1eec..b9e03cd 100644 --- a/timm_image.py +++ b/timm_image.py @@ -128,7 +128,7 @@ class TimmImage(NNOperator): return img def post_proc(self, features): - if 'vit' in self.model_name and features.dim() == 3: + if features.dim() == 3: features = features[:, 0] if features.dim() == 4: global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)