diff --git a/pytorch/model.py b/pytorch/model.py index ff94d35..d921b25 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -1,9 +1,11 @@ +import imp import os import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor from pathlib import Path +from towhee.hub.repo_manager import RepoManager class Transformer(nn.Module): def __init__(self): @@ -185,6 +187,14 @@ class Model(): self._device = device self._model = Transformer() path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth') + if not os.path.exists(path): + try: + repo = RepoManager('img2img-translation', 'cartoongan') + repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', + lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) + except: + print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') + self._model.load_state_dict(torch.load(path)) self._model.to(self._device) self._model.eval()