diff --git a/animegan.py b/animegan.py index b6ffba0..9577265 100644 --- a/animegan.py +++ b/animegan.py @@ -1,13 +1,12 @@ import os import numpy from pathlib import Path -from PIL import Image as PImage from torchvision import transforms +import torch from towhee import register from towhee.operator import Operator, OperatorFlag -from towhee.types import arg, to_image_color -from towhee._types import Image +from towhee.types import arg, to_image_color, Image import warnings warnings.filterwarnings('ignore') @@ -27,17 +26,12 @@ class Animegan(Operator): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self.model = module.Model(model_name, self._device) - self.tfms = transforms.Compose([ - transforms.ToTensor() - ]) + self.tfms = transforms.ToTensor() + @arg(1, to_image_color('RGB')) - def __call__(self, image): - img = self.tfms(image).unsqueeze(0) + def __call__(self, img): + img = self.tfms(img).unsqueeze(0) styled_image = self.model(img) - styled_image = numpy.transpose(styled_image, (1,2,0)) - styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) - styled_image = numpy.array(styled_image) - styled_image = styled_image[:, :, ::-1].copy() - return Image(styled_image, 'BGR') + return Image(styled_image, 'RGB')