diff --git a/pytorch/model.py b/pytorch/model.py index d921b25..c300151 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -1,4 +1,4 @@ -import imp +import logging import os import torch import torch.nn as nn @@ -7,6 +7,8 @@ from torch import Tensor from pathlib import Path from towhee.hub.repo_manager import RepoManager +log = logging.getLogger() + class Transformer(nn.Module): def __init__(self): super(Transformer, self).__init__() @@ -193,7 +195,7 @@ class Model(): repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) except: - print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') + log.error('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') self._model.load_state_dict(torch.load(path)) self._model.to(self._device)