diff --git a/timm_image.py b/timm_image.py index 8ef3583..d587b7a 100644 --- a/timm_image.py +++ b/timm_image.py @@ -64,17 +64,14 @@ class TimmImage(NNOperator): self.skip_tfms = skip_preprocess def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): - if isinstance(data, list): - imgs = [] - for img in data: - img = self.convert_img(img) - img = img if self.skip_tfms else self.tfms(img) - imgs.append(img) - inputs = torch.stack(imgs) - else: - img = self.convert_img(data) + if not isinstance(data, list): + data = [data] + imgs = [] + for img in data: + img = self.convert_img(img) img = img if self.skip_tfms else self.tfms(img) - inputs = img.unsqueeze(0) + imgs.append(img) + inputs = torch.stack(imgs) inputs = inputs.to(self.device) features = self.model.forward_features(inputs) if features.dim() == 4: