diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..32030ff --- /dev/null +++ b/.gitignore @@ -0,0 +1,209 @@ +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### OSX ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..2708095 --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +from .cartoongan import Cartoongan + +def nnoperator_template(*args, **kwargs): + return Cartoongan(*args, **kwargs) \ No newline at end of file diff --git a/cartoongan.py b/cartoongan.py new file mode 100644 index 0000000..7534f5d --- /dev/null +++ b/cartoongan.py @@ -0,0 +1,49 @@ +import logging +import os +import numpy +from pathlib import Path +from PIL import Image as PImage +from torchvision import transforms + +from towhee import register +from towhee.operator import NNOperator, OperatorFlag +from towhee.types import arg, to_image_color +from towhee._types import Image +import warnings +warnings.filterwarnings('ignore') + +log = logging.getLogger() + + +@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,) +class Cartoongan(NNOperator): + """ + A one line summary of this class. + """ + + def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None: + super().__init__() + self._device = device + if framework == 'pytorch': + import importlib.util + path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') + opname = os.path.basename(str(Path(__file__))).split('.')[0] + spec = importlib.util.spec_from_file_location(opname, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + self.model = module.Model(model_name, self._device) + self.tfms = transforms.Compose([ + transforms.ToTensor() + ]) + + @arg(1, to_image_color('RGB')) + def __call__(self, image): + img = self.tfms(image).unsqueeze(0) + styled_image = self.model(img) + + styled_image = numpy.transpose(styled_image, (1, 2, 0)) + styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) + styled_image = numpy.array(styled_image) + styled_image = styled_image[:, :, ::-1].copy() + + return Image(styled_image, 'BGR') \ No newline at end of file diff --git a/pytorch/__init__.py b/pytorch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pytorch/model.py b/pytorch/model.py new file mode 100644 index 0000000..20e274b --- /dev/null +++ b/pytorch/model.py @@ -0,0 +1,200 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path + +class Transformer(nn.Module): + def __init__(self): + super(Transformer, self).__init__() + # + self.refpad01_1 = nn.ReflectionPad2d(3) + self.conv01_1 = nn.Conv2d(3, 64, 7) + self.in01_1 = InstanceNormalization(64) + # relu + self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1) + self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1) + self.in02_1 = InstanceNormalization(128) + # relu + self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1) + self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1) + self.in03_1 = InstanceNormalization(256) + # relu + + ## res block 1 + self.refpad04_1 = nn.ReflectionPad2d(1) + self.conv04_1 = nn.Conv2d(256, 256, 3) + self.in04_1 = InstanceNormalization(256) + # relu + self.refpad04_2 = nn.ReflectionPad2d(1) + self.conv04_2 = nn.Conv2d(256, 256, 3) + self.in04_2 = InstanceNormalization(256) + # + input + + ## res block 2 + self.refpad05_1 = nn.ReflectionPad2d(1) + self.conv05_1 = nn.Conv2d(256, 256, 3) + self.in05_1 = InstanceNormalization(256) + # relu + self.refpad05_2 = nn.ReflectionPad2d(1) + self.conv05_2 = nn.Conv2d(256, 256, 3) + self.in05_2 = InstanceNormalization(256) + # + input + + ## res block 3 + self.refpad06_1 = nn.ReflectionPad2d(1) + self.conv06_1 = nn.Conv2d(256, 256, 3) + self.in06_1 = InstanceNormalization(256) + # relu + self.refpad06_2 = nn.ReflectionPad2d(1) + self.conv06_2 = nn.Conv2d(256, 256, 3) + self.in06_2 = InstanceNormalization(256) + # + input + + ## res block 4 + self.refpad07_1 = nn.ReflectionPad2d(1) + self.conv07_1 = nn.Conv2d(256, 256, 3) + self.in07_1 = InstanceNormalization(256) + # relu + self.refpad07_2 = nn.ReflectionPad2d(1) + self.conv07_2 = nn.Conv2d(256, 256, 3) + self.in07_2 = InstanceNormalization(256) + # + input + + ## res block 5 + self.refpad08_1 = nn.ReflectionPad2d(1) + self.conv08_1 = nn.Conv2d(256, 256, 3) + self.in08_1 = InstanceNormalization(256) + # relu + self.refpad08_2 = nn.ReflectionPad2d(1) + self.conv08_2 = nn.Conv2d(256, 256, 3) + self.in08_2 = InstanceNormalization(256) + # + input + + ## res block 6 + self.refpad09_1 = nn.ReflectionPad2d(1) + self.conv09_1 = nn.Conv2d(256, 256, 3) + self.in09_1 = InstanceNormalization(256) + # relu + self.refpad09_2 = nn.ReflectionPad2d(1) + self.conv09_2 = nn.Conv2d(256, 256, 3) + self.in09_2 = InstanceNormalization(256) + # + input + + ## res block 7 + self.refpad10_1 = nn.ReflectionPad2d(1) + self.conv10_1 = nn.Conv2d(256, 256, 3) + self.in10_1 = InstanceNormalization(256) + # relu + self.refpad10_2 = nn.ReflectionPad2d(1) + self.conv10_2 = nn.Conv2d(256, 256, 3) + self.in10_2 = InstanceNormalization(256) + # + input + + ## res block 8 + self.refpad11_1 = nn.ReflectionPad2d(1) + self.conv11_1 = nn.Conv2d(256, 256, 3) + self.in11_1 = InstanceNormalization(256) + # relu + self.refpad11_2 = nn.ReflectionPad2d(1) + self.conv11_2 = nn.Conv2d(256, 256, 3) + self.in11_2 = InstanceNormalization(256) + # + input + + ##------------------------------------## + self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1) + self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1) + self.in12_1 = InstanceNormalization(128) + # relu + self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1) + self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1) + self.in13_1 = InstanceNormalization(64) + # relu + self.refpad12_1 = nn.ReflectionPad2d(3) + self.deconv03_1 = nn.Conv2d(64, 3, 7) + # tanh + + def forward(self, x): + y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x)))) + y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y)))) + t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y)))) + + ## + y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04)))) + t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04 + + y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05)))) + t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05 + + y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06)))) + t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06 + + y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07)))) + t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07 + + y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08)))) + t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08 + + y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09)))) + t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09 + + y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10)))) + t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10 + + y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11)))) + y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11 + ## + + y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y)))) + y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y)))) + y = F.tanh(self.deconv03_1(self.refpad12_1(y))) + + return y + + +class InstanceNormalization(nn.Module): + def __init__(self, dim, eps=1e-9): + super(InstanceNormalization, self).__init__() + self.scale = nn.Parameter(torch.FloatTensor(dim)) + self.shift = nn.Parameter(torch.FloatTensor(dim)) + self.eps = eps + self._reset_parameters() + + def _reset_parameters(self): + self.scale.data.uniform_() + self.shift.data.zero_() + + def __call__(self, x): + n = x.size(2) * x.size(3) + t = x.view(x.size(0), x.size(1), n) + mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) + # Calculate the biased var. torch.var returns unbiased var + var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n)) + scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0) + scale_broadcast = scale_broadcast.expand_as(x) + shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0) + shift_broadcast = shift_broadcast.expand_as(x) + out = (x - mean) / torch.sqrt(var + self.eps) + out = out * scale_broadcast + shift_broadcast + return out + + +class Model(): + def __init__(self, model_name, device) -> None: + self._device = device + self._model = Transformer() + path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth') + self._model.load_state_dict(torch.load(path)) + self._model.to(self._device) + self._model.eval() + + def __call__(self, img_tensor: Tensor): + img_tensor = img_tensor.to(self._device) + img_tensor = img_tensor * 2 - 1 + + output_image = self._model(img_tensor) + output_image = output_image[0] + # BGR -> RGB + output_image = output_image[[2, 1, 0], :, :] + output_image = output_image.data.cpu().float() * 0.5 + 0.5 + return output_image.numpy() diff --git a/pytorch/weights/Hayao_net_G_float.pth b/pytorch/weights/Hayao_net_G_float.pth new file mode 100644 index 0000000..dc4bd11 --- /dev/null +++ b/pytorch/weights/Hayao_net_G_float.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8ab0e492efb3b705487db38679e363dc8b1f016692913bbe100587d695a9e2b5 +size 44529096 diff --git a/pytorch/weights/Hosoda_net_G_float.pth b/pytorch/weights/Hosoda_net_G_float.pth new file mode 100644 index 0000000..7dd5313 --- /dev/null +++ b/pytorch/weights/Hosoda_net_G_float.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c666eea7700864d5972765cc43e926d900174648297bfef494006dc230fd1bf0 +size 44529096 diff --git a/pytorch/weights/Paprika_net_G_float.pth b/pytorch/weights/Paprika_net_G_float.pth new file mode 100644 index 0000000..c8f87bd --- /dev/null +++ b/pytorch/weights/Paprika_net_G_float.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0629352a54838e56a2ad7fca3e6e51e6889d4338c37469f9ddb43e5929ef9475 +size 44529096 diff --git a/pytorch/weights/Shinkai_net_G_float.pth b/pytorch/weights/Shinkai_net_G_float.pth new file mode 100644 index 0000000..8de7435 --- /dev/null +++ b/pytorch/weights/Shinkai_net_G_float.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3547f611e780e79aebde7f7bc2b6c278555d701620f125583d666351044c486 +size 44529096