cartoongan
copied
shiyu22
3 years ago
9 changed files with 474 additions and 0 deletions
@ -0,0 +1,209 @@ |
|||
### Linux ### |
|||
*~ |
|||
|
|||
# temporary files which can be created if a process still has a handle open of a deleted file |
|||
.fuse_hidden* |
|||
|
|||
# KDE directory preferences |
|||
.directory |
|||
|
|||
# Linux trash folder which might appear on any partition or disk |
|||
.Trash-* |
|||
|
|||
# .nfs files are created when an open file is removed but is still being accessed |
|||
.nfs* |
|||
|
|||
### OSX ### |
|||
# General |
|||
.DS_Store |
|||
.AppleDouble |
|||
.LSOverride |
|||
|
|||
# Icon must end with two \r |
|||
Icon |
|||
|
|||
|
|||
# Thumbnails |
|||
._* |
|||
|
|||
# Files that might appear in the root of a volume |
|||
.DocumentRevisions-V100 |
|||
.fseventsd |
|||
.Spotlight-V100 |
|||
.TemporaryItems |
|||
.Trashes |
|||
.VolumeIcon.icns |
|||
.com.apple.timemachine.donotpresent |
|||
|
|||
# Directories potentially created on remote AFP share |
|||
.AppleDB |
|||
.AppleDesktop |
|||
Network Trash Folder |
|||
Temporary Items |
|||
.apdisk |
|||
|
|||
### Python ### |
|||
# Byte-compiled / optimized / DLL files |
|||
__pycache__/ |
|||
*.py[cod] |
|||
*$py.class |
|||
|
|||
# C extensions |
|||
*.so |
|||
|
|||
# Distribution / packaging |
|||
.Python |
|||
build/ |
|||
develop-eggs/ |
|||
dist/ |
|||
downloads/ |
|||
eggs/ |
|||
.eggs/ |
|||
lib/ |
|||
lib64/ |
|||
parts/ |
|||
sdist/ |
|||
var/ |
|||
wheels/ |
|||
share/python-wheels/ |
|||
*.egg-info/ |
|||
.installed.cfg |
|||
*.egg |
|||
MANIFEST |
|||
|
|||
# PyInstaller |
|||
# Usually these files are written by a python script from a template |
|||
# before PyInstaller builds the exe, so as to inject date/other infos into it. |
|||
*.manifest |
|||
*.spec |
|||
|
|||
# Installer logs |
|||
pip-log.txt |
|||
pip-delete-this-directory.txt |
|||
|
|||
# Unit test / coverage reports |
|||
htmlcov/ |
|||
.tox/ |
|||
.nox/ |
|||
.coverage |
|||
.coverage.* |
|||
.cache |
|||
nosetests.xml |
|||
coverage.xml |
|||
*.cover |
|||
*.py,cover |
|||
.hypothesis/ |
|||
.pytest_cache/ |
|||
cover/ |
|||
|
|||
# Translations |
|||
*.mo |
|||
*.pot |
|||
|
|||
# Django stuff: |
|||
*.log |
|||
local_settings.py |
|||
db.sqlite3 |
|||
db.sqlite3-journal |
|||
|
|||
# Flask stuff: |
|||
instance/ |
|||
.webassets-cache |
|||
|
|||
# Scrapy stuff: |
|||
.scrapy |
|||
|
|||
# Sphinx documentation |
|||
docs/_build/ |
|||
|
|||
# PyBuilder |
|||
.pybuilder/ |
|||
target/ |
|||
|
|||
# Jupyter Notebook |
|||
.ipynb_checkpoints |
|||
|
|||
# IPython |
|||
profile_default/ |
|||
ipython_config.py |
|||
|
|||
# pyenv |
|||
# For a library or package, you might want to ignore these files since the code is |
|||
# intended to run in multiple environments; otherwise, check them in: |
|||
# .python-version |
|||
|
|||
# pipenv |
|||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. |
|||
# However, in case of collaboration, if having platform-specific dependencies or dependencies |
|||
# having no cross-platform support, pipenv may install dependencies that don't work, or not |
|||
# install all needed dependencies. |
|||
#Pipfile.lock |
|||
|
|||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow |
|||
__pypackages__/ |
|||
|
|||
# Celery stuff |
|||
celerybeat-schedule |
|||
celerybeat.pid |
|||
|
|||
# SageMath parsed files |
|||
*.sage.py |
|||
|
|||
# Environments |
|||
.env |
|||
.venv |
|||
env/ |
|||
venv/ |
|||
ENV/ |
|||
env.bak/ |
|||
venv.bak/ |
|||
|
|||
# Spyder project settings |
|||
.spyderproject |
|||
.spyproject |
|||
|
|||
# Rope project settings |
|||
.ropeproject |
|||
|
|||
# mkdocs documentation |
|||
/site |
|||
|
|||
# mypy |
|||
.mypy_cache/ |
|||
.dmypy.json |
|||
dmypy.json |
|||
|
|||
# Pyre type checker |
|||
.pyre/ |
|||
|
|||
# pytype static type analyzer |
|||
.pytype/ |
|||
|
|||
# Cython debug symbols |
|||
cython_debug/ |
|||
|
|||
### Windows ### |
|||
# Windows thumbnail cache files |
|||
Thumbs.db |
|||
Thumbs.db:encryptable |
|||
ehthumbs.db |
|||
ehthumbs_vista.db |
|||
|
|||
# Dump file |
|||
*.stackdump |
|||
|
|||
# Folder config file |
|||
[Dd]esktop.ini |
|||
|
|||
# Recycle Bin used on file shares |
|||
$RECYCLE.BIN/ |
|||
|
|||
# Windows Installer files |
|||
*.cab |
|||
*.msi |
|||
*.msix |
|||
*.msm |
|||
*.msp |
|||
|
|||
# Windows shortcuts |
|||
*.lnk |
@ -0,0 +1,4 @@ |
|||
from .cartoongan import Cartoongan |
|||
|
|||
def nnoperator_template(*args, **kwargs): |
|||
return Cartoongan(*args, **kwargs) |
@ -0,0 +1,49 @@ |
|||
import logging |
|||
import os |
|||
import numpy |
|||
from pathlib import Path |
|||
from PIL import Image as PImage |
|||
from torchvision import transforms |
|||
|
|||
from towhee import register |
|||
from towhee.operator import NNOperator, OperatorFlag |
|||
from towhee.types import arg, to_image_color |
|||
from towhee._types import Image |
|||
import warnings |
|||
warnings.filterwarnings('ignore') |
|||
|
|||
log = logging.getLogger() |
|||
|
|||
|
|||
@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,) |
|||
class Cartoongan(NNOperator): |
|||
""" |
|||
A one line summary of this class. |
|||
""" |
|||
|
|||
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None: |
|||
super().__init__() |
|||
self._device = device |
|||
if framework == 'pytorch': |
|||
import importlib.util |
|||
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') |
|||
opname = os.path.basename(str(Path(__file__))).split('.')[0] |
|||
spec = importlib.util.spec_from_file_location(opname, path) |
|||
module = importlib.util.module_from_spec(spec) |
|||
spec.loader.exec_module(module) |
|||
self.model = module.Model(model_name, self._device) |
|||
self.tfms = transforms.Compose([ |
|||
transforms.ToTensor() |
|||
]) |
|||
|
|||
@arg(1, to_image_color('RGB')) |
|||
def __call__(self, image): |
|||
img = self.tfms(image).unsqueeze(0) |
|||
styled_image = self.model(img) |
|||
|
|||
styled_image = numpy.transpose(styled_image, (1, 2, 0)) |
|||
styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8)) |
|||
styled_image = numpy.array(styled_image) |
|||
styled_image = styled_image[:, :, ::-1].copy() |
|||
|
|||
return Image(styled_image, 'BGR') |
@ -0,0 +1,200 @@ |
|||
import os |
|||
import torch |
|||
import torch.nn as nn |
|||
import torch.nn.functional as F |
|||
from pathlib import Path |
|||
|
|||
class Transformer(nn.Module): |
|||
def __init__(self): |
|||
super(Transformer, self).__init__() |
|||
# |
|||
self.refpad01_1 = nn.ReflectionPad2d(3) |
|||
self.conv01_1 = nn.Conv2d(3, 64, 7) |
|||
self.in01_1 = InstanceNormalization(64) |
|||
# relu |
|||
self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1) |
|||
self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1) |
|||
self.in02_1 = InstanceNormalization(128) |
|||
# relu |
|||
self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1) |
|||
self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1) |
|||
self.in03_1 = InstanceNormalization(256) |
|||
# relu |
|||
|
|||
## res block 1 |
|||
self.refpad04_1 = nn.ReflectionPad2d(1) |
|||
self.conv04_1 = nn.Conv2d(256, 256, 3) |
|||
self.in04_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad04_2 = nn.ReflectionPad2d(1) |
|||
self.conv04_2 = nn.Conv2d(256, 256, 3) |
|||
self.in04_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 2 |
|||
self.refpad05_1 = nn.ReflectionPad2d(1) |
|||
self.conv05_1 = nn.Conv2d(256, 256, 3) |
|||
self.in05_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad05_2 = nn.ReflectionPad2d(1) |
|||
self.conv05_2 = nn.Conv2d(256, 256, 3) |
|||
self.in05_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 3 |
|||
self.refpad06_1 = nn.ReflectionPad2d(1) |
|||
self.conv06_1 = nn.Conv2d(256, 256, 3) |
|||
self.in06_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad06_2 = nn.ReflectionPad2d(1) |
|||
self.conv06_2 = nn.Conv2d(256, 256, 3) |
|||
self.in06_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 4 |
|||
self.refpad07_1 = nn.ReflectionPad2d(1) |
|||
self.conv07_1 = nn.Conv2d(256, 256, 3) |
|||
self.in07_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad07_2 = nn.ReflectionPad2d(1) |
|||
self.conv07_2 = nn.Conv2d(256, 256, 3) |
|||
self.in07_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 5 |
|||
self.refpad08_1 = nn.ReflectionPad2d(1) |
|||
self.conv08_1 = nn.Conv2d(256, 256, 3) |
|||
self.in08_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad08_2 = nn.ReflectionPad2d(1) |
|||
self.conv08_2 = nn.Conv2d(256, 256, 3) |
|||
self.in08_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 6 |
|||
self.refpad09_1 = nn.ReflectionPad2d(1) |
|||
self.conv09_1 = nn.Conv2d(256, 256, 3) |
|||
self.in09_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad09_2 = nn.ReflectionPad2d(1) |
|||
self.conv09_2 = nn.Conv2d(256, 256, 3) |
|||
self.in09_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 7 |
|||
self.refpad10_1 = nn.ReflectionPad2d(1) |
|||
self.conv10_1 = nn.Conv2d(256, 256, 3) |
|||
self.in10_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad10_2 = nn.ReflectionPad2d(1) |
|||
self.conv10_2 = nn.Conv2d(256, 256, 3) |
|||
self.in10_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
## res block 8 |
|||
self.refpad11_1 = nn.ReflectionPad2d(1) |
|||
self.conv11_1 = nn.Conv2d(256, 256, 3) |
|||
self.in11_1 = InstanceNormalization(256) |
|||
# relu |
|||
self.refpad11_2 = nn.ReflectionPad2d(1) |
|||
self.conv11_2 = nn.Conv2d(256, 256, 3) |
|||
self.in11_2 = InstanceNormalization(256) |
|||
# + input |
|||
|
|||
##------------------------------------## |
|||
self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1) |
|||
self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1) |
|||
self.in12_1 = InstanceNormalization(128) |
|||
# relu |
|||
self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1) |
|||
self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1) |
|||
self.in13_1 = InstanceNormalization(64) |
|||
# relu |
|||
self.refpad12_1 = nn.ReflectionPad2d(3) |
|||
self.deconv03_1 = nn.Conv2d(64, 3, 7) |
|||
# tanh |
|||
|
|||
def forward(self, x): |
|||
y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x)))) |
|||
y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y)))) |
|||
t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y)))) |
|||
|
|||
## |
|||
y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04)))) |
|||
t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04 |
|||
|
|||
y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05)))) |
|||
t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05 |
|||
|
|||
y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06)))) |
|||
t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06 |
|||
|
|||
y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07)))) |
|||
t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07 |
|||
|
|||
y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08)))) |
|||
t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08 |
|||
|
|||
y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09)))) |
|||
t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09 |
|||
|
|||
y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10)))) |
|||
t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10 |
|||
|
|||
y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11)))) |
|||
y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11 |
|||
## |
|||
|
|||
y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y)))) |
|||
y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y)))) |
|||
y = F.tanh(self.deconv03_1(self.refpad12_1(y))) |
|||
|
|||
return y |
|||
|
|||
|
|||
class InstanceNormalization(nn.Module): |
|||
def __init__(self, dim, eps=1e-9): |
|||
super(InstanceNormalization, self).__init__() |
|||
self.scale = nn.Parameter(torch.FloatTensor(dim)) |
|||
self.shift = nn.Parameter(torch.FloatTensor(dim)) |
|||
self.eps = eps |
|||
self._reset_parameters() |
|||
|
|||
def _reset_parameters(self): |
|||
self.scale.data.uniform_() |
|||
self.shift.data.zero_() |
|||
|
|||
def __call__(self, x): |
|||
n = x.size(2) * x.size(3) |
|||
t = x.view(x.size(0), x.size(1), n) |
|||
mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) |
|||
# Calculate the biased var. torch.var returns unbiased var |
|||
var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n)) |
|||
scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0) |
|||
scale_broadcast = scale_broadcast.expand_as(x) |
|||
shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0) |
|||
shift_broadcast = shift_broadcast.expand_as(x) |
|||
out = (x - mean) / torch.sqrt(var + self.eps) |
|||
out = out * scale_broadcast + shift_broadcast |
|||
return out |
|||
|
|||
|
|||
class Model(): |
|||
def __init__(self, model_name, device) -> None: |
|||
self._device = device |
|||
self._model = Transformer() |
|||
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth') |
|||
self._model.load_state_dict(torch.load(path)) |
|||
self._model.to(self._device) |
|||
self._model.eval() |
|||
|
|||
def __call__(self, img_tensor: Tensor): |
|||
img_tensor = img_tensor.to(self._device) |
|||
img_tensor = img_tensor * 2 - 1 |
|||
|
|||
output_image = self._model(img_tensor) |
|||
output_image = output_image[0] |
|||
# BGR -> RGB |
|||
output_image = output_image[[2, 1, 0], :, :] |
|||
output_image = output_image.data.cpu().float() * 0.5 + 0.5 |
|||
return output_image.numpy() |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue