logo
junjie.jiangjjj 2 years ago
parent
commit
a94e20ac26
  1. 60
      image_decode_cv2.py

60
image_decode_cv2.py

@ -1,60 +1,56 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
# encoding=utf-8
from collections import namedtuple
import cv2 import cv2
import requests import requests
import numpy as np import numpy as np
from towhee.types import Image
from towhee.operator import PyOperator, SharedType
from towhee import register from towhee import register
from towhee.operator import PyOperator, OperatorFlag
from towhee._types import Image
from towhee.utils import ndarray_utils
from towhee.utils.log import engine_log
log = logging.getLogger()
@register(
input_schema=[(str, (-1,))],
output_schema=[(Image, (-1, -1, 3))]
)
class ImageDecoder(PyOperator):
def __init__(self):
super().__init__()
@register(output_schema=['img'],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
class ImageDecodeCV2(PyOperator):
@staticmethod
def _load_from_remote(image_url: str) -> np.ndarray: def _load_from_remote(image_url: str) -> np.ndarray:
try: try:
r = requests.get(image_url, timeout=(20, 20)) r = requests.get(image_url, timeout=(20, 20))
if r.status_code // 100 != 2: if r.status_code // 100 != 2:
log.error('Download image from %s failed, error msg: %s, request code: %s ',
image_url, r.text, r.status_code)
engine_log.error('Download image from %s failed, error msg: %s, request code: %s ' % (image_url,
r.text,
r.status_code))
return None return None
arr = np.asarray(bytearray(r.content), dtype=np.uint8) arr = np.asarray(bytearray(r.content), dtype=np.uint8)
return cv2.imdecode(arr, -1) return cv2.imdecode(arr, -1)
except Exception as e: except Exception as e:
log.error('Download image from %s failed, error msg: %s', image_url, str(e))
engine_log.error('Download image from %s failed, error msg: %s' % (image_url, str(e)))
return False return False
@staticmethod
def _load_from_local(image_path: str) -> np.ndarray: def _load_from_local(image_path: str) -> np.ndarray:
return cv2.imread(image_path) return cv2.imread(image_path)
def __call__(self, image_path: str): def __call__(self, image_path: str):
if image_path.startswith('http'): if image_path.startswith('http'):
bgr_cv_image = ImageDecodeCV2._load_from_remote(image_path)
bgr_cv_image = ImageDecoder._load_from_remote(image_path)
else: else:
bgr_cv_image = ImageDecodeCV2._load_from_local(image_path)
bgr_cv_image = ImageDecoder._load_from_local(image_path)
if bgr_cv_image is None: if bgr_cv_image is None:
err = 'Read image %s failed' % image_path err = 'Read image %s failed' % image_path
log.error(err)
engine_log.error(err)
raise RuntimeError(err) raise RuntimeError(err)
rgb = cv2.cvtColor(bgr_cv_image, cv2.COLOR_BGR2RGB)
return Image(rgb, 'RGB')
rgb_cv_image = cv2.cvtColor(bgr_cv_image, cv2.COLOR_BGR2RGB)
return ndarray_utils.from_ndarray(rgb_cv_image, 'RGB')
@property
def shared_type(self):
return SharedType.Shareable

Loading…
Cancel
Save