|
@ -1,6 +1,7 @@ |
|
|
from torch import nn, load, Tensor |
|
|
from torch import nn, load, Tensor |
|
|
import os |
|
|
import os |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ConvNormLReLU(nn.Sequential): |
|
|
class ConvNormLReLU(nn.Sequential): |
|
@ -109,19 +110,22 @@ class Generator(nn.Module): |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
class Model(): |
|
|
class Model(): |
|
|
def __init__(self, model_name) -> None: |
|
|
|
|
|
|
|
|
def __init__(self, model_name, device) -> None: |
|
|
|
|
|
self._device = device |
|
|
self._model = Generator() |
|
|
self._model = Generator() |
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt') |
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt') |
|
|
ckpt = load(path) |
|
|
ckpt = load(path) |
|
|
self._model.load_state_dict(ckpt) |
|
|
self._model.load_state_dict(ckpt) |
|
|
|
|
|
self._model.to(self._device) |
|
|
self._model.eval() |
|
|
self._model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, img_tensor: Tensor): |
|
|
def __call__(self, img_tensor: Tensor): |
|
|
|
|
|
img_tensor = img_tensor.to(self._device) |
|
|
img_tensor = img_tensor * 2 - 1 |
|
|
img_tensor = img_tensor * 2 - 1 |
|
|
out = self._model(img_tensor).detach() |
|
|
|
|
|
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 |
|
|
|
|
|
return out.numpy() |
|
|
|
|
|
|
|
|
img_tensor = self._model(img_tensor).to('cpu').detach() |
|
|
|
|
|
img_tensor = img_tensor.squeeze(0).clip(-1, 1) * 0.5 + 0.5 |
|
|
|
|
|
return img_tensor.numpy() |
|
|
|
|
|
|
|
|
def train(self): |
|
|
def train(self): |
|
|
""" |
|
|
""" |
|
|