logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

43 lines
1.7 KiB

import os
import numpy
from pathlib import Path
from PIL import Image as PImage
from torchvision import transforms
from towhee import register
from towhee.operator import Operator, OperatorFlag
from towhee.types import arg, to_image_color
from towhee._types import Image
import warnings
warnings.filterwarnings('ignore')
@register(output_schema=['styled_image'], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE,)
class Animegan(Operator):
"""
PyTorch model for image embedding.
"""
def __init__(self, model_name: str, framework: str = 'pytorch', device: str = 'cpu') -> None:
super().__init__()
self._device = device
if framework == 'pytorch':
import importlib.util
path = os.path.join(str(Path(__file__).parent), 'pytorch', 'model.py')
opname = os.path.basename(str(Path(__file__))).split('.')[0]
spec = importlib.util.spec_from_file_location(opname, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
self.model = module.Model(model_name, self._device)
self.tfms = transforms.Compose([
transforms.ToTensor()
])
@arg(1, to_image_color('RGB'))
def __call__(self, image):
img = self.tfms(image).unsqueeze(0)
styled_image = self.model(img)
styled_image = numpy.transpose(styled_image, (1,2,0))
styled_image = PImage.fromarray((styled_image * 255).astype(numpy.uint8))
styled_image = numpy.array(styled_image)
styled_image = styled_image[:, :, ::-1].copy()
return Image(styled_image, 'BGR')