diff --git a/timm_image.py b/timm_image.py index b0188a7..98bf302 100644 --- a/timm_image.py +++ b/timm_image.py @@ -83,7 +83,7 @@ class TimmImage(NNOperator): imgs = data img_list = [] for img in imgs: - img = self.convert_img(img) if isinstance(img, numpy.ndarray) else img + img = self.convert_img(img) if isinstance(img, numpy.ndarray) else img.convert('RGB') img = img if self.skip_tfms else self.tfms(img) img_list.append(img) inputs = torch.stack(img_list)