From 6d1c7ffe85f149ee11e5e1c8fc0b8e5ff2d03566 Mon Sep 17 00:00:00 2001 From: "mengjia.gu@10.xxx.30.11" Date: Mon, 13 Feb 2023 14:05:23 +0800 Subject: [PATCH] Fix for image channels Signed-off-by: mengjia.gu@10.xxx.30.11 --- animegan.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/animegan.py b/animegan.py index b6ffba0..9577265 100644 --- a/animegan.py +++ b/animegan.py @@ -1,13 +1,12 @@ import os import numpy from pathlib import Path -from PIL import Image as PImage from torchvision import transforms +import torch from towhee import register from towhee.operator import Operator, OperatorFlag -from towhee.types import arg, to_image_color -from towhee._types import Image +from towhee.types import arg, to_image_color, Image import warnings warnings.filterwarnings('ignore') @@ -27,17 +26,12 @@ class Animegan(Operator): module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) self.model = module.Model(model_name, self._device) - self.tfms = transforms.Compose([ - transforms.ToTensor() - ]) + self.tfms = transforms.ToTensor() + @arg(1, to_image_color('RGB')) - def __call__(self, image): - img = self.tfms(image).unsqueeze(0) + def __call__(self, img): + img = self.tfms(img).unsqueeze(0) styled_image = self.model(img) - styled_image = numpy.transpose(styled_image, (1,2,0)) - styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) - styled_image = numpy.array(styled_image) - styled_image = styled_image[:, :, ::-1].copy() - return Image(styled_image, 'BGR') + return Image(styled_image, 'RGB')