cartoongan
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
215 lines
7.9 KiB
215 lines
7.9 KiB
import logging
|
|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
from pathlib import Path
|
|
from towhee.hub.repo_manager import RepoManager
|
|
|
|
log = logging.getLogger()
|
|
|
|
class Transformer(nn.Module):
|
|
def __init__(self):
|
|
super(Transformer, self).__init__()
|
|
#
|
|
self.refpad01_1 = nn.ReflectionPad2d(3)
|
|
self.conv01_1 = nn.Conv2d(3, 64, 7)
|
|
self.in01_1 = InstanceNormalization(64)
|
|
# relu
|
|
self.conv02_1 = nn.Conv2d(64, 128, 3, 2, 1)
|
|
self.conv02_2 = nn.Conv2d(128, 128, 3, 1, 1)
|
|
self.in02_1 = InstanceNormalization(128)
|
|
# relu
|
|
self.conv03_1 = nn.Conv2d(128, 256, 3, 2, 1)
|
|
self.conv03_2 = nn.Conv2d(256, 256, 3, 1, 1)
|
|
self.in03_1 = InstanceNormalization(256)
|
|
# relu
|
|
|
|
## res block 1
|
|
self.refpad04_1 = nn.ReflectionPad2d(1)
|
|
self.conv04_1 = nn.Conv2d(256, 256, 3)
|
|
self.in04_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad04_2 = nn.ReflectionPad2d(1)
|
|
self.conv04_2 = nn.Conv2d(256, 256, 3)
|
|
self.in04_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 2
|
|
self.refpad05_1 = nn.ReflectionPad2d(1)
|
|
self.conv05_1 = nn.Conv2d(256, 256, 3)
|
|
self.in05_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad05_2 = nn.ReflectionPad2d(1)
|
|
self.conv05_2 = nn.Conv2d(256, 256, 3)
|
|
self.in05_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 3
|
|
self.refpad06_1 = nn.ReflectionPad2d(1)
|
|
self.conv06_1 = nn.Conv2d(256, 256, 3)
|
|
self.in06_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad06_2 = nn.ReflectionPad2d(1)
|
|
self.conv06_2 = nn.Conv2d(256, 256, 3)
|
|
self.in06_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 4
|
|
self.refpad07_1 = nn.ReflectionPad2d(1)
|
|
self.conv07_1 = nn.Conv2d(256, 256, 3)
|
|
self.in07_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad07_2 = nn.ReflectionPad2d(1)
|
|
self.conv07_2 = nn.Conv2d(256, 256, 3)
|
|
self.in07_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 5
|
|
self.refpad08_1 = nn.ReflectionPad2d(1)
|
|
self.conv08_1 = nn.Conv2d(256, 256, 3)
|
|
self.in08_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad08_2 = nn.ReflectionPad2d(1)
|
|
self.conv08_2 = nn.Conv2d(256, 256, 3)
|
|
self.in08_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 6
|
|
self.refpad09_1 = nn.ReflectionPad2d(1)
|
|
self.conv09_1 = nn.Conv2d(256, 256, 3)
|
|
self.in09_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad09_2 = nn.ReflectionPad2d(1)
|
|
self.conv09_2 = nn.Conv2d(256, 256, 3)
|
|
self.in09_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 7
|
|
self.refpad10_1 = nn.ReflectionPad2d(1)
|
|
self.conv10_1 = nn.Conv2d(256, 256, 3)
|
|
self.in10_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad10_2 = nn.ReflectionPad2d(1)
|
|
self.conv10_2 = nn.Conv2d(256, 256, 3)
|
|
self.in10_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
## res block 8
|
|
self.refpad11_1 = nn.ReflectionPad2d(1)
|
|
self.conv11_1 = nn.Conv2d(256, 256, 3)
|
|
self.in11_1 = InstanceNormalization(256)
|
|
# relu
|
|
self.refpad11_2 = nn.ReflectionPad2d(1)
|
|
self.conv11_2 = nn.Conv2d(256, 256, 3)
|
|
self.in11_2 = InstanceNormalization(256)
|
|
# + input
|
|
|
|
##------------------------------------##
|
|
self.deconv01_1 = nn.ConvTranspose2d(256, 128, 3, 2, 1, 1)
|
|
self.deconv01_2 = nn.Conv2d(128, 128, 3, 1, 1)
|
|
self.in12_1 = InstanceNormalization(128)
|
|
# relu
|
|
self.deconv02_1 = nn.ConvTranspose2d(128, 64, 3, 2, 1, 1)
|
|
self.deconv02_2 = nn.Conv2d(64, 64, 3, 1, 1)
|
|
self.in13_1 = InstanceNormalization(64)
|
|
# relu
|
|
self.refpad12_1 = nn.ReflectionPad2d(3)
|
|
self.deconv03_1 = nn.Conv2d(64, 3, 7)
|
|
# tanh
|
|
|
|
def forward(self, x):
|
|
y = F.relu(self.in01_1(self.conv01_1(self.refpad01_1(x))))
|
|
y = F.relu(self.in02_1(self.conv02_2(self.conv02_1(y))))
|
|
t04 = F.relu(self.in03_1(self.conv03_2(self.conv03_1(y))))
|
|
|
|
##
|
|
y = F.relu(self.in04_1(self.conv04_1(self.refpad04_1(t04))))
|
|
t05 = self.in04_2(self.conv04_2(self.refpad04_2(y))) + t04
|
|
|
|
y = F.relu(self.in05_1(self.conv05_1(self.refpad05_1(t05))))
|
|
t06 = self.in05_2(self.conv05_2(self.refpad05_2(y))) + t05
|
|
|
|
y = F.relu(self.in06_1(self.conv06_1(self.refpad06_1(t06))))
|
|
t07 = self.in06_2(self.conv06_2(self.refpad06_2(y))) + t06
|
|
|
|
y = F.relu(self.in07_1(self.conv07_1(self.refpad07_1(t07))))
|
|
t08 = self.in07_2(self.conv07_2(self.refpad07_2(y))) + t07
|
|
|
|
y = F.relu(self.in08_1(self.conv08_1(self.refpad08_1(t08))))
|
|
t09 = self.in08_2(self.conv08_2(self.refpad08_2(y))) + t08
|
|
|
|
y = F.relu(self.in09_1(self.conv09_1(self.refpad09_1(t09))))
|
|
t10 = self.in09_2(self.conv09_2(self.refpad09_2(y))) + t09
|
|
|
|
y = F.relu(self.in10_1(self.conv10_1(self.refpad10_1(t10))))
|
|
t11 = self.in10_2(self.conv10_2(self.refpad10_2(y))) + t10
|
|
|
|
y = F.relu(self.in11_1(self.conv11_1(self.refpad11_1(t11))))
|
|
y = self.in11_2(self.conv11_2(self.refpad11_2(y))) + t11
|
|
##
|
|
|
|
y = F.relu(self.in12_1(self.deconv01_2(self.deconv01_1(y))))
|
|
y = F.relu(self.in13_1(self.deconv02_2(self.deconv02_1(y))))
|
|
y = F.tanh(self.deconv03_1(self.refpad12_1(y)))
|
|
|
|
return y
|
|
|
|
|
|
class InstanceNormalization(nn.Module):
|
|
def __init__(self, dim, eps=1e-9):
|
|
super(InstanceNormalization, self).__init__()
|
|
self.scale = nn.Parameter(torch.FloatTensor(dim))
|
|
self.shift = nn.Parameter(torch.FloatTensor(dim))
|
|
self.eps = eps
|
|
self._reset_parameters()
|
|
|
|
def _reset_parameters(self):
|
|
self.scale.data.uniform_()
|
|
self.shift.data.zero_()
|
|
|
|
def __call__(self, x):
|
|
n = x.size(2) * x.size(3)
|
|
t = x.view(x.size(0), x.size(1), n)
|
|
mean = torch.mean(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x)
|
|
# Calculate the biased var. torch.var returns unbiased var
|
|
var = torch.var(t, 2).unsqueeze(2).unsqueeze(3).expand_as(x) * ((n - 1) / float(n))
|
|
scale_broadcast = self.scale.unsqueeze(1).unsqueeze(1).unsqueeze(0)
|
|
scale_broadcast = scale_broadcast.expand_as(x)
|
|
shift_broadcast = self.shift.unsqueeze(1).unsqueeze(1).unsqueeze(0)
|
|
shift_broadcast = shift_broadcast.expand_as(x)
|
|
out = (x - mean) / torch.sqrt(var + self.eps)
|
|
out = out * scale_broadcast + shift_broadcast
|
|
return out
|
|
|
|
|
|
class Model():
|
|
def __init__(self, model_name, device) -> None:
|
|
self._device = device
|
|
self._model = Transformer()
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth')
|
|
if not os.path.exists(path):
|
|
try:
|
|
repo = RepoManager('img2img-translation', 'cartoongan')
|
|
repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth',
|
|
lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent))
|
|
except:
|
|
log.error('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).')
|
|
|
|
self._model.load_state_dict(torch.load(path))
|
|
self._model.to(self._device)
|
|
self._model.eval()
|
|
|
|
def __call__(self, img_tensor: Tensor):
|
|
img_tensor = img_tensor.to(self._device)
|
|
img_tensor = img_tensor * 2 - 1
|
|
|
|
output_image = self._model(img_tensor)
|
|
output_image = output_image[0]
|
|
# BGR -> RGB
|
|
output_image = output_image[[2, 1, 0], :, :]
|
|
output_image = output_image.data.cpu().float() * 0.5 + 0.5
|
|
|
|
return output_image.numpy()
|
|
|