|
@ -8,7 +8,7 @@ log = logging.getLogger(PyOperator) |
|
|
class StableDiffusion(PyOperator): |
|
|
class StableDiffusion(PyOperator): |
|
|
def __init__(self,model_id='stabilityai/stable-diffusion-2-1',pipe='StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32'): |
|
|
def __init__(self,model_id='stabilityai/stable-diffusion-2-1',pipe='StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32'): |
|
|
self._model_id=model_id |
|
|
self._model_id=model_id |
|
|
self.pipe = pipe |
|
|
|
|
|
|
|
|
self._pipe = pipe |
|
|
|
|
|
|
|
|
def __call__(self, prompt:str): |
|
|
def __call__(self, prompt:str): |
|
|
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|
# pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) |
|
|