logo
Browse Source

support decode bytes

Signed-off-by: junjiejiangjjj <junjie.jiang@zilliz.com>
main
junjiejiangjjj 2 years ago
parent
commit
a7c2a089e6
  1. 4
      __init__.py
  2. 30
      image_decoder.py

4
__init__.py

@ -16,5 +16,5 @@ from .image_decoder import ImageDecoder
# The factory method # The factory method
def nvjpeg():
return ImageDecoder()
def nvjpeg(*args, **kwargs):
return ImageDecoder(*args, **kwargs)

30
image_decoder.py

@ -7,13 +7,14 @@ import platform
import cv2 import cv2
import requests import requests
import numpy as np import numpy as np
from typing import Union
from towhee.types import Image from towhee.types import Image
from towhee.operator import PyOperator, SharedType from towhee.operator import PyOperator, SharedType
logger = logging.getLogger() logger = logging.getLogger()
def create_pynvjpeg():
def create_pynvjpeg(device):
if platform.system().lower() != 'linux': if platform.system().lower() != 'linux':
logger.warning('The GPU deocder only support linux') logger.warning('The GPU deocder only support linux')
return None return None
@ -26,7 +27,7 @@ def create_pynvjpeg():
try: try:
import pynvjpeg import pynvjpeg
decoder = pynvjpeg.Decoder() decoder = pynvjpeg.Decoder()
if not decoder.bind_device(0):
if not decoder.bind_device(device):
logger.info('Init GPU deocder failed, use CPU decoder') logger.info('Init GPU deocder failed, use CPU decoder')
return None return None
else: else:
@ -37,19 +38,8 @@ def create_pynvjpeg():
class ImageDecoder(PyOperator): class ImageDecoder(PyOperator):
def __init__(self):
self._nvjpeg = create_pynvjpeg()
def to_device(self):
if self._nvjpeg is not None and self.device >= 0:
if not self._nvjpeg.bind_device(self.device):
self._nvjpeg = None
@property
def device(self):
if hasattr(self, '_device'):
return self._device
return -1
def __init__(self, device=0):
self._nvjpeg = create_pynvjpeg(device)
def _download_image(self, image_url): def _download_image(self, image_url):
image_type = image_url.split('?')[0].split('.')[-1].lower() image_type = image_url.split('?')[0].split('.')[-1].lower()
@ -68,12 +58,14 @@ class ImageDecoder(PyOperator):
return is_jpeg, f.read() return is_jpeg, f.read()
def load_image(self, image_path): def load_image(self, image_path):
if isinstance(image_path, bytes):
return False, image_path
if image_path.startswith('http'): if image_path.startswith('http'):
return self._download_image(image_path) return self._download_image(image_path)
else: else:
return self._read_image(image_path) return self._read_image(image_path)
def __call__(self, image_path: str):
def __call__(self, image_path: Union[str, bytes]):
is_jpeg, image_content = self.load_image(image_path) is_jpeg, image_content = self.load_image(image_path)
if is_jpeg and self._nvjpeg is not None: if is_jpeg and self._nvjpeg is not None:
image = self._nvjpeg.imdecode(image_content) image = self._nvjpeg.imdecode(image_content)
@ -86,12 +78,6 @@ class ImageDecoder(PyOperator):
raise RuntimeError('Decode image %s failed' % image_path) raise RuntimeError('Decode image %s failed' % image_path)
return Image(image, 'RGB') return Image(image, 'RGB')
def input_schema(self):
return [(str, (1,))]
def output_schema(self):
return [(Image, (-1, -1, 3))]
@property @property
def shared_type(self): def shared_type(self):
return SharedType.NotShareable return SharedType.NotShareable

Loading…
Cancel
Save