logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

53 lines
2.0 KiB

import gradio
import numpy
from PIL import Image
from pathlib import Path
from typing import NamedTuple
from towhee.operator import Operator
from towhee import pipeline
path = str(Path.cwd() / 'test.jpg')
class AnimeTransferGradio(Operator):
"""
AnimeTransferGradio operator
"""
def __init__(self) -> None:
super().__init__()
def __call__(self, source: str = 'upload') -> NamedTuple('Outputs', [('img_path', str)]):
interface = gradio.Interface(self.trans_img, [gradio.inputs.Image(type="pil", source=source),
gradio.inputs.Radio(["celeba", "facepaintv1","facepaintv2", "hayao", "paprika", 'shinkai'])],
gradio.outputs.Image(type="pil"), allow_flagging='never', allow_screenshot=False)
interface.launch(enable_queue=True)
Outputs = NamedTuple('Outputs', [('img_path', str)])
return Outputs(path)
@staticmethod
def trans_img(input, version):
trans_pipeline = 'filip-halt/style-transfer-animegan'
# Resizing the image while keeping aspect ratio.
size = (512, 512)
input.thumbnail(size, Image.ANTIALIAS)
input.save(path)
if version == 'celeba':
x = pipeline(trans_pipeline, tag='celeba')(path)
elif version == 'facepaintv1':
x = pipeline(trans_pipeline, tag='facepaintv1')(path)
elif version == 'facepaintv2':
x = pipeline(trans_pipeline, tag='facepaintv2')(path)
elif version == 'hayao':
x = pipeline(trans_pipeline, tag='hayao')(path)
elif version == 'paprika':
x = pipeline(trans_pipeline, tag='paprika')(path)
elif version == 'shinkai':
x = pipeline(trans_pipeline, tag='shinkai')(path)
# Converting from channel-first, [0,1] value RGB, numpy array to PIL image.
x = numpy.transpose(x[0][0], (1, 2, 0))
x = Image.fromarray((x * 255).astype(numpy.uint8))
return x