|
|
|
from torch import nn, load, Tensor
|
|
|
|
import os
|
|
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNormLReLU(nn.Sequential):
|
|
|
|
def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, pad_mode="reflect", groups=1, bias=False):
|
|
|
|
|
|
|
|
pad_layer = {
|
|
|
|
"zero": nn.ZeroPad2d,
|
|
|
|
"same": nn.ReplicationPad2d,
|
|
|
|
"reflect": nn.ReflectionPad2d,
|
|
|
|
}
|
|
|
|
if pad_mode not in pad_layer:
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
super(ConvNormLReLU, self).__init__(
|
|
|
|
pad_layer[pad_mode](padding),
|
|
|
|
nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=0, groups=groups, bias=bias),
|
|
|
|
nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True),
|
|
|
|
nn.LeakyReLU(0.2, inplace=True)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class InvertedResBlock(nn.Module):
|
|
|
|
def __init__(self, in_ch, out_ch, expansion_ratio=2):
|
|
|
|
super(InvertedResBlock, self).__init__()
|
|
|
|
|
|
|
|
self.use_res_connect = in_ch == out_ch
|
|
|
|
bottleneck = int(round(in_ch*expansion_ratio))
|
|
|
|
layers = []
|
|
|
|
if expansion_ratio != 1:
|
|
|
|
layers.append(ConvNormLReLU(in_ch, bottleneck, kernel_size=1, padding=0))
|
|
|
|
|
|
|
|
# dw
|
|
|
|
layers.append(ConvNormLReLU(bottleneck, bottleneck, groups=bottleneck, bias=True))
|
|
|
|
# pw
|
|
|
|
layers.append(nn.Conv2d(bottleneck, out_ch, kernel_size=1, padding=0, bias=False))
|
|
|
|
layers.append(nn.GroupNorm(num_groups=1, num_channels=out_ch, affine=True))
|
|
|
|
|
|
|
|
self.layers = nn.Sequential(*layers)
|
|
|
|
|
|
|
|
def forward(self, input):
|
|
|
|
out = self.layers(input)
|
|
|
|
if self.use_res_connect:
|
|
|
|
out = input + out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
class Generator(nn.Module):
|
|
|
|
def __init__(self, ):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
self.block_a = nn.Sequential(
|
|
|
|
ConvNormLReLU(3, 32, kernel_size=7, padding=3),
|
|
|
|
ConvNormLReLU(32, 64, stride=2, padding=(0,1,0,1)),
|
|
|
|
ConvNormLReLU(64, 64)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.block_b = nn.Sequential(
|
|
|
|
ConvNormLReLU(64, 128, stride=2, padding=(0,1,0,1)),
|
|
|
|
ConvNormLReLU(128, 128)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.block_c = nn.Sequential(
|
|
|
|
ConvNormLReLU(128, 128),
|
|
|
|
InvertedResBlock(128, 256, 2),
|
|
|
|
InvertedResBlock(256, 256, 2),
|
|
|
|
InvertedResBlock(256, 256, 2),
|
|
|
|
InvertedResBlock(256, 256, 2),
|
|
|
|
ConvNormLReLU(256, 128),
|
|
|
|
)
|
|
|
|
|
|
|
|
self.block_d = nn.Sequential(
|
|
|
|
ConvNormLReLU(128, 128),
|
|
|
|
ConvNormLReLU(128, 128)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.block_e = nn.Sequential(
|
|
|
|
ConvNormLReLU(128, 64),
|
|
|
|
ConvNormLReLU(64, 64),
|
|
|
|
ConvNormLReLU(64, 32, kernel_size=7, padding=3)
|
|
|
|
)
|
|
|
|
|
|
|
|
self.out_layer = nn.Sequential(
|
|
|
|
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0, bias=False),
|
|
|
|
nn.Tanh()
|
|
|
|
)
|
|
|
|
|
|
|
|
def forward(self, input, align_corners=True):
|
|
|
|
out = self.block_a(input)
|
|
|
|
half_size = out.size()[-2:]
|
|
|
|
out = self.block_b(out)
|
|
|
|
out = self.block_c(out)
|
|
|
|
|
|
|
|
if align_corners:
|
|
|
|
out = nn.functional.interpolate(out, half_size, mode="bilinear", align_corners=True)
|
|
|
|
else:
|
|
|
|
out = nn.functional.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
|
|
|
out = self.block_d(out)
|
|
|
|
|
|
|
|
if align_corners:
|
|
|
|
out = nn.functional.interpolate(out, input.size()[-2:], mode="bilinear", align_corners=True)
|
|
|
|
else:
|
|
|
|
out = nn.functional.interpolate(out, scale_factor=2, mode="bilinear", align_corners=False)
|
|
|
|
out = self.block_e(out)
|
|
|
|
|
|
|
|
out = self.out_layer(out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
class Model():
|
|
|
|
def __init__(self, model_name, device) -> None:
|
|
|
|
self._device = device
|
|
|
|
self._model = Generator()
|
|
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt')
|
|
|
|
ckpt = load(path)
|
|
|
|
self._model.load_state_dict(ckpt)
|
|
|
|
self._model.to(self._device)
|
|
|
|
self._model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, img_tensor: Tensor):
|
|
|
|
img_tensor = img_tensor.to(self._device)
|
|
|
|
img_tensor = img_tensor * 2 - 1
|
|
|
|
img_tensor = self._model(img_tensor).to('cpu').detach()
|
|
|
|
img_tensor = img_tensor.squeeze(0).clip(-1, 1) * 0.5 + 0.5
|
|
|
|
return img_tensor.numpy()
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
"""
|
|
|
|
For training model
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|