import logging import torch from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler from towhee.operator import PyOperator log = logging.getLogger(PyOperator) class StableDiffusion(PyOperator): def __init__(self, model_id: str ='stabilityai/stable-diffusion-2-1', device: str = None ): if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device if self.device == 'cpu': self.pipe= StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32) else: self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16) self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) self.pipe = self.pipe.to('cuda') def __call__(self, prompt: str): image = self.pipe(prompt).images[0] return image