From 53144950f1b6cb6922bbaaa5c3616c198fc466e1 Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Wed, 1 Feb 2023 15:52:09 +0800 Subject: [PATCH] support decode from bytes Signed-off-by: junjiejiangjjj --- __init__.py | 4 ++-- image_decode_cv2.py | 38 +++++++++++++++++++++++++------------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/__init__.py b/__init__.py index c3f28e4..169ca5d 100644 --- a/__init__.py +++ b/__init__.py @@ -16,5 +16,5 @@ from .image_decode_cv2 import ImageDecodeCV2 # The factory method -def cv2(): - return ImageDecodeCV2() +def cv2(*args, **kwargs): + return ImageDecodeCV2(*args, **kwargs) diff --git a/image_decode_cv2.py b/image_decode_cv2.py index b5c4983..bdc8612 100644 --- a/image_decode_cv2.py +++ b/image_decode_cv2.py @@ -14,22 +14,23 @@ import logging +from typing import Union import cv2 import requests import numpy as np -from towhee import register -from towhee.operator import PyOperator, OperatorFlag -from towhee._types import Image +from towhee.operator import PyOperator +from towhee.types import Image log = logging.getLogger() -@register(output_schema=['img'], - flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE) class ImageDecodeCV2(PyOperator): - def __init__(self): - pass + def __init__(self, mode='BGR'): + mode = mode.upper() + if mode not in ['BGR', 'RGB']: + raise RuntimeError("Mode only support BRG and RGB") + self._mode = mode @staticmethod def _load_from_remote(image_url: str) -> np.ndarray: @@ -49,14 +50,25 @@ class ImageDecodeCV2(PyOperator): def _load_from_local(image_path: str) -> np.ndarray: return cv2.imread(image_path) - def __call__(self, image_path: str): - if image_path.startswith('http'): - bgr_cv_image = ImageDecodeCV2._load_from_remote(image_path) + @staticmethod + def _from_bytes(image_bytes) -> np.ndarray: + arr = np.asarray(bytearray(image_bytes), dtype=np.uint8) + return cv2.imdecode(arr, -1) + + def __call__(self, image_data: Union[str, bytes]): + if isinstance(image_data, bytes): + bgr_cv_image = ImageDecodeCV2._from_bytes(image_data) + elif image_data.startswith('http'): + bgr_cv_image = ImageDecodeCV2._load_from_remote(image_data) else: - bgr_cv_image = ImageDecodeCV2._load_from_local(image_path) + bgr_cv_image = ImageDecodeCV2._load_from_local(image_data) + if bgr_cv_image is None: - err = 'Read image %s failed' % image_path + err = 'Read image %s failed' % image_data log.error(err) raise RuntimeError(err) - return Image(bgr_cv_image, 'BGR') + if self._mode == 'BGR': + return Image(bgr_cv_image, 'BGR') + rgb_cv_image = cv2.cvtColor(bgr_cv_image, cv2.COLOR_BGR2RGB) + return Image(rgb_cv_image, 'RGB')