logo
Browse Source

Support GPU

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 3 years ago
parent
commit
7c9bce5f51
  1. 4
      README.md
  2. 5
      animegan.py
  3. 12
      pytorch/model.py

4
README.md

@ -92,6 +92,10 @@ Takes in a numpy rgb image in channels first. It transforms input into animated
​ Which ML framework being used, for now only supports PyTorch.
***device***: *str*
​ Which device being used('cpu' or 'cuda'), defaults to 'cpu'.
**Returns**: *towhee.types.Image (a sub-class of numpy.ndarray)*

5
animegan.py

@ -16,8 +16,9 @@ class Animegan(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, model_name: str, framework: str = 'pytorch') -> None:
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None:
super().__init__()
self._device = device
if framework == 'pytorch':
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
@ -25,7 +26,7 @@ class Animegan(Operator):
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name)
self.model = module.Model(model_name, self._device)
self.tfms = transforms.Compose([
transforms.ToTensor()
])

12
pytorch/model.py

@ -1,6 +1,7 @@
from torch import nn, load, Tensor
import os
from pathlib import Path
import torch
class ConvNormLReLU(nn.Sequential):
@ -109,19 +110,22 @@ class Generator(nn.Module):
return out
class Model():
def __init__(self, model_name) -> None:
def __init__(self, model_name, device) -> None:
self._device = device
self._model = Generator()
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt')
ckpt = load(path)
self._model.load_state_dict(ckpt)
self._model.to(self._device)
self._model.eval()
def __call__(self, img_tensor: Tensor):
img_tensor = img_tensor.to(self._device)
img_tensor = img_tensor * 2 - 1
out = self._model(img_tensor).detach()
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5
return out.numpy()
img_tensor = self._model(img_tensor).to('cpu').detach()
img_tensor = img_tensor.squeeze(0).clip(-1, 1) * 0.5 + 0.5
return img_tensor.numpy()
def train(self):
"""

Loading…
Cancel
Save