From 712619032597b78319f08aca6a1e8e3d6aac35f9 Mon Sep 17 00:00:00 2001 From: Jael Gu Date: Fri, 24 Jun 2022 13:18:29 +0800 Subject: [PATCH] Update Signed-off-by: Jael Gu --- timm_image.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/timm_image.py b/timm_image.py index 8ef3583..d587b7a 100644 --- a/timm_image.py +++ b/timm_image.py @@ -64,17 +64,14 @@ class TimmImage(NNOperator): self.skip_tfms = skip_preprocess def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): - if isinstance(data, list): - imgs = [] - for img in data: - img = self.convert_img(img) - img = img if self.skip_tfms else self.tfms(img) - imgs.append(img) - inputs = torch.stack(imgs) - else: - img = self.convert_img(data) + if not isinstance(data, list): + data = [data] + imgs = [] + for img in data: + img = self.convert_img(img) img = img if self.skip_tfms else self.tfms(img) - inputs = img.unsqueeze(0) + imgs.append(img) + inputs = torch.stack(imgs) inputs = inputs.to(self.device) features = self.model.forward_features(inputs) if features.dim() == 4: