stable-diffusion
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
26 lines
994 B
26 lines
994 B
import logging
|
|
import torch
|
|
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
|
|
from towhee.operator import PyOperator
|
|
|
|
log = logging.getLogger(PyOperator)
|
|
|
|
class StableDiffusion(PyOperator):
|
|
def __init__(self,
|
|
model_id: str ='stabilityai/stable-diffusion-2-1',
|
|
device: str = None
|
|
):
|
|
if device is None:
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
self.device = device
|
|
|
|
if self.device == 'cpu':
|
|
self.pipe= StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
|
|
else:
|
|
self.pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
|
|
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
|
|
self.pipe = self.pipe.to('cuda')
|
|
|
|
def __call__(self, prompt: str):
|
|
image = self.pipe(prompt).images[0]
|
|
return image
|