logo
Browse Source

Fix for image channels

Signed-off-by: mengjia.gu@10.xxx.30.11 <mengjia.gu@10.xxx.30.11>
main
mengjia.gu@10.xxx.30.11 2 years ago
parent
commit
6d1c7ffe85
  1. 20
      animegan.py

20
animegan.py

@ -1,13 +1,12 @@
import os
import numpy
from pathlib import Path
from PIL import Image as PImage
from torchvision import transforms
import torch
from towhee import register
from towhee.operator import Operator, OperatorFlag
from towhee.types import arg, to_image_color
from towhee._types import Image
from towhee.types import arg, to_image_color, Image
import warnings
warnings.filterwarnings('ignore')
@ -27,17 +26,12 @@ class Animegan(Operator):
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name, self._device)
self.tfms = transforms.Compose([
transforms.ToTensor()
])
self.tfms = transforms.ToTensor()
@arg(1, to_image_color('RGB'))
def __call__(self, image):
img = self.tfms(image).unsqueeze(0)
def __call__(self, img):
img = self.tfms(img).unsqueeze(0)
styled_image = self.model(img)
styled_image = numpy.transpose(styled_image, (1,2,0))
styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8))
styled_image = numpy.array(styled_image)
styled_image = styled_image[:, :, ::-1].copy()
return Image(styled_image, 'BGR')
return Image(styled_image, 'RGB')

Loading…
Cancel
Save