logo
Browse Source

support decode bytes

Signed-off-by: junjiejiangjjj <junjie.jiang@zilliz.com>
main
junjiejiangjjj 2 years ago
parent
commit
a7c2a089e6
  1. 4
      __init__.py
  2. 30
      image_decoder.py

4
__init__.py

@ -16,5 +16,5 @@ from .image_decoder import ImageDecoder
# The factory method
def nvjpeg():
return ImageDecoder()
def nvjpeg(*args, **kwargs):
return ImageDecoder(*args, **kwargs)

30
image_decoder.py

@ -7,13 +7,14 @@ import platform
import cv2
import requests
import numpy as np
from typing import Union
from towhee.types import Image
from towhee.operator import PyOperator, SharedType
logger = logging.getLogger()
def create_pynvjpeg():
def create_pynvjpeg(device):
if platform.system().lower() != 'linux':
logger.warning('The GPU deocder only support linux')
return None
@ -26,7 +27,7 @@ def create_pynvjpeg():
try:
import pynvjpeg
decoder = pynvjpeg.Decoder()
if not decoder.bind_device(0):
if not decoder.bind_device(device):
logger.info('Init GPU deocder failed, use CPU decoder')
return None
else:
@ -37,19 +38,8 @@ def create_pynvjpeg():
class ImageDecoder(PyOperator):
def __init__(self):
self._nvjpeg = create_pynvjpeg()
def to_device(self):
if self._nvjpeg is not None and self.device >= 0:
if not self._nvjpeg.bind_device(self.device):
self._nvjpeg = None
@property
def device(self):
if hasattr(self, '_device'):
return self._device
return -1
def __init__(self, device=0):
self._nvjpeg = create_pynvjpeg(device)
def _download_image(self, image_url):
image_type = image_url.split('?')[0].split('.')[-1].lower()
@ -68,12 +58,14 @@ class ImageDecoder(PyOperator):
return is_jpeg, f.read()
def load_image(self, image_path):
if isinstance(image_path, bytes):
return False, image_path
if image_path.startswith('http'):
return self._download_image(image_path)
else:
return self._read_image(image_path)
def __call__(self, image_path: str):
def __call__(self, image_path: Union[str, bytes]):
is_jpeg, image_content = self.load_image(image_path)
if is_jpeg and self._nvjpeg is not None:
image = self._nvjpeg.imdecode(image_content)
@ -86,12 +78,6 @@ class ImageDecoder(PyOperator):
raise RuntimeError('Decode image %s failed' % image_path)
return Image(image, 'RGB')
def input_schema(self):
return [(str, (1,))]
def output_schema(self):
return [(Image, (-1, -1, 3))]
@property
def shared_type(self):
return SharedType.NotShareable

Loading…
Cancel
Save