logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

49 lines
1.7 KiB

import logging
import os
import numpy
from pathlib import Path
from PIL import Image as PImage
from torchvision import transforms
from towhee import register
from towhee.operator import NNOperator, OperatorFlag
from towhee.types import arg, to_image_color
from towhee._types import Image
import warnings
warnings.filterwarnings('ignore')
log = logging.getLogger()
@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,)
class Cartoongan(NNOperator):
"""
A one line summary of this class.
"""
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None:
super().__init__()
self._device = device
if framework == 'pytorch':
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name, self._device)
self.tfms = transforms.Compose([
transforms.ToTensor()
])
@arg(1, to_image_color('RGB'))
def __call__(self, image):
img = self.tfms(image).unsqueeze(0)
styled_image = self.model(img)
styled_image = numpy.transpose(styled_image, (1, 2, 0))
styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8))
styled_image = numpy.array(styled_image)
styled_image = styled_image[:, :, ::-1].copy()
return Image(styled_image, 'BGR')