cartoongan
copied
shiyu22
3 years ago
9 changed files with 474 additions and 0 deletions
@ -0,0 +1,209 @@ |
|||||
|
### Linux ### |
||||
|
*~ |
||||
|
|
||||
|
# temporary files which can be created if a process still has a handle open of a deleted file |
||||
|
.fuse_hidden* |
||||
|
|
||||
|
# KDE directory preferences |
||||
|
.directory |
||||
|
|
||||
|
# Linux trash folder which might appear on any partition or disk |
||||
|
.Trash-* |
||||
|
|
||||
|
# .nfs files are created when an open file is removed but is still being accessed |
||||
|
.nfs* |
||||
|
|
||||
|
### OSX ### |
||||
|
# General |
||||
|
.DS_Store |
||||
|
.AppleDouble |
||||
|
.LSOverride |
||||
|
|
||||
|
# Icon must end with two \r |
||||
|
Icon |
||||
|
|
||||
|
|
||||
|
# Thumbnails |
||||
|
._* |
||||
|
|
||||
|
# Files that might appear in the root of a volume |
||||
|
.DocumentRevisions-V100 |
||||
|
.fseventsd |
||||
|
.Spotlight-V100 |
||||
|
.TemporaryItems |
||||
|
.Trashes |
||||
|
.VolumeIcon.icns |
||||
|
.com.apple.timemachine.donotpresent |
||||
|
|
||||
|
# Directories potentially created on remote AFP share |
||||
|
.AppleDB |
||||
|
.AppleDesktop |
||||
|
Network Trash Folder |
||||
|
Temporary Items |
||||
|
.apdisk |
||||
|
|
||||
|
### Python ### |
||||
|
# Byte-compiled / optimized / DLL files |
||||
|
__pycache__/ |
||||
|
*.py[cod] |
||||
|
*$py.class |
||||
|
|
||||
|
# C extensions |
||||
|
*.so |
||||
|
|
||||
|
# Distribution / packaging |
||||
|
.Python |
||||
|
build/ |
||||
|
develop-eggs/ |
||||
|
dist/ |
||||
|
downloads/ |
||||
|
eggs/ |
||||
|
.eggs/ |
||||
|
lib/ |
||||
|
lib64/ |
||||
|
parts/ |
||||
|
sdist/ |
||||
|
var/ |
||||
|
wheels/ |
||||
|
share/python-wheels/ |
||||
|
*.egg-info/ |
||||
|
.installed.cfg |
||||
|
*.egg |
||||
|
MANIFEST |
||||
|
|
||||
|
# PyInstaller |
||||
|
# Usually these files are written by a python script from a template |
||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
||||
|
*.manifest |
||||
|
*.spec |
||||
|
|
||||
|
# Installer logs |
||||
|
pip-log.txt |
||||
|
pip-delete-this-directory.txt |
||||
|
|
||||
|
# Unit test / coverage reports |
||||
|
htmlcov/ |
||||
|
.tox/ |
||||
|
.nox/ |
||||
|
.coverage |
||||
|
.coverage.* |
||||
|
.cache |
||||
|
nosetests.xml |
||||
|
coverage.xml |
||||
|
*.cover |
||||
|
*.py,cover |
||||
|
.hypothesis/ |
||||
|
.pytest_cache/ |
||||
|
cover/ |
||||
|
|
||||
|
# Translations |
||||
|
*.mo |
||||
|
*.pot |
||||
|
|
||||
|
# Django stuff: |
||||
|
*.log |
||||
|
local_settings.py |
||||
|
db.sqlite3 |
||||
|
db.sqlite3-journal |
||||
|
|
||||
|
# Flask stuff: |
||||
|
instance/ |
||||
|
.webassets-cache |
||||
|
|
||||
|
# Scrapy stuff: |
||||
|
.scrapy |
||||
|
|
||||
|
# Sphinx documentation |
||||
|
docs/_build/ |
||||
|
|
||||
|
# PyBuilder |
||||
|
.pybuilder/ |
||||
|
target/ |
||||
|
|
||||
|
# Jupyter Notebook |
||||
|
.ipynb_checkpoints |
||||
|
|
||||
|
# IPython |
||||
|
profile_default/ |
||||
|
ipython_config.py |
||||
|
|
||||
|
# pyenv |
||||
|
# For a library or package, you might want to ignore these files since the code is |
||||
|
# intended to run in multiple environments; otherwise, check them in: |
||||
|
# .python-version |
||||
|
|
||||
|
# pipenv |
||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
||||
|
# install all needed dependencies. |
||||
|
#Pipfile.lock |
||||
|
|
||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
||||
|
__pypackages__/ |
||||
|
|
||||
|
# Celery stuff |
||||
|
celerybeat-schedule |
||||
|
celerybeat.pid |
||||
|
|
||||
|
# SageMath parsed files |
||||
|
*.sage.py |
||||
|
|
||||
|
# Environments |
||||
|
.env |
||||
|
.venv |
||||
|
env/ |
||||
|
venv/ |
||||
|
ENV/ |
||||
|
env.bak/ |
||||
|
venv.bak/ |
||||
|
|
||||
|
# Spyder project settings |
||||
|
.spyderproject |
||||
|
.spyproject |
||||
|
|
||||
|
# Rope project settings |
||||
|
.ropeproject |
||||
|
|
||||
|
# mkdocs documentation |
||||
|
/site |
||||
|
|
||||
|
# mypy |
||||
|
.mypy_cache/ |
||||
|
.dmypy.json |
||||
|
dmypy.json |
||||
|
|
||||
|
# Pyre type checker |
||||
|
.pyre/ |
||||
|
|
||||
|
# pytype static type analyzer |
||||
|
.pytype/ |
||||
|
|
||||
|
# Cython debug symbols |
||||
|
cython_debug/ |
||||
|
|
||||
|
### Windows ### |
||||
|
# Windows thumbnail cache files |
||||
|
Thumbs.db |
||||
|
Thumbs.db:encryptable |
||||
|
ehthumbs.db |
||||
|
ehthumbs_vista.db |
||||
|
|
||||
|
# Dump file |
||||
|
*.stackdump |
||||
|
|
||||
|
# Folder config file |
||||
|
[Dd]esktop.ini |
||||
|
|
||||
|
# Recycle Bin used on file shares |
||||
|
$RECYCLE.BIN/ |
||||
|
|
||||
|
# Windows Installer files |
||||
|
*.cab |
||||
|
*.msi |
||||
|
*.msix |
||||
|
*.msm |
||||
|
*.msp |
||||
|
|
||||
|
# Windows shortcuts |
||||
|
*.lnk |
@ -0,0 +1,4 @@ |
|||||
|
from .cartoongan import Cartoongan |
||||
|
|
||||
|
def nnoperator_template(*args, **kwargs): |
||||
|
return Cartoongan(*args, **kwargs) |
@ -0,0 +1,49 @@ |
|||||
|
import logging |
||||
|
import os |
||||
|
import numpy |
||||
|
from pathlib import Path |
||||
|
from PIL import Image as PImage |
||||
|
from torchvision import transforms |
||||
|
|
||||
|
from towhee import register |
||||
|
from towhee.operator import NNOperator, OperatorFlag |
||||
|
from towhee.types import arg, to_image_color |
||||
|
from towhee._types import Image |
||||
|
import warnings |
||||
|
warnings.filterwarnings('ignore') |
||||
|
|
||||
|
log = logging.getLogger() |
||||
|
|
||||
|
|
||||
|
@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,) |
||||
|
class Cartoongan(NNOperator): |
||||
|
""" |
||||
|
A one line summary of this class. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None: |
||||
|
super().__init__() |
||||
|
self._device = device |
||||
|
if framework == 'pytorch': |
||||
|
import importlib.util |
||||
|
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') |
||||
|
opname = os.path.basename(str(Path(__file__))).split('.')[0] |
||||
|
spec = importlib.util.spec_from_file_location(opname, path) |
||||
|
module = importlib.util.module_from_spec(spec) |
||||
|
spec.loader.exec_module(module) |
||||
|
self.model = module.Model(model_name, self._device) |
||||
|
self.tfms = transforms.Compose([ |
||||
|
transforms.ToTensor() |
||||
|
]) |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def __call__(self, image): |
||||
|
img = self.tfms(image).unsqueeze(0) |
||||
|
styled_image = self.model(img) |
||||
|
|
||||
|
styled_image = numpy.transpose(styled_image, (1, 2, 0)) |
||||
|
styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) |
||||
|
styled_image = numpy.array(styled_image) |
||||
|
styled_image = styled_image[:, :, ::-1].copy() |
||||
|
|
||||
|
return Image(styled_image, 'BGR') |
@ -0,0 +1,200 @@ |
|||||
|
import os |
||||
|
import torch |
||||
|
import torch.nn as nn |
||||
|
import torch.nn.functional as F |
||||
|
from pathlib import Path |
||||
|
|
||||
|
class Transformer(nn.Module): |
||||
|
def __init__(self): |
||||
|
super(Transformer, self).__init__() |
||||
|
# |
||||
|
self.refpad01_1 = nn.ReflectionPad2d(3) |
||||
|
self.conv01_1 = nn.Conv2d(3, 64, 7) |
||||
|
self.in01_1 = InstanceNormalization(64) |
||||
|
# relu |
||||
|
self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1) |
||||
|
self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1) |
||||
|
self.in02_1 = InstanceNormalization(128) |
||||
|
# relu |
||||
|
self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1) |
||||
|
self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1) |
||||
|
self.in03_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
|
||||
|
## res block 1 |
||||
|
self.refpad04_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv04_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in04_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad04_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv04_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in04_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 2 |
||||
|
self.refpad05_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv05_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in05_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad05_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv05_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in05_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 3 |
||||
|
self.refpad06_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv06_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in06_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad06_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv06_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in06_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 4 |
||||
|
self.refpad07_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv07_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in07_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad07_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv07_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in07_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 5 |
||||
|
self.refpad08_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv08_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in08_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad08_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv08_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in08_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 6 |
||||
|
self.refpad09_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv09_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in09_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad09_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv09_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in09_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 7 |
||||
|
self.refpad10_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv10_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in10_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad10_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv10_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in10_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
## res block 8 |
||||
|
self.refpad11_1 = nn.ReflectionPad2d(1) |
||||
|
self.conv11_1 = nn.Conv2d(256, 256, 3) |
||||
|
self.in11_1 = InstanceNormalization(256) |
||||
|
# relu |
||||
|
self.refpad11_2 = nn.ReflectionPad2d(1) |
||||
|
self.conv11_2 = nn.Conv2d(256, 256, 3) |
||||
|
self.in11_2 = InstanceNormalization(256) |
||||
|
# + input |
||||
|
|
||||
|
##------------------------------------## |
||||
|
self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1) |
||||
|
self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1) |
||||
|
self.in12_1 = InstanceNormalization(128) |
||||
|
# relu |
||||
|
self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1) |
||||
|
self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1) |
||||
|
self.in13_1 = InstanceNormalization(64) |
||||
|
# relu |
||||
|
self.refpad12_1 = nn.ReflectionPad2d(3) |
||||
|
self.deconv03_1 = nn.Conv2d(64, 3, 7) |
||||
|
# tanh |
||||
|
|
||||
|
def forward(self, x): |
||||
|
y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x)))) |
||||
|
y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y)))) |
||||
|
t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y)))) |
||||
|
|
||||
|
## |
||||
|
y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04)))) |
||||
|
t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04 |
||||
|
|
||||
|
y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05)))) |
||||
|
t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05 |
||||
|
|
||||
|
y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06)))) |
||||
|
t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06 |
||||
|
|
||||
|
y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07)))) |
||||
|
t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07 |
||||
|
|
||||
|
y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08)))) |
||||
|
t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08 |
||||
|
|
||||
|
y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09)))) |
||||
|
t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09 |
||||
|
|
||||
|
y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10)))) |
||||
|
t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10 |
||||
|
|
||||
|
y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11)))) |
||||
|
y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11 |
||||
|
## |
||||
|
|
||||
|
y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y)))) |
||||
|
y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y)))) |
||||
|
y = F.tanh(self.deconv03_1(self.refpad12_1(y))) |
||||
|
|
||||
|
return y |
||||
|
|
||||
|
|
||||
|
class InstanceNormalization(nn.Module): |
||||
|
def __init__(self, dim, eps=1e-9): |
||||
|
super(InstanceNormalization, self).__init__() |
||||
|
self.scale = nn.Parameter(torch.FloatTensor(dim)) |
||||
|
self.shift = nn.Parameter(torch.FloatTensor(dim)) |
||||
|
self.eps = eps |
||||
|
self._reset_parameters() |
||||
|
|
||||
|
def _reset_parameters(self): |
||||
|
self.scale.data.uniform_() |
||||
|
self.shift.data.zero_() |
||||
|
|
||||
|
def __call__(self, x): |
||||
|
n = x.size(2) * x.size(3) |
||||
|
t = x.view(x.size(0), x.size(1), n) |
||||
|
mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) |
||||
|
# Calculate the biased var. torch.var returns unbiased var |
||||
|
var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n)) |
||||
|
scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0) |
||||
|
scale_broadcast = scale_broadcast.expand_as(x) |
||||
|
shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0) |
||||
|
shift_broadcast = shift_broadcast.expand_as(x) |
||||
|
out = (x - mean) / torch.sqrt(var + self.eps) |
||||
|
out = out * scale_broadcast + shift_broadcast |
||||
|
return out |
||||
|
|
||||
|
|
||||
|
class Model(): |
||||
|
def __init__(self, model_name, device) -> None: |
||||
|
self._device = device |
||||
|
self._model = Transformer() |
||||
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth') |
||||
|
self._model.load_state_dict(torch.load(path)) |
||||
|
self._model.to(self._device) |
||||
|
self._model.eval() |
||||
|
|
||||
|
def __call__(self, img_tensor: Tensor): |
||||
|
img_tensor = img_tensor.to(self._device) |
||||
|
img_tensor = img_tensor * 2 - 1 |
||||
|
|
||||
|
output_image = self._model(img_tensor) |
||||
|
output_image = output_image[0] |
||||
|
# BGR -> RGB |
||||
|
output_image = output_image[[2, 1, 0], :, :] |
||||
|
output_image = output_image.data.cpu().float() * 0.5 + 0.5 |
||||
|
return output_image.numpy() |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue