import os import numpy from pathlib import Path from PIL import Image as PImage from torchvision import transforms from towhee import register from towhee.operator import Operator, OperatorFlag from towhee.types import arg, to_image_color from towhee._types import Image import warnings warnings.filterwarnings('ignore') @register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,) class Animegan(Operator): """ PyTorch model for image embedding. """ def __init__(self, model_name: str, framework: str = 'pytorch') -> None: super().__init__() if framework == 'pytorch': import importlib.util path = os.path.join(str(Path("__file__").parent), 'pytorch', 'model.py') opname = os.path.basename(str(Path("__file__"))).split('.')[0] spec = importlib.util.spec_from_file_location(opname, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self.model = module.Model(model_name) self.tfms = transforms.Compose([ transforms.ToTensor() ]) @arg(1, to_image_color('RGB')) def __call__(self, image): img = self.tfms(image).unsqueeze(0) styled_image = self.model(img) styled_image = numpy.transpose(styled_image, (1,2,0)) styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) styled_image = numpy.array(styled_image) styled_image = styled_image[:, :, ::-1].copy() return Image(styled_image, 'RGB')