logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

51 lines
2.1 KiB

import gradio
import numpy
from PIL import Image
from pathlib import Path
from towhee.operator import Operator
from towhee import pipeline
class AnimeTransferGradio(Operator):
"""
AnimeTransferGradio operator
"""
def __init__(self) -> None:
super().__init__()
def __call__(self, source: str = 'upload') -> None:
interface = gradio.Interface(self.trans_img, [gradio.inputs.Image(type="pil", source=source),
gradio.inputs.Radio(["celeba", "facepaintv1","facepaintv2", "hayao", "paprika", 'shinkai'])],
gradio.outputs.Image(type="pil"), allow_flagging='never', allow_screenshot=False)
interface.launch(enable_queue=True)
@staticmethod
def trans_img(input, version):
trans_pipeline = 'filip-halt/style-transfer-animegan'
# Resizing the image while keeping aspect ratio.
size = (512, 512)
input.thumbnail(size, Image.ANTIALIAS)
# Saving image to file for input. Very low chance of concurrent file saves during the time
# between saving and taking first step of pipeline, so avoiding locks for now. In addition,
# current gradio is set to queue so there will never be parallel runs for this.
path = str(Path.cwd() / 'test.jpg')
input.save(path)
if version == 'celeba':
x = pipeline(trans_pipeline, tag='celeba')(path)
elif version == 'facepaintv1':
x = pipeline(trans_pipeline, tag='facepaintv1')(path)
elif version == 'facepaintv2':
x = pipeline(trans_pipeline, tag='facepaintv2')(path)
elif version == 'hayao':
x = pipeline(trans_pipeline, tag='hayao')(path)
elif version == 'paprika':
x = pipeline(trans_pipeline, tag='paprika')(path)
elif version == 'shinkai':
x = pipeline(trans_pipeline, tag='shinkai')(path)
# Converting from channel-first, [0,1] value RGB, numpy array to PIL image.
x = numpy.transpose(x[0][0], (1, 2, 0))
x = Image.fromarray((x * 255).astype(numpy.uint8))
return x