diff --git a/timm_image.py b/timm_image.py index fe2cc97..f937cab 100644 --- a/timm_image.py +++ b/timm_image.py @@ -22,6 +22,7 @@ import torch from torch import nn from PIL import Image as PILImage +import cv2 from timm.data.transforms_factory import create_transform from timm.data import resolve_data_config @@ -55,9 +56,12 @@ class TimmImage(NNOperator): def __call__(self, img: numpy.ndarray) -> numpy.ndarray: if hasattr(img, 'mode'): - if img.mode != 'RGB': + if img.mode not in ['RGB', 'BGR']: log.error(f'Invalid image mode: expect "RGB" but receive "{img.mode}".') raise AssertionError(f'Invalid image mode "{img.mode}".') + elif img.mode == 'BGR': + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + log.warning('Converting image mode from "BGR" to "RGB" ...') else: log.warning(f'Image mode is not specified. Using "RGB" now.')