From b7859ed3a3609dc59c755dc9df40e15ff6438fc4 Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Thu, 26 May 2022 18:56:31 +0800 Subject: [PATCH] Add logger Signed-off-by: shiyu22 --- pytorch/model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch/model.py b/pytorch/model.py index d921b25..c300151 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -1,4 +1,4 @@ -import imp +import logging import os import torch import torch.nn as nn @@ -7,6 +7,8 @@ from torch import Tensor from pathlib import Path from towhee.hub.repo_manager import RepoManager +log = logging.getLogger() + class Transformer(nn.Module): def __init__(self): super(Transformer, self).__init__() @@ -193,7 +195,7 @@ class Model(): repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) except: - print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') + log.error('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).') self._model.load_state_dict(torch.load(path)) self._model.to(self._device)