From a7c2a089e646db1d092927587efbae16eb0fbc9e Mon Sep 17 00:00:00 2001 From: junjiejiangjjj Date: Wed, 1 Feb 2023 16:15:41 +0800 Subject: [PATCH] support decode bytes Signed-off-by: junjiejiangjjj --- __init__.py | 4 ++-- image_decoder.py | 30 ++++++++---------------------- 2 files changed, 10 insertions(+), 24 deletions(-) diff --git a/__init__.py b/__init__.py index 1bf901e..1accef4 100644 --- a/__init__.py +++ b/__init__.py @@ -16,5 +16,5 @@ from .image_decoder import ImageDecoder # The factory method -def nvjpeg(): - return ImageDecoder() +def nvjpeg(*args, **kwargs): + return ImageDecoder(*args, **kwargs) diff --git a/image_decoder.py b/image_decoder.py index 5527ee0..edf429f 100644 --- a/image_decoder.py +++ b/image_decoder.py @@ -7,13 +7,14 @@ import platform import cv2 import requests import numpy as np +from typing import Union from towhee.types import Image from towhee.operator import PyOperator, SharedType logger = logging.getLogger() -def create_pynvjpeg(): +def create_pynvjpeg(device): if platform.system().lower() != 'linux': logger.warning('The GPU deocder only support linux') return None @@ -26,7 +27,7 @@ def create_pynvjpeg(): try: import pynvjpeg decoder = pynvjpeg.Decoder() - if not decoder.bind_device(0): + if not decoder.bind_device(device): logger.info('Init GPU deocder failed, use CPU decoder') return None else: @@ -37,19 +38,8 @@ def create_pynvjpeg(): class ImageDecoder(PyOperator): - def __init__(self): - self._nvjpeg = create_pynvjpeg() - - def to_device(self): - if self._nvjpeg is not None and self.device >= 0: - if not self._nvjpeg.bind_device(self.device): - self._nvjpeg = None - - @property - def device(self): - if hasattr(self, '_device'): - return self._device - return -1 + def __init__(self, device=0): + self._nvjpeg = create_pynvjpeg(device) def _download_image(self, image_url): image_type = image_url.split('?')[0].split('.')[-1].lower() @@ -68,12 +58,14 @@ class ImageDecoder(PyOperator): return is_jpeg, f.read() def load_image(self, image_path): + if isinstance(image_path, bytes): + return False, image_path if image_path.startswith('http'): return self._download_image(image_path) else: return self._read_image(image_path) - def __call__(self, image_path: str): + def __call__(self, image_path: Union[str, bytes]): is_jpeg, image_content = self.load_image(image_path) if is_jpeg and self._nvjpeg is not None: image = self._nvjpeg.imdecode(image_content) @@ -86,12 +78,6 @@ class ImageDecoder(PyOperator): raise RuntimeError('Decode image %s failed' % image_path) return Image(image, 'RGB') - def input_schema(self): - return [(str, (1,))] - - def output_schema(self): - return [(Image, (-1, -1, 3))] - @property def shared_type(self): return SharedType.NotShareable