import logging import os import numpy from pathlib import Path from PIL import Image as PImage from torchvision import transforms from towhee import register from towhee.operator import NNOperator, OperatorFlag from towhee.types import arg, to_image_color from towhee._types import Image import warnings warnings.filterwarnings('ignore') log = logging.getLogger() @register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,) class Cartoongan(NNOperator): """ A one line summary of this class. """ def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None: super().__init__() self._device = device if framework == 'pytorch': import importlib.util path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') opname = os.path.basename(str(Path(__file__))).split('.')[0] spec = importlib.util.spec_from_file_location(opname, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self.model = module.Model(model_name, self._device) self.tfms = transforms.Compose([ transforms.ToTensor() ]) @arg(1, to_image_color('BGR')) def __call__(self, image): image = self.tfms(image).unsqueeze(0) styled_image = self.model(image) styled_image = numpy.transpose(styled_image, (1, 2, 0)) styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) styled_image = numpy.array(styled_image) styled_image = styled_image[:, :, ::-1].copy() return Image(styled_image, 'BGR')