Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
2 additions and
3 deletions
-
timm_image.py
|
@ -144,15 +144,14 @@ class TimmImage(NNOperator): |
|
|
return img |
|
|
return img |
|
|
|
|
|
|
|
|
def post_proc(self, features): |
|
|
def post_proc(self, features): |
|
|
features = features.to(self.device) |
|
|
|
|
|
|
|
|
features = features.to('cpu') |
|
|
if features.dim() == 3: |
|
|
if features.dim() == 3: |
|
|
features = features[:, 0] |
|
|
features = features[:, 0] |
|
|
if features.dim() == 4: |
|
|
if features.dim() == 4: |
|
|
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) |
|
|
|
|
|
|
|
|
global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
features = global_pool(features) |
|
|
features = global_pool(features) |
|
|
features = features.flatten(1) |
|
|
features = features.flatten(1) |
|
|
assert features.dim() == 2, f'Invalid output dim {features.dim()}' |
|
|
assert features.dim() == 2, f'Invalid output dim {features.dim()}' |
|
|
features = features.to('cpu') |
|
|
|
|
|
return features |
|
|
return features |
|
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|