diff --git a/README.md b/README.md index b8797b7..0f3387c 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,27 @@ -# anime-transfer-gradio +# Operator: anime-transfer-gradio + +Author: shiyu22 + +## Overview + +Use gradio to call [style-transfer-animegan](https://towhee.io/filip-halt/style-transfer-animegan). + +## Interface + +```python +__init__(self) +``` + +None + +```python +__call__(self, source: str = 'upload') +``` + +Args: + +- dource: + - image soure for input image, defauts to 'upload', you can also change to 'webcam'. + + - supported types: str diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/anime_transfer_gradio.py b/anime_transfer_gradio.py new file mode 100644 index 0000000..5ddc40d --- /dev/null +++ b/anime_transfer_gradio.py @@ -0,0 +1,51 @@ +import gradio +import numpy +from PIL import Image +from pathlib import Path + +from towhee.operator import Operator +from towhee import pipeline + +class AnimeTransferGradio(Operator): + """ + AnimeTransferGradio operator + """ + def __init__(self) -> None: + super().__init__() + + def __call__(self, source: str = 'upload') -> None: + interface = gradio.Interface(self.trans_img, [gradio.inputs.Image(type="pil", source=source), + gradio.inputs.Radio(["celeba", "facepaintv1","facepaintv2", "hayao", "paprika", 'shinkai'])], + gradio.outputs.Image(type="pil"), allow_flagging='never', allow_screenshot=False) + interface.launch(enable_queue=True) + + @staticmethod + def trans_img(input, version): + trans_pipeline = 'filip-halt/style-transfer-animegan' + + # Resizing the image while keeping aspect ratio. + size = (512, 512) + input.thumbnail(size, Image.ANTIALIAS) + # Saving image to file for input. Very low chance of concurrent file saves during the time + # between saving and taking first step of pipeline, so avoiding locks for now. In addition, + # current gradio is set to queue so there will never be parallel runs for this. + path = str(Path.cwd() / 'test.jpg') + input.save(path) + + if version == 'celeba': + x = pipeline(trans_pipeline, tag='celeba')(path) + elif version == 'facepaintv1': + x = pipeline(trans_pipeline, tag='facepaintv1')(path) + elif version == 'facepaintv2': + x = pipeline(trans_pipeline, tag='facepaintv2')(path) + elif version == 'hayao': + x = pipeline(trans_pipeline, tag='hayao')(path) + elif version == 'paprika': + x = pipeline(trans_pipeline, tag='paprika')(path) + elif version == 'shinkai': + x = pipeline(trans_pipeline, tag='shinkai')(path) + + # Converting from channel-first, [0,1] value RGB, numpy array to PIL image. + x = numpy.transpose(x[0][0], (1, 2, 0)) + x = Image.fromarray((x * 255).astype(numpy.uint8)) + return x