From af2086bdce870349bf1db21d84509ce37c049cb7 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Thu, 9 Feb 2023 10:11:59 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- timm_image.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/timm_image.py b/timm_image.py index 286c9bc..1e869b9 100644 --- a/timm_image.py +++ b/timm_image.py @@ -144,15 +144,14 @@ class TimmImage(NNOperator): return img def post_proc(self, features): - features = features.to(self.device) + features = features.to('cpu') if features.dim() == 3: features = features[:, 0] if features.dim() == 4: - global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) + global_pool = nn.AdaptiveAvgPool2d(1) features = global_pool(features) features = features.flatten(1) assert features.dim() == 2, f'Invalid output dim {features.dim()}' - features = features.to('cpu') return features def save_model(self, format: str = 'pytorch', path: str = 'default'):