# encoding=utf-8 import os import sys import logging import platform import cv2 import requests import numpy as np from towhee.types import Image from towhee.operator import PyOperator, SharedType logger = logging.getLogger() def create_pynvjpeg(): if platform.system().lower() != 'linux': logger.warning('The GPU deocder only support linux') return None py_version = platform.python_version() version = ''.join(py_version.split('.')[0:2]) cur_dir = os.path.dirname(os.path.abspath(__file__)) lib_path = cur_dir + '/' + 'py' + version + 'nvjpeg' sys.path.insert(0, lib_path) try: import pynvjpeg decoder = pynvjpeg.Decoder() if not decoder.bind_device(0): logger.info('Init GPU deocder failed, use CPU decoder') return None else: return decoder except Exception as e: logger.error('Create nvjpeg failed, use opencv decoder, errors: ', str(e)) return None class ImageDecoder(PyOperator): def __init__(self): self._nvjpeg = create_pynvjpeg() def to_device(self, device): if self._nvjpeg is not None: if not self._nvjpeg.bind_device(device): self._nvjpeg = None def _download_image(self, image_url): image_type = image_url.split('?')[0].split('.')[-1].lower() is_jpeg = True if image_type == 'jpeg' or image_type == 'jpg' else False r = requests.get(image_url, timeout=(20, 20)) if r.status_code // 100 != 2: logging.error('Download image from %s failed, error msg: %s, request code: %s ' % (image_url, r.text, r.status_code)) return None return is_jpeg, r.content def _read_image(self, image_path): is_jpeg = image_path.lower().endswith('jpg') or image_path.lower().endswith('jpeg') with open(image_path, 'rb') as f: return is_jpeg, f.read() def load_image(self, image_path): if image_path.startswith('http'): return self._download_image(image_path) else: return self._read_image(image_path) def __call__(self, image_path: str): is_jpeg, image_content = self.load_image(image_path) if is_jpeg and self._nvjpeg is not None: image = self._nvjpeg.imdecode(image_content) else: arr = np.asarray(bytearray(image_content), dtype=np.uint8) image = cv2.imdecode(arr, -1) if image is not None: image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if image is None: raise RuntimeError('Decode image %s failed' % image_path) return Image(image, 'RGB') def input_schema(self): return [(str, (1,))] def output_schema(self): return [(Image, (-1, -1, 3))] @property def shared_type(self): return SharedType.NotShareable