diff --git a/README copy.md b/README copy.md new file mode 100644 index 0000000..aa1961e --- /dev/null +++ b/README copy.md @@ -0,0 +1,54 @@ +# AnimeGanV2 Style-Transfer Operator + +Authors: filip + +## Overview + +AnimeGanV2 is a style transfer net that transforms images to looking like they fit in an anime movie. + +## Interface + +```python +__init__(self, model_name: str, framework: str = 'pytorch') +``` + +**Args:** + +- model_name: + - which weights to use for inference. + - supports 'celeba', 'facepaintv1', 'facepaitv2', 'hayao', 'paprika', 'shinkai' +- framework: + - the framework of the model + - supported types: `str`, default is 'pytorch' + +```python +__call__(self, image: 'towhee.types.Image') +``` + +**Args:** + +- image: + - the input image + - supported types: `towhee.types.Image` + +**Returns:** + +The Operator returns a tuple `Tuple[('styled_image', numpy.ndarray)]` containing following fields: + +- styled_image: + - styled photo + - data type: `numpy.ndarray` + - shape: (3, x, x) + - format: RGB + - values: [0,1] + +## Requirements + +You can get the required python package by [requirements.txt](./requirements.txt). + + +## Reference + +Jie Chen, Gang Liu, Xin Chen +"AnimeGAN: A Novel Lightweight GAN for Photo Animation." +ISICA 2019: Artificial Intelligence Algorithms and Applications pp 242-256, 2019. diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..47ec405 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +from .animegan import Animegan +def animegan(name): + return Animegan(name) \ No newline at end of file diff --git a/animegan.py b/animegan.py new file mode 100644 index 0000000..18c3a08 --- /dev/null +++ b/animegan.py @@ -0,0 +1,34 @@ +import os +from pathlib import Path +from torchvision import transforms + +from towhee import register +from towhee.operator import Operator, OperatorFlag +from towhee.types import arg, to_image_color +from towhee._types import Image +import warnings +warnings.filterwarnings('ignore') + +@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,) +class Animegan(Operator): + """ + PyTorch model for image embedding. + """ + def __init__(self, model_name: str, framework: str = 'pytorch') -> None: + super().__init__() + if framework == 'pytorch': + import importlib.util + path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') + opname = os.path.basename(str(Path(__file__))).split('.')[0] + spec = importlib.util.spec_from_file_location(opname, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.model = module.Model(model_name) + self.tfms = transforms.Compose([ + transforms.ToTensor() + ]) + @arg(1, to_image_color('RGB') ) + def __call__(self, image): + img = self.tfms(image).unsqueeze(0) + styled_image = self.model(img) + return Image(styled_image, 'RGB') diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch/model.py b/pytorch/model.py new file mode 100644 index 0000000..361ecc7 --- /dev/null +++ b/pytorch/model.py @@ -0,0 +1,133 @@ +from torch import nn, load, Tensor +import os +from pathlib import Path + + +class ConvNormLReLU(nn.Sequential): + def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False): + + pad_layer = { + "zero": nn.ZeroPad2d, + "same": nn.ReplicationPad2d, + "reflect": nn.ReflectionPad2d, + } + if pad_mode not in pad_layer: + raise NotImplementedError + + super(ConvNormLReLU, self).__init__( + pad_layer[pad_mode](padding), + nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias), + nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True), + nn.LeakyReLU(0.2, inplace=True) + ) + + +class InvertedResBlock(nn.Module): + def __init__(self, in_ch, out_ch, expansion_ratio=2): + super(InvertedResBlock, self).__init__() + + self.use_res_connect = in_ch == out_ch + bottleneck = int(round(in_ch*expansion_ratio)) + layers = [] + if expansion_ratio != 1: + layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0)) + + # dw + layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True)) + # pw + layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False)) + layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True)) + + self.layers = nn.Sequential(*layers) + + def forward(self, input): + out = self.layers(input) + if self.use_res_connect: + out = input + out + return out + + +class Generator(nn.Module): + def __init__(self, ): + super().__init__() + + self.block_a = nn.Sequential( + ConvNormLReLU(3, 32, kernel_size=7, padding=3), + ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)), + ConvNormLReLU(64, 64) + ) + + self.block_b = nn.Sequential( + ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)), + ConvNormLReLU(128, 128) + ) + + self.block_c = nn.Sequential( + ConvNormLReLU(128, 128), + InvertedResBlock(128, 256, 2), + InvertedResBlock(256, 256, 2), + InvertedResBlock(256, 256, 2), + InvertedResBlock(256, 256, 2), + ConvNormLReLU(256, 128), + ) + + self.block_d = nn.Sequential( + ConvNormLReLU(128, 128), + ConvNormLReLU(128, 128) + ) + + self.block_e = nn.Sequential( + ConvNormLReLU(128, 64), + ConvNormLReLU(64, 64), + ConvNormLReLU(64, 32, kernel_size=7, padding=3) + ) + + self.out_layer = nn.Sequential( + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False), + nn.Tanh() + ) + + def forward(self, input, align_corners=True): + out = self.block_a(input) + half_size = out.size()[-2:] + out = self.block_b(out) + out = self.block_c(out) + + if align_corners: + out = nn.functional.interpolate(out, half_size, mode="bilinear", align_corners=True) + else: + out = nn.functional.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + out = self.block_d(out) + + if align_corners: + out = nn.functional.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True) + else: + out = nn.functional.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False) + out = self.block_e(out) + + out = self.out_layer(out) + return out + +class Model(): + def __init__(self, model_name) -> None: + self._model = Generator() + path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt') + ckpt = load(path) + self._model.load_state_dict(ckpt) + self._model.eval() + + + def __call__(self, img_tensor: Tensor): + img_tensor = img_tensor * 2 - 1 + out = self._model(img_tensor).detach() + out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 + return out.numpy() + + def train(self): + """ + For training model + """ + pass + + + \ No newline at end of file diff --git a/pytorch/weights/celeba.pt b/pytorch/weights/celeba.pt new file mode 100644 index 0000000..b269a2e --- /dev/null +++ b/pytorch/weights/celeba.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a3740d98f99efe2ee6c332de2b800f542ddbb2d15e835c07e9bf667c29cef8a7 +size 8603556 diff --git a/pytorch/weights/facepaintv1.pt b/pytorch/weights/facepaintv1.pt new file mode 100644 index 0000000..ff16e49 --- /dev/null +++ b/pytorch/weights/facepaintv1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f27b45d17c6f4d027753026aeb7e0a558bb95cff5d03a207ac06ca0a372d5316 +size 8603556 diff --git a/pytorch/weights/facepaintv2.pt b/pytorch/weights/facepaintv2.pt new file mode 100644 index 0000000..76f61ce --- /dev/null +++ b/pytorch/weights/facepaintv2.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:06b88a204eb230889444ad868ee5608f4fce5d4ff7b7738acaa4209c2b8fdca7 +size 8601086 diff --git a/pytorch/weights/hayao.pt b/pytorch/weights/hayao.pt new file mode 100644 index 0000000..7312811 --- /dev/null +++ b/pytorch/weights/hayao.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:96e5c586c944fbca18a108a698c02011a96baa22803b77aab8d0b49f5d0b204d +size 8601086 diff --git a/pytorch/weights/paprika.pt b/pytorch/weights/paprika.pt new file mode 100644 index 0000000..997eded --- /dev/null +++ b/pytorch/weights/paprika.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3eaba1b6d01e88ea32b16a6006b04eb3f60327c1dcc841de640d3196898da344 +size 8603556 diff --git a/pytorch/weights/shinkai.pt b/pytorch/weights/shinkai.pt new file mode 100644 index 0000000..d0ed068 --- /dev/null +++ b/pytorch/weights/shinkai.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b2c8512f263f2ccf112a6cbe3ec1cef3bb6e17c9076211e37e6e1f2324fe2b1e +size 8601086 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..8414abf --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +pathlib +torchvision \ No newline at end of file