|
@ -1,4 +1,4 @@ |
|
|
import imp |
|
|
|
|
|
|
|
|
import logging |
|
|
import os |
|
|
import os |
|
|
import torch |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn as nn |
|
@ -7,6 +7,8 @@ from torch import Tensor |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
from towhee.hub.repo_manager import RepoManager |
|
|
from towhee.hub.repo_manager import RepoManager |
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger() |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
class Transformer(nn.Module): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super(Transformer, self).__init__() |
|
|
super(Transformer, self).__init__() |
|
@ -193,7 +195,7 @@ class Model(): |
|
|
repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', |
|
|
repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', |
|
|
lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) |
|
|
lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) |
|
|
except: |
|
|
except: |
|
|
print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') |
|
|
|
|
|
|
|
|
log.error('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') |
|
|
|
|
|
|
|
|
self._model.load_state_dict(torch.load(path)) |
|
|
self._model.load_state_dict(torch.load(path)) |
|
|
self._model.to(self._device) |
|
|
self._model.to(self._device) |
|
|