diff --git a/README.md b/README.md index 246770a..28e1f61 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,10 @@ Takes in a numpy rgb image in channels first. It transforms input into animated ​ Which ML framework being used, for now only supports PyTorch. + ***device***: *str* + +​ Which device being used('cpu' or 'cuda'), defaults to 'cpu'. + **Returns**: *towhee.types.Image (a sub-class of numpy.ndarray)* diff --git a/animegan.py b/animegan.py index 161a339..b6ffba0 100644 --- a/animegan.py +++ b/animegan.py @@ -16,8 +16,9 @@ class Animegan(Operator): """ PyTorch model for image embedding. """ - def __init__(self, model_name: str, framework: str = 'pytorch') -> None: + def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None: super().__init__() + self._device = device if framework == 'pytorch': import importlib.util path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py') @@ -25,7 +26,7 @@ class Animegan(Operator): spec = importlib.util.spec_from_file_location(opname, path) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) - self.model = module.Model(model_name) + self.model = module.Model(model_name, self._device) self.tfms = transforms.Compose([ transforms.ToTensor() ]) diff --git a/pytorch/model.py b/pytorch/model.py index 361ecc7..8875b36 100644 --- a/pytorch/model.py +++ b/pytorch/model.py @@ -1,6 +1,7 @@ from torch import nn, load, Tensor import os from pathlib import Path +import torch class ConvNormLReLU(nn.Sequential): @@ -109,19 +110,22 @@ class Generator(nn.Module): return out class Model(): - def __init__(self, model_name) -> None: + def __init__(self, model_name, device) -> None: + self._device = device self._model = Generator() path = os.path.join(str(Path(__file__).parent), 'weights', model_name + '.pt') ckpt = load(path) self._model.load_state_dict(ckpt) + self._model.to(self._device) self._model.eval() def __call__(self, img_tensor: Tensor): + img_tensor = img_tensor.to(self._device) img_tensor = img_tensor * 2 - 1 - out = self._model(img_tensor).detach() - out = out.squeeze(0).clip(-1, 1) * 0.5 + 0.5 - return out.numpy() + img_tensor = self._model(img_tensor).to('cpu').detach() + img_tensor = img_tensor.squeeze(0).clip(-1, 1) * 0.5 + 0.5 + return img_tensor.numpy() def train(self): """