|
@ -1,9 +1,11 @@ |
|
|
|
|
|
import imp |
|
|
import os |
|
|
import os |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
from torch import Tensor |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
|
|
|
from towhee.hub.repo_manager import RepoManager |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
class Transformer(nn.Module): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
@ -185,6 +187,14 @@ class Model(): |
|
|
self._device = device |
|
|
self._device = device |
|
|
self._model = Transformer() |
|
|
self._model = Transformer() |
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth') |
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth') |
|
|
|
|
|
if not os.path.exists(path): |
|
|
|
|
|
try: |
|
|
|
|
|
repo = RepoManager('img2img-translation', 'cartoongan') |
|
|
|
|
|
repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', |
|
|
|
|
|
lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) |
|
|
|
|
|
except: |
|
|
|
|
|
print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') |
|
|
|
|
|
|
|
|
self._model.load_state_dict(torch.load(path)) |
|
|
self._model.load_state_dict(torch.load(path)) |
|
|
self._model.to(self._device) |
|
|
self._model.to(self._device) |
|
|
self._model.eval() |
|
|
self._model.eval() |
|
|