logo
Browse Source

Download model

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
ede56f71cb
  1. 10
      pytorch/model.py

10
pytorch/model.py

@ -1,9 +1,11 @@
import imp
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from pathlib import Path
from towhee.hub.repo_manager import RepoManager
class Transformer(nn.Module):
def __init__(self):
@ -185,6 +187,14 @@ class Model():
self._device = device
self._model = Transformer()
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '_net_G_float.pth')
if not os.path.exists(path):
try:
repo = RepoManager('img2img-translation', 'cartoongan')
repo.download_executor(tag='main', file_name='pytorch/weights/'+ model_name + '_net_G_float.pth',
lfs_files=('.pth'), local_repo_path=str(Path(__file__).parent.parent))
except:
print('Error when downloading model, please make sure the mode_name in (Hayao, Hosoda, Shinkai, Paprika).')
self._model.load_state_dict(torch.load(path))
self._model.to(self._device)
self._model.eval()

Loading…
Cancel
Save