diff --git a/timm_image.py b/timm_image.py index ab132b4..80ed6bb 100644 --- a/timm_image.py +++ b/timm_image.py @@ -125,7 +125,7 @@ class TimmImage(NNOperator): vecs = [list(x.detach().numpy()) for x in features] if isinstance(features, list) \ else list(features.detach().numpy()) else: - vecs = [x.squeeze(0).detach().numpy()] if isinstance(features, list) \ + vecs = [x.squeeze(0).detach().numpy() for x in features] if isinstance(features, list) \ else features.squeeze(0).detach().numpy() return vecs