logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

91 lines
3.0 KiB

# encoding=utf-8
import os
import sys
import logging
import platform
import cv2
import requests
import numpy as np
from towhee.types import Image
from towhee.operator import PyOperator, SharedType
logger = logging.getLogger()
def create_pynvjpeg():
if platform.system().lower() != 'linux':
logger.warning('The GPU deocder only support linux')
return None
py_version = platform.python_version()
version = ''.join(py_version.split('.')[0:2])
cur_dir = os.path.dirname(os.path.abspath(__file__))
lib_path = cur_dir + '/' + 'py' + version + 'nvjpeg'
sys.path.insert(0, lib_path)
try:
import pynvjpeg
decoder = pynvjpeg.Decoder()
if not decoder.bind_device(0):
logger.info('Init GPU deocder failed, use CPU decoder')
return None
else:
return decoder
except Exception as e:
logger.error('Create nvjpeg failed, use opencv decoder, errors: ', str(e))
return None
class ImageDecoder(PyOperator):
def __init__(self):
self._nvjpeg = create_pynvjpeg()
def to_device(self, device):
if self._nvjpeg is not None:
if not self._nvjpeg.bind_device(device):
self._nvjpeg = None
def _download_image(self, image_url):
image_type = image_url.split('?')[0].split('.')[-1].lower()
is_jpeg = True if image_type == 'jpeg' or image_type == 'jpg' else False
r = requests.get(image_url, timeout=(20, 20))
if r.status_code // 100 != 2:
logging.error('Download image from %s failed, error msg: %s, request code: %s ' % (image_url,
r.text,
r.status_code))
return None
return is_jpeg, r.content
def _read_image(self, image_path):
is_jpeg = image_path.lower().endswith('jpg') or image_path.lower().endswith('jpeg')
with open(image_path, 'rb') as f:
return is_jpeg, f.read()
def load_image(self, image_path):
if image_path.startswith('http'):
return self._download_image(image_path)
else:
return self._read_image(image_path)
def __call__(self, image_path: str):
is_jpeg, image_content = self.load_image(image_path)
if is_jpeg and self._nvjpeg is not None:
image = self._nvjpeg.imdecode(image_content)
else:
arr = np.asarray(bytearray(image_content), dtype=np.uint8)
image = cv2.imdecode(arr, -1)
if image is not None:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if image is None:
raise RuntimeError('Decode image %s failed' % image_path)
return Image(image, 'RGB')
def input_schema(self):
return [(str, (1,))]
def output_schema(self):
return [(Image, (-1, -1, 3))]
@property
def shared_type(self):
return SharedType.NotShareable