logo
Browse Source

Add anime-transfer

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
ee18c11ae2
  1. 27
      README.md
  2. 0
      __init__.py
  3. 51
      anime_transfer_gradio.py

27
README.md

@ -1,2 +1,27 @@
# anime-transfer-gradio
# Operator: anime-transfer-gradio
Author: shiyu22
## Overview
Use gradio to call [style-transfer-animegan](https://towhee.io/filip-halt/style-transfer-animegan).
## Interface
```python
__init__(self)
```
None
```python
__call__(self, source: str = 'upload')
```
Args:
- dource:
- image soure for input image, defauts to 'upload', you can also change to 'webcam'.
- supported types: str

0
__init__.py

51
anime_transfer_gradio.py

@ -0,0 +1,51 @@
import gradio
import numpy
from PIL import Image
from pathlib import Path
from towhee.operator import Operator
from towhee import pipeline
class AnimeTransferGradio(Operator):
"""
AnimeTransferGradio operator
"""
def __init__(self) -> None:
super().__init__()
def __call__(self, source: str = 'upload') -> None:
interface = gradio.Interface(self.trans_img, [gradio.inputs.Image(type="pil", source=source),
gradio.inputs.Radio(["celeba", "facepaintv1","facepaintv2", "hayao", "paprika", 'shinkai'])],
gradio.outputs.Image(type="pil"), allow_flagging='never', allow_screenshot=False)
interface.launch(enable_queue=True)
@staticmethod
def trans_img(input, version):
trans_pipeline = 'filip-halt/style-transfer-animegan'
# Resizing the image while keeping aspect ratio.
size = (512, 512)
input.thumbnail(size, Image.ANTIALIAS)
# Saving image to file for input. Very low chance of concurrent file saves during the time
# between saving and taking first step of pipeline, so avoiding locks for now. In addition,
# current gradio is set to queue so there will never be parallel runs for this.
path = str(Path.cwd() / 'test.jpg')
input.save(path)
if version == 'celeba':
x = pipeline(trans_pipeline, tag='celeba')(path)
elif version == 'facepaintv1':
x = pipeline(trans_pipeline, tag='facepaintv1')(path)
elif version == 'facepaintv2':
x = pipeline(trans_pipeline, tag='facepaintv2')(path)
elif version == 'hayao':
x = pipeline(trans_pipeline, tag='hayao')(path)
elif version == 'paprika':
x = pipeline(trans_pipeline, tag='paprika')(path)
elif version == 'shinkai':
x = pipeline(trans_pipeline, tag='shinkai')(path)
# Converting from channel-first, [0,1] value RGB, numpy array to PIL image.
x = numpy.transpose(x[0][0], (1, 2, 0))
x = Image.fromarray((x * 255).astype(numpy.uint8))
return x
Loading…
Cancel
Save