nvjpeg
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
91 lines
3.0 KiB
91 lines
3.0 KiB
# encoding=utf-8
|
|
import os
|
|
import sys
|
|
import logging
|
|
import platform
|
|
|
|
import cv2
|
|
import requests
|
|
import numpy as np
|
|
from towhee.types import Image
|
|
from towhee.operator import PyOperator, SharedType
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
def create_pynvjpeg():
|
|
if platform.system().lower() != 'linux':
|
|
logger.warning('The GPU deocder only support linux')
|
|
return None
|
|
|
|
py_version = platform.python_version()
|
|
version = ''.join(py_version.split('.')[0:2])
|
|
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
|
lib_path = cur_dir + '/' + 'py' + version + 'nvjpeg'
|
|
sys.path.insert(0, lib_path)
|
|
try:
|
|
import pynvjpeg
|
|
decoder = pynvjpeg.Decoder()
|
|
if not decoder.bind_device(0):
|
|
logger.info('Init GPU deocder failed, use CPU decoder')
|
|
return None
|
|
else:
|
|
return decoder
|
|
except Exception as e:
|
|
logger.error('Create nvjpeg failed, use opencv decoder, errors: ', str(e))
|
|
return None
|
|
|
|
|
|
class ImageDecoder(PyOperator):
|
|
def __init__(self):
|
|
self._nvjpeg = create_pynvjpeg()
|
|
|
|
def to_device(self, device):
|
|
if self._nvjpeg is not None:
|
|
if not self._nvjpeg.bind_device(device):
|
|
self._nvjpeg = None
|
|
|
|
def _download_image(self, image_url):
|
|
image_type = image_url.split('?')[0].split('.')[-1].lower()
|
|
is_jpeg = True if image_type == 'jpeg' or image_type == 'jpg' else False
|
|
r = requests.get(image_url, timeout=(20, 20))
|
|
if r.status_code // 100 != 2:
|
|
logging.error('Download image from %s failed, error msg: %s, request code: %s ' % (image_url,
|
|
r.text,
|
|
r.status_code))
|
|
return None
|
|
return is_jpeg, r.content
|
|
|
|
def _read_image(self, image_path):
|
|
is_jpeg = image_path.lower().endswith('jpg') or image_path.lower().endswith('jpeg')
|
|
with open(image_path, 'rb') as f:
|
|
return is_jpeg, f.read()
|
|
|
|
def load_image(self, image_path):
|
|
if image_path.startswith('http'):
|
|
return self._download_image(image_path)
|
|
else:
|
|
return self._read_image(image_path)
|
|
|
|
def __call__(self, image_path: str):
|
|
is_jpeg, image_content = self.load_image(image_path)
|
|
if is_jpeg and self._nvjpeg is not None:
|
|
image = self._nvjpeg.imdecode(image_content)
|
|
else:
|
|
arr = np.asarray(bytearray(image_content), dtype=np.uint8)
|
|
image = cv2.imdecode(arr, -1)
|
|
if image is not None:
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
if image is None:
|
|
raise RuntimeError('Decode image %s failed' % image_path)
|
|
return Image(image, 'RGB')
|
|
|
|
def input_schema(self):
|
|
return [(str, (1,))]
|
|
|
|
def output_schema(self):
|
|
return [(Image, (-1, -1, 3))]
|
|
|
|
@property
|
|
def shared_type(self):
|
|
return SharedType.NotShareable
|