|
|
@ -7,13 +7,14 @@ import platform |
|
|
|
import cv2 |
|
|
|
import requests |
|
|
|
import numpy as np |
|
|
|
from typing import Union |
|
|
|
from towhee.types import Image |
|
|
|
from towhee.operator import PyOperator, SharedType |
|
|
|
|
|
|
|
logger = logging.getLogger() |
|
|
|
|
|
|
|
|
|
|
|
def create_pynvjpeg(): |
|
|
|
def create_pynvjpeg(device): |
|
|
|
if platform.system().lower() != 'linux': |
|
|
|
logger.warning('The GPU deocder only support linux') |
|
|
|
return None |
|
|
@ -26,7 +27,7 @@ def create_pynvjpeg(): |
|
|
|
try: |
|
|
|
import pynvjpeg |
|
|
|
decoder = pynvjpeg.Decoder() |
|
|
|
if not decoder.bind_device(0): |
|
|
|
if not decoder.bind_device(device): |
|
|
|
logger.info('Init GPU deocder failed, use CPU decoder') |
|
|
|
return None |
|
|
|
else: |
|
|
@ -37,19 +38,8 @@ def create_pynvjpeg(): |
|
|
|
|
|
|
|
|
|
|
|
class ImageDecoder(PyOperator): |
|
|
|
def __init__(self): |
|
|
|
self._nvjpeg = create_pynvjpeg() |
|
|
|
|
|
|
|
def to_device(self): |
|
|
|
if self._nvjpeg is not None and self.device >= 0: |
|
|
|
if not self._nvjpeg.bind_device(self.device): |
|
|
|
self._nvjpeg = None |
|
|
|
|
|
|
|
@property |
|
|
|
def device(self): |
|
|
|
if hasattr(self, '_device'): |
|
|
|
return self._device |
|
|
|
return -1 |
|
|
|
def __init__(self, device=0): |
|
|
|
self._nvjpeg = create_pynvjpeg(device) |
|
|
|
|
|
|
|
def _download_image(self, image_url): |
|
|
|
image_type = image_url.split('?')[0].split('.')[-1].lower() |
|
|
@ -68,12 +58,14 @@ class ImageDecoder(PyOperator): |
|
|
|
return is_jpeg, f.read() |
|
|
|
|
|
|
|
def load_image(self, image_path): |
|
|
|
if isinstance(image_path, bytes): |
|
|
|
return False, image_path |
|
|
|
if image_path.startswith('http'): |
|
|
|
return self._download_image(image_path) |
|
|
|
else: |
|
|
|
return self._read_image(image_path) |
|
|
|
|
|
|
|
def __call__(self, image_path: str): |
|
|
|
def __call__(self, image_path: Union[str, bytes]): |
|
|
|
is_jpeg, image_content = self.load_image(image_path) |
|
|
|
if is_jpeg and self._nvjpeg is not None: |
|
|
|
image = self._nvjpeg.imdecode(image_content) |
|
|
@ -86,12 +78,6 @@ class ImageDecoder(PyOperator): |
|
|
|
raise RuntimeError('Decode image %s failed' % image_path) |
|
|
|
return Image(image, 'RGB') |
|
|
|
|
|
|
|
def input_schema(self): |
|
|
|
return [(str, (1,))] |
|
|
|
|
|
|
|
def output_schema(self): |
|
|
|
return [(Image, (-1, -1, 3))] |
|
|
|
|
|
|
|
@property |
|
|
|
def shared_type(self): |
|
|
|
return SharedType.NotShareable |
|
|
|