logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

37 lines
1.3 KiB

import os
import numpy
from pathlib import Path
from torchvision import transforms
import torch
from towhee import register
from towhee.operator import Operator, OperatorFlag
from towhee.types import arg, to_image_color, Image
import warnings
warnings.filterwarnings('ignore')
@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,)
class Animegan(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None:
super().__init__()
self._device = device
if framework == 'pytorch':
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name, self._device)
self.tfms = transforms.ToTensor()
@arg(1, to_image_color('RGB'))
def __call__(self, img):
img = self.tfms(img).unsqueeze(0)
styled_image = self.model(img)
styled_image = numpy.transpose(styled_image, (1,2,0))
return Image(styled_image, 'RGB')