logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
af2086bdce
  1. 5
      timm_image.py

5
timm_image.py

@ -144,15 +144,14 @@ class TimmImage(NNOperator):
return img return img
def post_proc(self, features): def post_proc(self, features):
features = features.to(self.device)
features = features.to('cpu')
if features.dim() == 3: if features.dim() == 3:
features = features[:, 0] features = features[:, 0]
if features.dim() == 4: if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
global_pool = nn.AdaptiveAvgPool2d(1)
features = global_pool(features) features = global_pool(features)
features = features.flatten(1) features = features.flatten(1)
assert features.dim() == 2, f'Invalid output dim {features.dim()}' assert features.dim() == 2, f'Invalid output dim {features.dim()}'
features = features.to('cpu')
return features return features
def save_model(self, format: str = 'pytorch', path: str = 'default'): def save_model(self, format: str = 'pytorch', path: str = 'default'):

Loading…
Cancel
Save