Browse Source
Support GPU
Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
3 changed files with
15 additions and
6 deletions
-
README.md
-
animegan.py
-
pytorch/model.py
|
|
@ -92,6 +92,10 @@ Takes in a numpy rgb image in channels first. It transforms input into animated |
|
|
|
|
|
|
|
Which ML framework being used, for now only supports PyTorch. |
|
|
|
|
|
|
|
***device***: *str* |
|
|
|
|
|
|
|
Which device being used('cpu' or 'cuda'), defaults to 'cpu'. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
**Returns**: *towhee.types.Image (a sub-class of numpy.ndarray)* |
|
|
|
|
|
@ -16,8 +16,9 @@ class Animegan(Operator): |
|
|
|
""" |
|
|
|
PyTorch model for image embedding. |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str, framework: str = 'pytorch') -> None: |
|
|
|
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None: |
|
|
|
super().__init__() |
|
|
|
self._device = device |
|
|
|
if framework == 'pytorch': |
|
|
|
import importlib.util |
|
|
|
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') |
|
|
@ -25,7 +26,7 @@ class Animegan(Operator): |
|
|
|
spec = importlib.util.spec_from_file_location(opname, path) |
|
|
|
module = importlib.util.module_from_spec(spec) |
|
|
|
spec.loader.exec_module(module) |
|
|
|
self.model = module.Model(model_name) |
|
|
|
self.model = module.Model(model_name, self._device) |
|
|
|
self.tfms = transforms.Compose([ |
|
|
|
transforms.ToTensor() |
|
|
|
]) |
|
|
|
|
|
@ -1,6 +1,7 @@ |
|
|
|
from torch import nn, load, Tensor |
|
|
|
import os |
|
|
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
class ConvNormLReLU(nn.Sequential): |
|
|
@ -109,19 +110,22 @@ class Generator(nn.Module): |
|
|
|
return out |
|
|
|
|
|
|
|
class Model(): |
|
|
|
def __init__(self, model_name) -> None: |
|
|
|
def __init__(self, model_name, device) -> None: |
|
|
|
self._device = device |
|
|
|
self._model = Generator() |
|
|
|
path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt') |
|
|
|
ckpt = load(path) |
|
|
|
self._model.load_state_dict(ckpt) |
|
|
|
self._model.to(self._device) |
|
|
|
self._model.eval() |
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, img_tensor: Tensor): |
|
|
|
img_tensor = img_tensor.to(self._device) |
|
|
|
img_tensor = img_tensor * 2 - 1 |
|
|
|
out = self._model(img_tensor).detach() |
|
|
|
out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 |
|
|
|
return out.numpy() |
|
|
|
img_tensor = self._model(img_tensor).to('cpu').detach() |
|
|
|
img_tensor = img_tensor.squeeze(0).clip(-1, 1) * 0.5 + 0.5 |
|
|
|
return img_tensor.numpy() |
|
|
|
|
|
|
|
def train(self): |
|
|
|
""" |
|
|
|