logo
Browse Source

Add logger

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
b7859ed3a3
  1. 6
      pytorch/model.py

6
pytorch/model.py

@ -1,4 +1,4 @@
import imp
import logging
import os import os
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -7,6 +7,8 @@ from torch import Tensor
from pathlib import Path from pathlib import Path
from towhee.hub.repo_manager import RepoManager from towhee.hub.repo_manager import RepoManager
log = logging.getLogger()
class Transformer(nn.Module): class Transformer(nn.Module):
def __init__(self): def __init__(self):
super(Transformer, self).__init__() super(Transformer, self).__init__()
@ -193,7 +195,7 @@ class Model():
repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth', repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth',
lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent)) lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent))
except: except:
print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).')
log.error('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).')
self._model.load_state_dict(torch.load(path)) self._model.load_state_dict(torch.load(path))
self._model.to(self._device) self._model.to(self._device)

Loading…
Cancel
Save