From f944a6458e1f55183254a123478967c9b8e9e194 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Mon, 23 Mar 2026 15:53:04 +0000 Subject: [PATCH 01/14] initial architecture --- .../stable_diffusion_3/__init__.py | 46 +++ .../stable_diffusion_3/before_denoise.py | 280 ++++++++++++++++ .../stable_diffusion_3/decoders.py | 51 +++ .../stable_diffusion_3/denoise.py | 153 +++++++++ .../stable_diffusion_3/encoders.py | 304 ++++++++++++++++++ .../stable_diffusion_3/inputs.py | 141 ++++++++ .../modular_blocks_stable_diffusion_3.py | 105 ++++++ .../stable_diffusion_3/modular_pipeline.py | 48 +++ .../stable_diffusion_3/__init__.py | 0 ...est_modular_pipeline_stable_diffusion_3.py | 122 +++++++ 10 files changed, 1250 insertions(+) create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py create mode 100644 src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py create mode 100644 tests/modular_pipelines/stable_diffusion_3/__init__.py create mode 100644 tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000000..13396327ee7c --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -0,0 +1,46 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, +) + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects # noqa F403 + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modular_blocks_stable_diffusion_3"] = ["SD3AutoBlocks"] + _import_structure["modular_pipeline"] = ["SD3ModularPipeline"] + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 + else: + from .modular_blocks_stable_diffusion_3 import SD3AutoBlocks + from .modular_pipeline import SD3ModularPipeline +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py new file mode 100644 index 000000000000..7eee1d7dc652 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -0,0 +1,280 @@ +import inspect + +import numpy as np +import torch + +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ...utils.torch_utils import randn_tensor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import SD3ModularPipeline + + +logger = logging.get_logger(__name__) + + +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +def _get_initial_timesteps_and_optionals( + transformer, + scheduler, + height, + width, + patch_size, + vae_scale_factor, + num_inference_steps, + sigmas, + device, + mu=None, +): + scheduler_kwargs = {} + if scheduler.config.get("use_dynamic_shifting", None) and mu is None: + image_seq_len = (height // vae_scale_factor // patch_size) * (width // vae_scale_factor // patch_size) + mu = calculate_shift( + image_seq_len, + scheduler.config.get("base_image_seq_len", 256), + scheduler.config.get("max_image_seq_len", 4096), + scheduler.config.get("base_shift", 0.5), + scheduler.config.get("max_shift", 1.16), + ) + scheduler_kwargs["mu"] = mu + elif mu is not None: + scheduler_kwargs["mu"] = mu + + timesteps, num_inference_steps = retrieve_timesteps( + scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs + ) + return timesteps, num_inference_steps + + +class SD3SetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for inference" + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("mu", type_hint=float), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("timesteps", type_hint=torch.Tensor), + OutputParam("num_inference_steps", type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + timesteps, num_inference_steps = _get_initial_timesteps_and_optionals( + components.transformer, + components.scheduler, + block_state.height, + block_state.width, + components.patch_size, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.sigmas, + block_state.device, + getattr(block_state, "mu", None) + ) + + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class SD3Img2ImgSetTimestepsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def description(self) -> str: + return "Step that sets the scheduler's timesteps for img2img inference" + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("num_inference_steps", default=50), + InputParam("timesteps"), + InputParam("sigmas"), + InputParam("strength", default=0.6), + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("mu", type_hint=float), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("timesteps", type_hint=torch.Tensor), + OutputParam("num_inference_steps", type_hint=int), + ] + + @staticmethod + def get_timesteps(scheduler, num_inference_steps, strength): + init_timestep = min(num_inference_steps * strength, num_inference_steps) + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = scheduler.timesteps[t_start * scheduler.order :] + if hasattr(scheduler, "set_begin_index"): + scheduler.set_begin_index(t_start * scheduler.order) + return timesteps, num_inference_steps - t_start + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + timesteps, num_inference_steps = _get_initial_timesteps_and_optionals( + components.transformer, + components.scheduler, + block_state.height, + block_state.width, + components.patch_size, + components.vae_scale_factor, + block_state.num_inference_steps, + block_state.sigmas, + block_state.device, + getattr(block_state, "mu", None) + ) + + timesteps, num_inference_steps = self.get_timesteps( + components.scheduler, num_inference_steps, block_state.strength + ) + + block_state.timesteps = timesteps + block_state.num_inference_steps = num_inference_steps + + self.set_block_state(state, block_state) + return components, state + + +class SD3PrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Prepare latents step for Text-to-Image" + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("height", type_hint=int), + InputParam("width", type_hint=int), + InputParam("latents", type_hint=torch.Tensor | None), + InputParam("num_images_per_prompt", type_hint=int, default=1), + InputParam("generator"), + InputParam("batch_size", required=True, type_hint=int), + InputParam("dtype", type_hint=torch.dtype), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[OutputParam("latents", type_hint=torch.Tensor)] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + batch_size = block_state.batch_size * block_state.num_images_per_prompt + + if block_state.latents is not None: + block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) + else: + shape = ( + batch_size, + components.num_channels_latents, + int(block_state.height) // components.vae_scale_factor, + int(block_state.width) // components.vae_scale_factor, + ) + block_state.latents = randn_tensor(shape, generator=block_state.generator, device=block_state.device, dtype=block_state.dtype) + + self.set_block_state(state, block_state) + return components, state + + +class SD3Img2ImgPrepareLatentsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("image_latents", required=True, type_hint=torch.Tensor), + InputParam("timesteps", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("initial_noise", type_hint=torch.Tensor)] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + block_state.initial_noise = block_state.latents + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py new file mode 100644 index 000000000000..3f037f1fee01 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -0,0 +1,51 @@ +import PIL +import torch + +from ...configuration_utils import FrozenDict +from ...models import AutoencoderKL +from ...utils import logging +from ...image_processor import VaeImageProcessor +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam + + +logger = logging.get_logger(__name__) + + +class SD3DecodeStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec("vae", AutoencoderKL), + ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ] + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("output_type", default="pil"), + InputParam("latents", required=True, type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return [OutputParam("images", type_hint=list[PIL.Image.Image] | torch.Tensor)] + + @torch.no_grad() + def __call__(self, components, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + vae = components.vae + + if not block_state.output_type == "latent": + latents = (block_state.latents / vae.config.scaling_factor) + vae.config.shift_factor + block_state.images = vae.decode(latents, return_dict=False)[0] + block_state.images = components.image_processor.postprocess( + block_state.images, output_type=block_state.output_type + ) + else: + block_state.images = block_state.latents + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py new file mode 100644 index 000000000000..4341c3daf3c9 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -0,0 +1,153 @@ +from typing import Any + +import torch + +from ...models.transformers import SD3Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import logging +from ..modular_pipeline import ( + BlockState, + LoopSequentialPipelineBlocks, + ModularPipelineBlocks, + PipelineState, +) +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import SD3ModularPipeline + +logger = logging.get_logger(__name__) + + +class SD3LoopDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("transformer", SD3Transformer2DModel)] + + @property + def description(self) -> str: + return "Step within the denoising loop that denoise the latents." + + @property + def inputs(self) -> list[tuple[str, Any]]: + return[ + InputParam("joint_attention_kwargs"), + InputParam("latents", required=True, type_hint=torch.Tensor), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("do_classifier_free_guidance", type_hint=bool), + InputParam("guidance_scale", default=7.0), + InputParam("skip_guidance_layers", type_hint=list), + InputParam("skip_layer_guidance_scale", default=2.8), + InputParam("skip_layer_guidance_stop", default=0.2), + InputParam("skip_layer_guidance_start", default=0.01), + InputParam("original_prompt_embeds", type_hint=torch.Tensor), + InputParam("original_pooled_prompt_embeds", type_hint=torch.Tensor), + InputParam("num_inference_steps", type_hint=int), + ] + + @torch.no_grad() + def __call__( + self, components: SD3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + ) -> PipelineState: + latent_model_input = torch.cat([block_state.latents] * 2) if block_state.do_classifier_free_guidance else block_state.latents + timestep = t.expand(latent_model_input.shape[0]) + + noise_pred = components.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=block_state.prompt_embeds, + pooled_projections=block_state.pooled_prompt_embeds, + joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + return_dict=False, + )[0] + + if block_state.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + block_state.guidance_scale * (noise_pred_text - noise_pred_uncond) + + should_skip_layers = ( + getattr(block_state, "skip_guidance_layers", None) is not None + and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) + and i < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) + ) + + if should_skip_layers: + timestep_skip = t.expand(block_state.latents.shape[0]) + noise_pred_skip_layers = components.transformer( + hidden_states=block_state.latents, + timestep=timestep_skip, + encoder_hidden_states=block_state.original_prompt_embeds, + pooled_projections=block_state.original_pooled_prompt_embeds, + joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + return_dict=False, + skip_layers=block_state.skip_guidance_layers, + )[0] + noise_pred = noise_pred + (noise_pred_text - noise_pred_skip_layers) * getattr(block_state, "skip_layer_guidance_scale", 2.8) + + block_state.noise_pred = noise_pred + return components, block_state + + +class SD3LoopAfterDenoiser(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[OutputParam("latents", type_hint=torch.Tensor)] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + latents_dtype = block_state.latents.dtype + block_state.latents = components.scheduler.step( + block_state.noise_pred, + t, + block_state.latents, + return_dict=False, + )[0] + + if block_state.latents.dtype != latents_dtype: + block_state.latents = block_state.latents.to(latents_dtype) + + return components, block_state + + +class SD3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def loop_expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("transformer", SD3Transformer2DModel), + ] + + @property + def loop_inputs(self) -> list[InputParam]: + return[ + InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam("num_inference_steps", required=True, type_hint=int), + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: + for i, t in enumerate(block_state.timesteps): + components, block_state = self.loop_step(components, block_state, i=i, t=t) + if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + progress_bar.update() + + self.set_block_state(state, block_state) + return components, state + + +class SD3DenoiseStep(SD3DenoiseLoopWrapper): + block_classes = [SD3LoopDenoiser, SD3LoopAfterDenoiser] + block_names = ["denoiser", "after_denoiser"] \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py new file mode 100644 index 000000000000..24f38fbfce38 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -0,0 +1,304 @@ +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast + +from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor +from ...loaders import SD3LoraLoaderMixin +from ...models import AutoencoderKL +from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam +from .modular_pipeline import SD3ModularPipeline + +logger = logging.get_logger(__name__) + +def retrieve_latents( + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): + if isinstance(generator, list): + image_latents =[ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + return image_latents + +class SD3ProcessImagesInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Image Preprocess step for SD3." + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8, "vae_latent_channels": 16}), + default_creation_method="from_config", + ), + ] + + @property + def inputs(self) -> list[InputParam]: + return[InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[OutputParam(name="processed_image")] + + @staticmethod + def check_inputs(height, width, vae_scale_factor, patch_size): + if height is not None and height % (vae_scale_factor * patch_size) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * patch_size} but is {height}") + + if width is not None and width % (vae_scale_factor * patch_size) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * patch_size} but is {width}") + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("`resized_image` and `image` cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, + vae_scale_factor=components.vae_scale_factor, patch_size=components.patch_size + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width + else: + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + + self.set_block_state(state, block_state) + return components, state + +class SD3VaeEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + def __init__(self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"): + self._image_input_name = input_name + self._image_latents_output_name = output_name + self.sample_mode = sample_mode + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}." + + @property + def expected_components(self) -> list[ComponentSpec]: + return [ComponentSpec("vae", AutoencoderKL)] + + @property + def inputs(self) -> list[InputParam]: + return[InputParam(self._image_input_name), InputParam("generator")] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam(self._image_latents_output_name, type_hint=torch.Tensor, description="The latents representing the reference image") + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + image = getattr(block_state, self._image_input_name) + + if image is None: + setattr(block_state, self._image_latents_output_name, None) + else: + device = components._execution_device + dtype = components.vae.dtype + image = image.to(device=device, dtype=dtype) + image_latents = encode_vae_image( + image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode + ) + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + return components, state + +class SD3TextEncoderStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Text Encoder step that generates text embeddings to guide the image generation for SD3." + + @property + def expected_components(self) -> list[ComponentSpec]: + return[ + ComponentSpec("text_encoder", CLIPTextModelWithProjection), + ComponentSpec("tokenizer", CLIPTokenizer), + ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), + ComponentSpec("tokenizer_2", CLIPTokenizer), + ComponentSpec("text_encoder_3", T5EncoderModel), + ComponentSpec("tokenizer_3", T5TokenizerFast), + ] + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("prompt"), + InputParam("prompt_2"), + InputParam("prompt_3"), + InputParam("negative_prompt"), + InputParam("negative_prompt_2"), + InputParam("negative_prompt_3"), + InputParam("prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + InputParam("guidance_scale", default=7.0), + InputParam("clip_skip", type_hint=int), + InputParam("max_sequence_length", type_hint=int, default=256), + InputParam("joint_attention_kwargs"), + ] + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @staticmethod + def _get_t5_prompt_embeds(components, prompt, max_sequence_length, device): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if components.text_encoder_3 is None: + return torch.zeros( + (batch_size, max_sequence_length, components.transformer.config.joint_attention_dim), + device=device, + dtype=components.text_encoder.dtype, + ) + + text_inputs = components.tokenizer_3( + prompt, padding="max_length", max_length=max_sequence_length, + truncation=True, add_special_tokens=True, return_tensors="pt", + ) + prompt_embeds = components.text_encoder_3(text_inputs.input_ids.to(device))[0] + return prompt_embeds.to(dtype=components.text_encoder_3.dtype, device=device) + + @staticmethod + def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_index): + clip_tokenizers = [components.tokenizer, components.tokenizer_2] + clip_text_encoders =[components.text_encoder, components.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + text_inputs = tokenizer( + prompt, padding="max_length", max_length=tokenizer.model_max_length, + truncation=True, return_tensors="pt", + ) + + prompt_embeds = text_encoder(text_inputs.input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + return prompt_embeds.to(dtype=components.text_encoder.dtype, device=device), pooled_prompt_embeds + + @staticmethod + def encode_prompt(components, block_state, device, do_classifier_free_guidance, lora_scale): + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if components.text_encoder is not None: scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt_embeds = block_state.prompt_embeds + pooled_prompt_embeds = block_state.pooled_prompt_embeds + + if prompt_embeds is None: + prompt = [block_state.prompt] if isinstance(block_state.prompt, str) else block_state.prompt + prompt_2 = block_state.prompt_2 or prompt + prompt_3 = block_state.prompt_3 or prompt + + prompt_embed, pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, prompt, device, block_state.clip_skip, 0) + prompt_2_embed, pooled_2_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, prompt_2, device, block_state.clip_skip, 1) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = SD3TextEncoderStep._get_t5_prompt_embeds(components, prompt_3, block_state.max_sequence_length, device) + clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_embed, pooled_2_embed], dim=-1) + + negative_prompt_embeds = block_state.negative_prompt_embeds + negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds + + if do_classifier_free_guidance and negative_prompt_embeds is None: + batch_size = prompt_embeds.shape[0] + neg_prompt = block_state.negative_prompt or "" + neg_prompt_2 = block_state.negative_prompt_2 or neg_prompt + neg_prompt_3 = block_state.negative_prompt_3 or neg_prompt + + neg_prompt = batch_size * [neg_prompt] if isinstance(neg_prompt, str) else neg_prompt + neg_prompt_2 = batch_size * [neg_prompt_2] if isinstance(neg_prompt_2, str) else neg_prompt_2 + neg_prompt_3 = batch_size * [neg_prompt_3] if isinstance(neg_prompt_3, str) else neg_prompt_3 + + neg_embed, neg_pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt, device, None, 0) + neg_2_embed, neg_2_pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt_2, device, None, 1) + neg_clip_embeds = torch.cat([neg_embed, neg_2_embed], dim=-1) + + t5_neg_embed = SD3TextEncoderStep._get_t5_prompt_embeds(components, neg_prompt_3, block_state.max_sequence_length, device) + neg_clip_embeds = torch.nn.functional.pad(neg_clip_embeds, (0, t5_neg_embed.shape[-1] - neg_clip_embeds.shape[-1])) + + negative_prompt_embeds = torch.cat([neg_clip_embeds, t5_neg_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat([neg_pooled_embed, neg_2_pooled_embed], dim=-1) + + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + if components.text_encoder is not None: unscale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + block_state.device = components._execution_device + + do_classifier_free_guidance = block_state.guidance_scale > 1.0 + lora_scale = block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs else None + + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( + components, block_state, block_state.device, do_classifier_free_guidance, lora_scale + ) + + block_state.prompt_embeds = prompt_embeds + block_state.negative_prompt_embeds = negative_prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py new file mode 100644 index 000000000000..61ca894faafc --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -0,0 +1,141 @@ +import torch +from ...utils import logging +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam +from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size +from .modular_pipeline import SD3ModularPipeline + +logger = logging.get_logger(__name__) + +class SD3TextInputStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + @property + def description(self) -> str: + return "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." + + @property + def inputs(self) -> list[InputParam]: + return[ + InputParam("num_images_per_prompt", default=1), + InputParam("guidance_scale", default=7.0), + InputParam("skip_guidance_layers", type_hint=list), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @property + def intermediate_outputs(self) -> list[str]: + return[ + OutputParam("batch_size", type_hint=int), + OutputParam("dtype", type_hint=torch.dtype), + OutputParam("do_classifier_free_guidance", type_hint=bool), + OutputParam("prompt_embeds", type_hint=torch.Tensor), + OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam("original_prompt_embeds", type_hint=torch.Tensor), + OutputParam("original_pooled_prompt_embeds", type_hint=torch.Tensor), + ] + + @torch.no_grad() + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + block_state.do_classifier_free_guidance = block_state.guidance_scale > 1.0 + + _, seq_len, _ = block_state.prompt_embeds.shape + prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: + _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape + negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, neg_seq_len, -1) + + negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + + if block_state.skip_guidance_layers is not None: + block_state.original_prompt_embeds = prompt_embeds + block_state.original_pooled_prompt_embeds = pooled_prompt_embeds + + block_state.prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + block_state.pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + else: + block_state.prompt_embeds = prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + + self.set_block_state(state, block_state) + return components, state + +class SD3AdditionalInputsStep(ModularPipelineBlocks): + model_name = "stable-diffusion-3" + + def __init__(self, image_latent_inputs: list[str] = ["image_latents"], additional_batch_inputs: list[str] =[]): + self._image_latent_inputs = image_latent_inputs if isinstance(image_latent_inputs, list) else [image_latent_inputs] + self._additional_batch_inputs = additional_batch_inputs if isinstance(additional_batch_inputs, list) else[additional_batch_inputs] + super().__init__() + + @property + def description(self) -> str: + return "Updates height/width if None, and expands batch size. SD3 does not pack latents on pipeline level." + + @property + def inputs(self) -> list[InputParam]: + inputs =[ + InputParam("num_images_per_prompt", default=1), + InputParam("batch_size", required=True), + InputParam("height"), + InputParam("width"), + ] + for name in self._image_latent_inputs + self._additional_batch_inputs: + inputs.append(InputParam(name)) + return inputs + + @property + def intermediate_outputs(self) -> list[OutputParam]: + return[ + OutputParam("image_height", type_hint=int), + OutputParam("image_width", type_hint=int), + ] + + def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + for input_name in self._image_latent_inputs: + tensor = getattr(block_state, input_name) + if tensor is None: + continue + + height, width = calculate_dimension_from_latents(tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + tensor = repeat_tensor_to_batch_size( + input_name=input_name, input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size + ) + setattr(block_state, input_name, tensor) + + for input_name in self._additional_batch_inputs: + tensor = getattr(block_state, input_name) + if tensor is None: continue + tensor = repeat_tensor_to_batch_size( + input_name=input_name, input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size + ) + setattr(block_state, input_name, tensor) + + self.set_block_state(state, block_state) + return components, state \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py new file mode 100644 index 000000000000..719910b5ca8f --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -0,0 +1,105 @@ +from ...utils import logging +from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks +from ..modular_pipeline_utils import InsertableDict, OutputParam +from .before_denoise import ( + SD3Img2ImgPrepareLatentsStep, + SD3Img2ImgSetTimestepsStep, + SD3PrepareLatentsStep, + SD3SetTimestepsStep, +) +from .decoders import SD3DecodeStep +from .denoise import SD3DenoiseStep +from .encoders import ( + SD3ProcessImagesInputStep, + SD3TextEncoderStep, + SD3VaeEncoderStep, +) +from .inputs import ( + SD3AdditionalInputsStep, + SD3TextInputStep, +) + + +logger = logging.get_logger(__name__) + + +class SD3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes = [SD3ProcessImagesInputStep(), SD3VaeEncoderStep()] + block_names = ["preprocess", "encode"] + + +class SD3AutoVaeEncoderStep(AutoPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3Img2ImgVaeEncoderStep] + block_names = ["img2img"] + block_trigger_inputs =["image"] + + +class SD3BeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3PrepareLatentsStep(), SD3SetTimestepsStep()] + block_names = ["prepare_latents", "set_timesteps"] + + +class SD3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[ + SD3PrepareLatentsStep(), + SD3Img2ImgSetTimestepsStep(), + SD3Img2ImgPrepareLatentsStep(), + ] + block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents"] + + +class SD3AutoBeforeDenoiseStep(AutoPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3Img2ImgBeforeDenoiseStep, SD3BeforeDenoiseStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + +class SD3Img2ImgInputStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3TextInputStep(), SD3AdditionalInputsStep()] + block_names =["text_inputs", "additional_inputs"] + + +class SD3AutoInputStep(AutoPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes = [SD3Img2ImgInputStep, SD3TextInputStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] + + +class SD3CoreDenoiseStep(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes =[SD3AutoInputStep, SD3AutoBeforeDenoiseStep, SD3DenoiseStep] + block_names =["input", "before_denoise", "denoise"] + @property + def outputs(self): + return [OutputParam.template("latents")] + + +AUTO_BLOCKS = InsertableDict([ + ("text_encoder", SD3TextEncoderStep()), + ("vae_encoder", SD3AutoVaeEncoderStep()), + ("denoise", SD3CoreDenoiseStep()), + ("decode", SD3DecodeStep()), + ] +) + + +class SD3AutoBlocks(SequentialPipelineBlocks): + model_name = "stable-diffusion-3" + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() + + _workflow_map = { + "text2image": {"prompt": True}, + "image2image": {"image": True, "prompt": True}, + } + + @property + def outputs(self): + return [OutputParam.template("images")] \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py new file mode 100644 index 000000000000..56033fa08bc7 --- /dev/null +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -0,0 +1,48 @@ +from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin +from ...utils import logging +from ..modular_pipeline import ModularPipeline + +logger = logging.get_logger(__name__) + + +class SD3ModularPipeline(ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): + """ + A ModularPipeline for Stable Diffusion 3. + + >[!WARNING] > This is an experimental feature and is likely to change in the future. + """ + + default_blocks_name = "SD3AutoBlocks" + + @property + def default_height(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_width(self): + return self.default_sample_size * self.vae_scale_factor + + @property + def default_sample_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.sample_size + return 128 + + @property + def patch_size(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.patch_size + return 2 + + @property + def vae_scale_factor(self): + vae_scale_factor = 8 + if getattr(self, "vae", None) is not None: + vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + return vae_scale_factor + + @property + def num_channels_latents(self): + if getattr(self, "transformer", None) is not None: + return self.transformer.config.in_channels + return 16 \ No newline at end of file diff --git a/tests/modular_pipelines/stable_diffusion_3/__init__.py b/tests/modular_pipelines/stable_diffusion_3/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py new file mode 100644 index 000000000000..20c1542ee3ab --- /dev/null +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -0,0 +1,122 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +import random +import numpy as np +import PIL +import torch + +from diffusers.image_processor import VaeImageProcessor +from diffusers.modular_pipelines import ModularPipeline +from diffusers.modular_pipelines.stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline + +from ...testing_utils import floats_tensor, torch_device +from ..test_modular_pipelines_common import ModularPipelineTesterMixin + + +SD3_TEXT2IMAGE_WORKFLOWS = { + "text2image":[ + ("text_encoder", "SD3TextEncoderStep"), + ("denoise.input", "SD3TextInputStep"), + ("denoise.before_denoise.prepare_latents", "SD3PrepareLatentsStep"), + ("denoise.before_denoise.set_timesteps", "SD3SetTimestepsStep"), + ("denoise.denoise", "SD3DenoiseStep"), + ("decode", "SD3DecodeStep"), + ] +} + +class TestSD3ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = SD3ModularPipeline + pipeline_blocks_class = SD3AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + + params = frozenset(["prompt", "height", "width", "guidance_scale"]) + batch_params = frozenset(["prompt"]) + expected_workflow_blocks = SD3_TEXT2IMAGE_WORKFLOWS + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + return { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "pt", + } + + def test_float16_inference(self): + super().test_float16_inference(9e-2) + + +SD3_IMAGE2IMAGE_WORKFLOWS = { + "image2image":[ + ("text_encoder", "SD3TextEncoderStep"), + ("vae_encoder.preprocess", "SD3ProcessImagesInputStep"), + ("vae_encoder.encode", "SD3VaeEncoderStep"), + ("denoise.input.text_inputs", "SD3TextInputStep"), + ("denoise.input.additional_inputs", "SD3AdditionalInputsStep"), + ("denoise.before_denoise.prepare_latents", "SD3PrepareLatentsStep"), + ("denoise.before_denoise.set_timesteps", "SD3Img2ImgSetTimestepsStep"), + ("denoise.before_denoise.prepare_img2img_latents", "SD3Img2ImgPrepareLatentsStep"), + ("denoise.denoise", "SD3DenoiseStep"), + ("decode", "SD3DecodeStep"), + ] +} + +class TestSD3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = SD3ModularPipeline + pipeline_blocks_class = SD3AutoBlocks + pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + + params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) + batch_params = frozenset(["prompt", "image"]) + expected_workflow_blocks = SD3_IMAGE2IMAGE_WORKFLOWS + + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = super().get_pipeline(components_manager, torch_dtype) + pipeline.image_processor = VaeImageProcessor(vae_scale_factor=8) + return pipeline + + def get_dummy_inputs(self, seed=0): + generator = self.get_generator(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 4, + "guidance_scale": 5.0, + "height": 32, + "width": 32, + "max_sequence_length": 48, + "output_type": "pt", + } + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(torch_device) + image = image.cpu().permute(0, 2, 3, 1)[0] + init_image = PIL.Image.fromarray(np.uint8(image)).convert("RGB") + inputs["image"] = init_image + inputs["strength"] = 0.5 + return inputs + + def test_save_from_pretrained(self, tmp_path): + pipes =[] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + base_pipe.save_pretrained(str(tmp_path)) + pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipe.image_processor = VaeImageProcessor(vae_scale_factor=8) + pipes.append(pipe) + + image_slices =[] + for pipe in pipes: + inputs = self.get_dummy_inputs() + image = pipe(**inputs, output="images") + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_float16_inference(self): + super().test_float16_inference(8e-2) \ No newline at end of file From 08d14c60e7d155568974b02792bc34248776779d Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 24 Mar 2026 17:17:32 +0000 Subject: [PATCH 02/14] add blocks to various inits --- src/diffusers/__init__.py | 4 + src/diffusers/modular_pipelines/__init__.py | 2 + .../modular_pipelines/modular_pipeline.py | 3 +- .../stable_diffusion_3/before_denoise.py | 14 ++++ .../stable_diffusion_3/decoders.py | 14 ++++ .../stable_diffusion_3/denoise.py | 14 ++++ .../stable_diffusion_3/encoders.py | 14 ++++ .../stable_diffusion_3/inputs.py | 21 +++++- .../modular_blocks_stable_diffusion_3.py | 14 ++++ .../stable_diffusion_3/modular_pipeline.py | 14 ++++ .../dummy_torch_and_transformers_objects.py | 29 ++++++++ ...est_modular_pipeline_stable_diffusion_3.py | 73 +++++++++++++++++-- 12 files changed, 205 insertions(+), 11 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0be7b8166a37..0f2852baf421 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -450,6 +450,8 @@ "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", + "SD3AutoBlocks", + "SD3ModularPipeline", "Wan22Blocks", "Wan22Image2VideoBlocks", "Wan22Image2VideoModularPipeline", @@ -1211,6 +1213,8 @@ QwenImageModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, + SD3AutoBlocks, + SD3ModularPipeline, Wan22Blocks, Wan22Image2VideoBlocks, Wan22Image2VideoModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index fd9bd691ca87..e9a92c5704ac 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -46,6 +46,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] + _import_structure["stable_diffusion_3"] =["SD3AutoBlocks", "SD3ModularPipeline"] _import_structure["wan"] = [ "WanBlocks", "Wan22Blocks", @@ -141,6 +142,7 @@ QwenImageModularPipeline, ) from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline + from .stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline from .wan import ( Wan22Blocks, Wan22Image2VideoBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 9cd2f9f5c6ae..e2ca24812e72 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -119,8 +119,9 @@ def _helios_pyramid_map_fn(config_dict=None): MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), + ("stable-diffusion-3", _create_default_map_fn("SD3ModularPipeline")), ("wan", _wan_map_fn), - ("wan-i2v", _wan_i2v_map_fn), + ("wan-i2v", _wan_i2v_map_fn), ("flux", _create_default_map_fn("FluxModularPipeline")), ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), ("flux2", _create_default_map_fn("Flux2ModularPipeline")), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 7eee1d7dc652..ebadf45236da 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect import numpy as np diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index 3f037f1fee01..c8d9f6a562c1 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import PIL import torch diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index 4341c3daf3c9..a41e87665ede 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Any import torch diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 24f38fbfce38..6087f349a691 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 61ca894faafc..5ae213b09040 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState @@ -53,6 +67,9 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + block_state.original_prompt_embeds = prompt_embeds + block_state.original_pooled_prompt_embeds = pooled_prompt_embeds + if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) @@ -61,10 +78,6 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - if block_state.skip_guidance_layers is not None: - block_state.original_prompt_embeds = prompt_embeds - block_state.original_pooled_prompt_embeds = pooled_prompt_embeds - block_state.prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) block_state.pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) else: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 719910b5ca8f..0595d26346c2 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ...utils import logging from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index 56033fa08bc7..a54b1fd54423 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -1,3 +1,17 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin from ...utils import logging from ..modular_pipeline import ModularPipeline diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 2ec5bc002f41..a3d9f8bcf56c 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -391,6 +391,35 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class SD3AutoBlocks(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class SD3ModularPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls,["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + class Wan22Blocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 20c1542ee3ab..860fa70f0565 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -1,5 +1,18 @@ # coding=utf-8 -# Copyright 2025 HuggingFace Inc. +# Copyright 2026 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random import numpy as np import PIL @@ -46,10 +59,45 @@ def get_dummy_inputs(self, seed=0): "output_type": "pt", } + def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): + pipeline = self.pipeline_class.from_pretrained( + self.pretrained_model_name_or_path, torch_dtype=torch_dtype + ) + if components_manager is not None: + pipeline.components_manager = components_manager + return pipeline + + def test_save_from_pretrained(self, tmp_path): + pipes =[] + base_pipe = self.get_pipeline().to(torch_device) + pipes.append(base_pipe) + + base_pipe.save_pretrained(str(tmp_path)) + pipe = self.pipeline_class.from_pretrained(tmp_path).to(torch_device) + pipe.load_components(torch_dtype=torch.float32) + pipe.to(torch_device) + pipes.append(pipe) + + image_slices =[] + for p in pipes: + inputs = self.get_dummy_inputs() + image = p(**inputs, output="images") + image_slices.append(image[0, -3:, -3:, -1].flatten()) + + assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + + def test_load_expected_components_from_save_pretrained(self, tmp_path): + base_pipe = self.get_pipeline() + base_pipe.save_pretrained(str(tmp_path)) + + pipe = self.pipeline_class.from_pretrained(tmp_path) + pipe.load_components(torch_dtype=torch.float32) + + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) + def test_float16_inference(self): super().test_float16_inference(9e-2) - SD3_IMAGE2IMAGE_WORKFLOWS = { "image2image":[ ("text_encoder", "SD3TextEncoderStep"), @@ -75,7 +123,11 @@ class TestSD3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): expected_workflow_blocks = SD3_IMAGE2IMAGE_WORKFLOWS def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): - pipeline = super().get_pipeline(components_manager, torch_dtype) + pipeline = self.pipeline_class.from_pretrained( + self.pretrained_model_name_or_path, torch_dtype=torch_dtype + ) + if components_manager is not None: + pipeline.components_manager = components_manager pipeline.image_processor = VaeImageProcessor(vae_scale_factor=8) return pipeline @@ -104,19 +156,28 @@ def test_save_from_pretrained(self, tmp_path): pipes.append(base_pipe) base_pipe.save_pretrained(str(tmp_path)) - pipe = ModularPipeline.from_pretrained(tmp_path).to(torch_device) + pipe = self.pipeline_class.from_pretrained(tmp_path).to(torch_device) pipe.load_components(torch_dtype=torch.float32) pipe.to(torch_device) pipe.image_processor = VaeImageProcessor(vae_scale_factor=8) pipes.append(pipe) image_slices =[] - for pipe in pipes: + for p in pipes: inputs = self.get_dummy_inputs() - image = pipe(**inputs, output="images") + image = p(**inputs, output="images") image_slices.append(image[0, -3:, -3:, -1].flatten()) assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3 + def test_load_expected_components_from_save_pretrained(self, tmp_path): + base_pipe = self.get_pipeline() + base_pipe.save_pretrained(str(tmp_path)) + + pipe = self.pipeline_class.from_pretrained(tmp_path) + pipe.load_components(torch_dtype=torch.float32) + + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) + def test_float16_inference(self): super().test_float16_inference(8e-2) \ No newline at end of file From 0a81741904319427b82e77a80c922084dbd933ce Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Tue, 24 Mar 2026 17:20:28 +0000 Subject: [PATCH 03/14] styling --- src/diffusers/__init__.py | 4 ++-- src/diffusers/modular_pipelines/__init__.py | 2 +- .../modular_pipelines/modular_pipeline.py | 2 +- .../stable_diffusion_3/__init__.py | 3 ++- .../stable_diffusion_3/before_denoise.py | 14 ++++++-------- .../stable_diffusion_3/decoders.py | 4 ++-- .../stable_diffusion_3/denoise.py | 7 ++++--- .../stable_diffusion_3/encoders.py | 19 ++++++++++++------- .../stable_diffusion_3/inputs.py | 11 +++++++---- .../modular_blocks_stable_diffusion_3.py | 2 +- .../stable_diffusion_3/modular_pipeline.py | 3 ++- .../dummy_torch_and_transformers_objects.py | 2 +- ...est_modular_pipeline_stable_diffusion_3.py | 12 ++++++------ 13 files changed, 47 insertions(+), 38 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0f2852baf421..c1fcef28465b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -1211,10 +1211,10 @@ QwenImageLayeredAutoBlocks, QwenImageLayeredModularPipeline, QwenImageModularPipeline, - StableDiffusionXLAutoBlocks, - StableDiffusionXLModularPipeline, SD3AutoBlocks, SD3ModularPipeline, + StableDiffusionXLAutoBlocks, + StableDiffusionXLModularPipeline, Wan22Blocks, Wan22Image2VideoBlocks, Wan22Image2VideoModularPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index e9a92c5704ac..3e4802609be3 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -141,8 +141,8 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) - from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline + from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, Wan22Image2VideoBlocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index e2ca24812e72..2f36d8526cdc 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -121,7 +121,7 @@ def _helios_pyramid_map_fn(config_dict=None): ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), ("stable-diffusion-3", _create_default_map_fn("SD3ModularPipeline")), ("wan", _wan_map_fn), - ("wan-i2v", _wan_i2v_map_fn), + ("wan-i2v", _wan_i2v_map_fn), ("flux", _create_default_map_fn("FluxModularPipeline")), ("flux-kontext", _create_default_map_fn("FluxKontextModularPipeline")), ("flux2", _create_default_map_fn("Flux2ModularPipeline")), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py index 13396327ee7c..d6a8b5891986 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -9,6 +9,7 @@ is_transformers_available, ) + _dummy_objects = {} _import_structure = {} @@ -43,4 +44,4 @@ ) for name, value in _dummy_objects.items(): - setattr(sys.modules[__name__], name, value) \ No newline at end of file + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index ebadf45236da..6781235e1ac8 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect -import numpy as np import torch from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -143,10 +141,10 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.device, getattr(block_state, "mu", None) ) - + block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps - + self.set_block_state(state, block_state) return components, state @@ -207,11 +205,11 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.device, getattr(block_state, "mu", None) ) - + timesteps, num_inference_steps = self.get_timesteps( components.scheduler, num_inference_steps, block_state.strength ) - + block_state.timesteps = timesteps block_state.num_inference_steps = num_inference_steps @@ -247,7 +245,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state = self.get_block_state(state) block_state.device = components._execution_device batch_size = block_state.batch_size * block_state.num_images_per_prompt - + if block_state.latents is not None: block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) else: @@ -291,4 +289,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.image_latents, latent_timestep, block_state.latents ) self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index c8d9f6a562c1..939df4b5bf36 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -16,9 +16,9 @@ import torch from ...configuration_utils import FrozenDict +from ...image_processor import VaeImageProcessor from ...models import AutoencoderKL from ...utils import logging -from ...image_processor import VaeImageProcessor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam @@ -62,4 +62,4 @@ def __call__(self, components, state: PipelineState) -> PipelineState: block_state.images = block_state.latents self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index a41e87665ede..dc57b994e33b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -28,6 +28,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import SD3ModularPipeline + logger = logging.get_logger(__name__) @@ -79,7 +80,7 @@ def __call__( if block_state.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + block_state.guidance_scale * (noise_pred_text - noise_pred_uncond) - + should_skip_layers = ( getattr(block_state, "skip_guidance_layers", None) is not None and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) @@ -151,7 +152,7 @@ def loop_inputs(self) -> list[InputParam]: def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) - + with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) @@ -164,4 +165,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe class SD3DenoiseStep(SD3DenoiseLoopWrapper): block_classes = [SD3LoopDenoiser, SD3LoopAfterDenoiser] - block_names = ["denoiser", "after_denoiser"] \ No newline at end of file + block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 6087f349a691..46ae89ac65c9 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -24,6 +24,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import SD3ModularPipeline + logger = logging.get_logger(__name__) def retrieve_latents( @@ -95,7 +96,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState): if block_state.resized_image is None: image = block_state.image self.check_inputs( - height=block_state.height, width=block_state.width, + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor, patch_size=components.patch_size ) height = block_state.height or components.default_height @@ -233,7 +234,7 @@ def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_in prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) - + prompt_embeds = text_encoder(text_inputs.input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] @@ -247,8 +248,10 @@ def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_in @staticmethod def encode_prompt(components, block_state, device, do_classifier_free_guidance, lora_scale): if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - if components.text_encoder is not None: scale_lora_layers(components.text_encoder, lora_scale) - if components.text_encoder_2 is not None: scale_lora_layers(components.text_encoder_2, lora_scale) + if components.text_encoder is not None: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: + scale_lora_layers(components.text_encoder_2, lora_scale) prompt_embeds = block_state.prompt_embeds pooled_prompt_embeds = block_state.pooled_prompt_embeds @@ -292,8 +295,10 @@ def encode_prompt(components, block_state, device, do_classifier_free_guidance, negative_pooled_prompt_embeds = torch.cat([neg_pooled_embed, neg_2_pooled_embed], dim=-1) if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - if components.text_encoder is not None: unscale_lora_layers(components.text_encoder, lora_scale) - if components.text_encoder_2 is not None: unscale_lora_layers(components.text_encoder_2, lora_scale) + if components.text_encoder is not None: + unscale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None: + unscale_lora_layers(components.text_encoder_2, lora_scale) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -315,4 +320,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 5ae213b09040..225e23994f1f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -13,12 +13,14 @@ # limitations under the License. import torch + from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size from .modular_pipeline import SD3ModularPipeline + logger = logging.get_logger(__name__) class SD3TextInputStep(ModularPipelineBlocks): @@ -55,7 +57,7 @@ def intermediate_outputs(self) -> list[str]: @torch.no_grad() def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - + block_state.batch_size = block_state.prompt_embeds.shape[0] block_state.dtype = block_state.prompt_embeds.dtype block_state.do_classifier_free_guidance = block_state.guidance_scale > 1.0 @@ -129,7 +131,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe height, width = calculate_dimension_from_latents(tensor, components.vae_scale_factor) block_state.height = block_state.height or height block_state.width = block_state.width or width - + if not hasattr(block_state, "image_height"): block_state.image_height = height if not hasattr(block_state, "image_width"): @@ -143,7 +145,8 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe for input_name in self._additional_batch_inputs: tensor = getattr(block_state, input_name) - if tensor is None: continue + if tensor is None: + continue tensor = repeat_tensor_to_batch_size( input_name=input_name, input_tensor=tensor, num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size @@ -151,4 +154,4 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe setattr(block_state, input_name, tensor) self.set_block_state(state, block_state) - return components, state \ No newline at end of file + return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 0595d26346c2..34e850bf11b8 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -116,4 +116,4 @@ class SD3AutoBlocks(SequentialPipelineBlocks): @property def outputs(self): - return [OutputParam.template("images")] \ No newline at end of file + return [OutputParam.template("images")] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index a54b1fd54423..657cda1a08ad 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -16,6 +16,7 @@ from ...utils import logging from ..modular_pipeline import ModularPipeline + logger = logging.get_logger(__name__) @@ -59,4 +60,4 @@ def vae_scale_factor(self): def num_channels_latents(self): if getattr(self, "transformer", None) is not None: return self.transformer.config.in_channels - return 16 \ No newline at end of file + return 16 diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a3d9f8bcf56c..6f23acf9e9fd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -419,7 +419,7 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - + class Wan22Blocks(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 860fa70f0565..d256f22b9ae8 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -14,12 +14,12 @@ # limitations under the License. import random + import numpy as np import PIL import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.modular_pipelines import ModularPipeline from diffusers.modular_pipelines.stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline from ...testing_utils import floats_tensor, torch_device @@ -89,10 +89,10 @@ def test_save_from_pretrained(self, tmp_path): def test_load_expected_components_from_save_pretrained(self, tmp_path): base_pipe = self.get_pipeline() base_pipe.save_pretrained(str(tmp_path)) - + pipe = self.pipeline_class.from_pretrained(tmp_path) pipe.load_components(torch_dtype=torch.float32) - + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) def test_float16_inference(self): @@ -173,11 +173,11 @@ def test_save_from_pretrained(self, tmp_path): def test_load_expected_components_from_save_pretrained(self, tmp_path): base_pipe = self.get_pipeline() base_pipe.save_pretrained(str(tmp_path)) - + pipe = self.pipeline_class.from_pretrained(tmp_path) pipe.load_components(torch_dtype=torch.float32) - + assert set(base_pipe.components.keys()) == set(pipe.components.keys()) def test_float16_inference(self): - super().test_float16_inference(8e-2) \ No newline at end of file + super().test_float16_inference(8e-2) From ba2938310015fbc8db2dbe39bcc0f6c4416b005a Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Thu, 26 Mar 2026 15:09:45 +0000 Subject: [PATCH 04/14] push tiny-sd3-modular to hub and fix the tests --- ...est_modular_pipeline_stable_diffusion_3.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index d256f22b9ae8..0a8c44717092 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -40,7 +40,7 @@ class TestSD3ModularPipelineFast(ModularPipelineTesterMixin): pipeline_class = SD3ModularPipeline pipeline_blocks_class = SD3AutoBlocks - pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + pretrained_model_name_or_path = "AlanPonnachan/tiny-sd3-modular" params = frozenset(["prompt", "height", "width", "guidance_scale"]) batch_params = frozenset(["prompt"]) @@ -60,12 +60,7 @@ def get_dummy_inputs(self, seed=0): } def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): - pipeline = self.pipeline_class.from_pretrained( - self.pretrained_model_name_or_path, torch_dtype=torch_dtype - ) - if components_manager is not None: - pipeline.components_manager = components_manager - return pipeline + return super().get_pipeline(components_manager, torch_dtype) def test_save_from_pretrained(self, tmp_path): pipes =[] @@ -116,18 +111,14 @@ def test_float16_inference(self): class TestSD3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): pipeline_class = SD3ModularPipeline pipeline_blocks_class = SD3AutoBlocks - pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe" + pretrained_model_name_or_path = "AlanPonnachan/tiny-sd3-modular" params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) batch_params = frozenset(["prompt", "image"]) expected_workflow_blocks = SD3_IMAGE2IMAGE_WORKFLOWS def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): - pipeline = self.pipeline_class.from_pretrained( - self.pretrained_model_name_or_path, torch_dtype=torch_dtype - ) - if components_manager is not None: - pipeline.components_manager = components_manager + pipeline = super().get_pipeline(components_manager, torch_dtype) pipeline.image_processor = VaeImageProcessor(vae_scale_factor=8) return pipeline @@ -180,4 +171,4 @@ def test_load_expected_components_from_save_pretrained(self, tmp_path): assert set(base_pipe.components.keys()) == set(pipe.components.keys()) def test_float16_inference(self): - super().test_float16_inference(8e-2) + super().test_float16_inference(9e-2) From ad15c9da676f5a90976c05f8863c01cb100973b4 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Thu, 26 Mar 2026 16:41:36 +0000 Subject: [PATCH 05/14] rename modules --- src/diffusers/__init__.py | 8 +-- src/diffusers/modular_pipelines/__init__.py | 4 +- .../modular_pipelines/modular_pipeline.py | 2 +- .../stable_diffusion_3/__init__.py | 8 +-- .../stable_diffusion_3/before_denoise.py | 56 +++------------ .../stable_diffusion_3/decoders.py | 2 +- .../stable_diffusion_3/denoise.py | 20 +++--- .../stable_diffusion_3/encoders.py | 26 +++---- .../stable_diffusion_3/inputs.py | 10 +-- .../modular_blocks_stable_diffusion_3.py | 68 +++++++++---------- .../stable_diffusion_3/modular_pipeline.py | 4 +- .../dummy_torch_and_transformers_objects.py | 4 +- ...est_modular_pipeline_stable_diffusion_3.py | 46 ++++++------- 13 files changed, 111 insertions(+), 147 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c1fcef28465b..b08524e12b47 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -450,8 +450,8 @@ "QwenImageModularPipeline", "StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline", - "SD3AutoBlocks", - "SD3ModularPipeline", + "StableDiffusion3AutoBlocks", + "StableDiffusion3ModularPipeline", "Wan22Blocks", "Wan22Image2VideoBlocks", "Wan22Image2VideoModularPipeline", @@ -1211,8 +1211,8 @@ QwenImageLayeredAutoBlocks, QwenImageLayeredModularPipeline, QwenImageModularPipeline, - SD3AutoBlocks, - SD3ModularPipeline, + StableDiffusion3AutoBlocks, + StableDiffusion3ModularPipeline, StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline, Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 3e4802609be3..1cd4c14bc844 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -46,7 +46,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["stable_diffusion_3"] =["SD3AutoBlocks", "SD3ModularPipeline"] + _import_structure["stable_diffusion_3"] =["StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline"] _import_structure["wan"] = [ "WanBlocks", "Wan22Blocks", @@ -141,7 +141,7 @@ QwenImageLayeredModularPipeline, QwenImageModularPipeline, ) - from .stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline + from .stable_diffusion_3 import StableDiffusion3AutoBlocks, StableDiffusion3ModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .wan import ( Wan22Blocks, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 2f36d8526cdc..25fc5baa6779 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -119,7 +119,7 @@ def _helios_pyramid_map_fn(config_dict=None): MODULAR_PIPELINE_MAPPING = OrderedDict( [ ("stable-diffusion-xl", _create_default_map_fn("StableDiffusionXLModularPipeline")), - ("stable-diffusion-3", _create_default_map_fn("SD3ModularPipeline")), + ("stable-diffusion-3", _create_default_map_fn("StableDiffusion3ModularPipeline")), ("wan", _wan_map_fn), ("wan-i2v", _wan_i2v_map_fn), ("flux", _create_default_map_fn("FluxModularPipeline")), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py index d6a8b5891986..d7bc6020a816 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -21,8 +21,8 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modular_blocks_stable_diffusion_3"] = ["SD3AutoBlocks"] - _import_structure["modular_pipeline"] = ["SD3ModularPipeline"] + _import_structure["modular_blocks_stable_diffusion_3"] = ["StableDiffusion3AutoBlocks"] + _import_structure["modular_pipeline"] = ["StableDiffusion3ModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -31,8 +31,8 @@ except OptionalDependencyNotAvailable: from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 else: - from .modular_blocks_stable_diffusion_3 import SD3AutoBlocks - from .modular_pipeline import SD3ModularPipeline + from .modular_blocks_stable_diffusion_3 import StableDiffusion3AutoBlocks + from .modular_pipeline import StableDiffusion3ModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 6781235e1ac8..fd854d5cd659 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -15,54 +15,18 @@ import torch +from ...pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import calculate_shift, retrieve_timesteps from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import SD3ModularPipeline +from .modular_pipeline import StableDiffusion3ModularPipeline logger = logging.get_logger(__name__) -def retrieve_timesteps( - scheduler, - num_inference_steps: int | None = None, - device: str | torch.device | None = None, - timesteps: list[int] | None = None, - sigmas: list[float] | None = None, - **kwargs, -): - if timesteps is not None and sigmas is not None: - raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") - if timesteps is not None: - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -def calculate_shift( - image_seq_len, - base_seq_len: int = 256, - max_seq_len: int = 4096, - base_shift: float = 0.5, - max_shift: float = 1.15, -): - m = (max_shift - base_shift) / (max_seq_len - base_seq_len) - b = base_shift - m * base_seq_len - mu = image_seq_len * m + b - return mu - - def _get_initial_timesteps_and_optionals( transformer, scheduler, @@ -95,7 +59,7 @@ def _get_initial_timesteps_and_optionals( return timesteps, num_inference_steps -class SD3SetTimestepsStep(ModularPipelineBlocks): +class StableDiffusion3SetTimestepsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -125,7 +89,7 @@ def intermediate_outputs(self) -> list[OutputParam]: ] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -149,7 +113,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe return components, state -class SD3Img2ImgSetTimestepsStep(ModularPipelineBlocks): +class StableDiffusion3Img2ImgSetTimestepsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -189,7 +153,7 @@ def get_timesteps(scheduler, num_inference_steps, strength): return timesteps, num_inference_steps - t_start @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -217,7 +181,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe return components, state -class SD3PrepareLatentsStep(ModularPipelineBlocks): +class StableDiffusion3PrepareLatentsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -241,7 +205,7 @@ def intermediate_outputs(self) -> list[OutputParam]: return[OutputParam("latents", type_hint=torch.Tensor)] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device batch_size = block_state.batch_size * block_state.num_images_per_prompt @@ -261,7 +225,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe return components, state -class SD3Img2ImgPrepareLatentsStep(ModularPipelineBlocks): +class StableDiffusion3Img2ImgPrepareLatentsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -281,7 +245,7 @@ def intermediate_outputs(self) -> list[OutputParam]: return [OutputParam("initial_noise", type_hint=torch.Tensor)] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) block_state.initial_noise = block_state.latents diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index 939df4b5bf36..4b500cd2c95e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -26,7 +26,7 @@ logger = logging.get_logger(__name__) -class SD3DecodeStep(ModularPipelineBlocks): +class StableDiffusion3DecodeStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index dc57b994e33b..0e7ad8abccd7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -26,13 +26,13 @@ PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import SD3ModularPipeline +from .modular_pipeline import StableDiffusion3ModularPipeline logger = logging.get_logger(__name__) -class SD3LoopDenoiser(ModularPipelineBlocks): +class StableDiffusion3LoopDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -46,7 +46,7 @@ def description(self) -> str: @property def inputs(self) -> list[tuple[str, Any]]: return[ - InputParam("joint_attention_kwargs"), + InputParam("joint_attention_kwargs", type_hint=dict), InputParam("latents", required=True, type_hint=torch.Tensor), InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), @@ -63,7 +63,7 @@ def inputs(self) -> list[tuple[str, Any]]: @torch.no_grad() def __call__( - self, components: SD3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + self, components: StableDiffusion3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: latent_model_input = torch.cat([block_state.latents] * 2) if block_state.do_classifier_free_guidance else block_state.latents timestep = t.expand(latent_model_input.shape[0]) @@ -104,7 +104,7 @@ def __call__( return components, block_state -class SD3LoopAfterDenoiser(ModularPipelineBlocks): +class StableDiffusion3LoopAfterDenoiser(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -116,7 +116,7 @@ def intermediate_outputs(self) -> list[OutputParam]: return[OutputParam("latents", type_hint=torch.Tensor)] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + def __call__(self, components: StableDiffusion3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step( block_state.noise_pred, @@ -131,7 +131,7 @@ def __call__(self, components: SD3ModularPipeline, block_state: BlockState, i: i return components, block_state -class SD3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): +class StableDiffusion3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -149,7 +149,7 @@ def loop_inputs(self) -> list[InputParam]: ] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) @@ -163,6 +163,6 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe return components, state -class SD3DenoiseStep(SD3DenoiseLoopWrapper): - block_classes = [SD3LoopDenoiser, SD3LoopAfterDenoiser] +class StableDiffusion3DenoiseStep(StableDiffusion3DenoiseLoopWrapper): + block_classes = [StableDiffusion3LoopDenoiser, StableDiffusion3LoopAfterDenoiser] block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 46ae89ac65c9..a8b654abb456 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -22,7 +22,7 @@ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_pipeline import SD3ModularPipeline +from .modular_pipeline import StableDiffusion3ModularPipeline logger = logging.get_logger(__name__) @@ -52,7 +52,7 @@ def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.G image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor return image_latents -class SD3ProcessImagesInputStep(ModularPipelineBlocks): +class StableDiffusion3ProcessImagesInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -87,7 +87,7 @@ def check_inputs(height, width, vae_scale_factor, patch_size): raise ValueError(f"Width must be divisible by {vae_scale_factor * patch_size} but is {width}") @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState): + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState): block_state = self.get_block_state(state) if block_state.resized_image is None and block_state.image is None: @@ -110,7 +110,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState): self.set_block_state(state, block_state) return components, state -class SD3VaeEncoderStep(ModularPipelineBlocks): +class StableDiffusion3VaeEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" def __init__(self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"): @@ -138,7 +138,7 @@ def intermediate_outputs(self) -> list[OutputParam]: ] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) image = getattr(block_state, self._image_input_name) @@ -156,7 +156,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state -class SD3TextEncoderStep(ModularPipelineBlocks): +class StableDiffusion3TextEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -261,11 +261,11 @@ def encode_prompt(components, block_state, device, do_classifier_free_guidance, prompt_2 = block_state.prompt_2 or prompt prompt_3 = block_state.prompt_3 or prompt - prompt_embed, pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, prompt, device, block_state.clip_skip, 0) - prompt_2_embed, pooled_2_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, prompt_2, device, block_state.clip_skip, 1) + prompt_embed, pooled_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, prompt, device, block_state.clip_skip, 0) + prompt_2_embed, pooled_2_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, prompt_2, device, block_state.clip_skip, 1) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) - t5_prompt_embed = SD3TextEncoderStep._get_t5_prompt_embeds(components, prompt_3, block_state.max_sequence_length, device) + t5_prompt_embed = StableDiffusion3TextEncoderStep._get_t5_prompt_embeds(components, prompt_3, block_state.max_sequence_length, device) clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) @@ -284,11 +284,11 @@ def encode_prompt(components, block_state, device, do_classifier_free_guidance, neg_prompt_2 = batch_size * [neg_prompt_2] if isinstance(neg_prompt_2, str) else neg_prompt_2 neg_prompt_3 = batch_size * [neg_prompt_3] if isinstance(neg_prompt_3, str) else neg_prompt_3 - neg_embed, neg_pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt, device, None, 0) - neg_2_embed, neg_2_pooled_embed = SD3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt_2, device, None, 1) + neg_embed, neg_pooled_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt, device, None, 0) + neg_2_embed, neg_2_pooled_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt_2, device, None, 1) neg_clip_embeds = torch.cat([neg_embed, neg_2_embed], dim=-1) - t5_neg_embed = SD3TextEncoderStep._get_t5_prompt_embeds(components, neg_prompt_3, block_state.max_sequence_length, device) + t5_neg_embed = StableDiffusion3TextEncoderStep._get_t5_prompt_embeds(components, neg_prompt_3, block_state.max_sequence_length, device) neg_clip_embeds = torch.nn.functional.pad(neg_clip_embeds, (0, t5_neg_embed.shape[-1] - neg_clip_embeds.shape[-1])) negative_prompt_embeds = torch.cat([neg_clip_embeds, t5_neg_embed], dim=-2) @@ -303,7 +303,7 @@ def encode_prompt(components, block_state, device, do_classifier_free_guidance, return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 225e23994f1f..97755443078f 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -18,12 +18,12 @@ from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size -from .modular_pipeline import SD3ModularPipeline +from .modular_pipeline import StableDiffusion3ModularPipeline logger = logging.get_logger(__name__) -class SD3TextInputStep(ModularPipelineBlocks): +class StableDiffusion3TextInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property @@ -55,7 +55,7 @@ def intermediate_outputs(self) -> list[str]: ] @torch.no_grad() - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.batch_size = block_state.prompt_embeds.shape[0] @@ -89,7 +89,7 @@ def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> Pipe self.set_block_state(state, block_state) return components, state -class SD3AdditionalInputsStep(ModularPipelineBlocks): +class StableDiffusion3AdditionalInputsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" def __init__(self, image_latent_inputs: list[str] = ["image_latents"], additional_batch_inputs: list[str] =[]): @@ -120,7 +120,7 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("image_width", type_hint=int), ] - def __call__(self, components: SD3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) for input_name in self._image_latent_inputs: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 34e850bf11b8..e823a58ea723 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -16,79 +16,79 @@ from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline_utils import InsertableDict, OutputParam from .before_denoise import ( - SD3Img2ImgPrepareLatentsStep, - SD3Img2ImgSetTimestepsStep, - SD3PrepareLatentsStep, - SD3SetTimestepsStep, + StableDiffusion3Img2ImgPrepareLatentsStep, + StableDiffusion3Img2ImgSetTimestepsStep, + StableDiffusion3PrepareLatentsStep, + StableDiffusion3SetTimestepsStep, ) -from .decoders import SD3DecodeStep -from .denoise import SD3DenoiseStep +from .decoders import StableDiffusion3DecodeStep +from .denoise import StableDiffusion3DenoiseStep from .encoders import ( - SD3ProcessImagesInputStep, - SD3TextEncoderStep, - SD3VaeEncoderStep, + StableDiffusion3ProcessImagesInputStep, + StableDiffusion3TextEncoderStep, + StableDiffusion3VaeEncoderStep, ) from .inputs import ( - SD3AdditionalInputsStep, - SD3TextInputStep, + StableDiffusion3AdditionalInputsStep, + StableDiffusion3TextInputStep, ) logger = logging.get_logger(__name__) -class SD3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): +class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes = [SD3ProcessImagesInputStep(), SD3VaeEncoderStep()] + block_classes = [StableDiffusion3ProcessImagesInputStep(), StableDiffusion3VaeEncoderStep()] block_names = ["preprocess", "encode"] -class SD3AutoVaeEncoderStep(AutoPipelineBlocks): +class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[SD3Img2ImgVaeEncoderStep] + block_classes =[StableDiffusion3Img2ImgVaeEncoderStep] block_names = ["img2img"] block_trigger_inputs =["image"] -class SD3BeforeDenoiseStep(SequentialPipelineBlocks): +class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[SD3PrepareLatentsStep(), SD3SetTimestepsStep()] + block_classes =[StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep()] block_names = ["prepare_latents", "set_timesteps"] -class SD3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): +class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes =[ - SD3PrepareLatentsStep(), - SD3Img2ImgSetTimestepsStep(), - SD3Img2ImgPrepareLatentsStep(), + StableDiffusion3PrepareLatentsStep(), + StableDiffusion3Img2ImgSetTimestepsStep(), + StableDiffusion3Img2ImgPrepareLatentsStep(), ] block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents"] -class SD3AutoBeforeDenoiseStep(AutoPipelineBlocks): +class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[SD3Img2ImgBeforeDenoiseStep, SD3BeforeDenoiseStep] + block_classes =[StableDiffusion3Img2ImgBeforeDenoiseStep, StableDiffusion3BeforeDenoiseStep] block_names = ["img2img", "text2image"] block_trigger_inputs = ["image_latents", None] -class SD3Img2ImgInputStep(SequentialPipelineBlocks): +class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[SD3TextInputStep(), SD3AdditionalInputsStep()] + block_classes =[StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep()] block_names =["text_inputs", "additional_inputs"] -class SD3AutoInputStep(AutoPipelineBlocks): +class StableDiffusion3AutoInputStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" - block_classes = [SD3Img2ImgInputStep, SD3TextInputStep] + block_classes = [StableDiffusion3Img2ImgInputStep, StableDiffusion3TextInputStep] block_names = ["img2img", "text2image"] block_trigger_inputs = ["image_latents", None] -class SD3CoreDenoiseStep(SequentialPipelineBlocks): +class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[SD3AutoInputStep, SD3AutoBeforeDenoiseStep, SD3DenoiseStep] + block_classes =[StableDiffusion3AutoInputStep, StableDiffusion3AutoBeforeDenoiseStep, StableDiffusion3DenoiseStep] block_names =["input", "before_denoise", "denoise"] @property def outputs(self): @@ -96,15 +96,15 @@ def outputs(self): AUTO_BLOCKS = InsertableDict([ - ("text_encoder", SD3TextEncoderStep()), - ("vae_encoder", SD3AutoVaeEncoderStep()), - ("denoise", SD3CoreDenoiseStep()), - ("decode", SD3DecodeStep()), + ("text_encoder", StableDiffusion3TextEncoderStep()), + ("vae_encoder", StableDiffusion3AutoVaeEncoderStep()), + ("denoise", StableDiffusion3CoreDenoiseStep()), + ("decode", StableDiffusion3DecodeStep()), ] ) -class SD3AutoBlocks(SequentialPipelineBlocks): +class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = AUTO_BLOCKS.values() block_names = AUTO_BLOCKS.keys() diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index 657cda1a08ad..a3a017d38e15 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -20,14 +20,14 @@ logger = logging.get_logger(__name__) -class SD3ModularPipeline(ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): +class StableDiffusion3ModularPipeline(ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): """ A ModularPipeline for Stable Diffusion 3. >[!WARNING] > This is an experimental feature and is likely to change in the future. """ - default_blocks_name = "SD3AutoBlocks" + default_blocks_name = "StableDiffusion3AutoBlocks" @property def default_height(self): diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 6f23acf9e9fd..d6c4d3972b96 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -391,7 +391,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class SD3AutoBlocks(metaclass=DummyObject): +class StableDiffusion3AutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -406,7 +406,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class SD3ModularPipeline(metaclass=DummyObject): +class StableDiffusion3ModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 0a8c44717092..51eab6fcd0ea 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -20,7 +20,7 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.modular_pipelines.stable_diffusion_3 import SD3AutoBlocks, SD3ModularPipeline +from diffusers.modular_pipelines.stable_diffusion_3 import StableDiffusion3AutoBlocks, StableDiffusion3ModularPipeline from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin @@ -28,18 +28,18 @@ SD3_TEXT2IMAGE_WORKFLOWS = { "text2image":[ - ("text_encoder", "SD3TextEncoderStep"), - ("denoise.input", "SD3TextInputStep"), - ("denoise.before_denoise.prepare_latents", "SD3PrepareLatentsStep"), - ("denoise.before_denoise.set_timesteps", "SD3SetTimestepsStep"), - ("denoise.denoise", "SD3DenoiseStep"), - ("decode", "SD3DecodeStep"), + ("text_encoder", "StableDiffusion3TextEncoderStep"), + ("denoise.input", "StableDiffusion3TextInputStep"), + ("denoise.before_denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ("denoise.before_denoise.set_timesteps", "StableDiffusion3SetTimestepsStep"), + ("denoise.denoise", "StableDiffusion3DenoiseStep"), + ("decode", "StableDiffusion3DecodeStep"), ] } -class TestSD3ModularPipelineFast(ModularPipelineTesterMixin): - pipeline_class = SD3ModularPipeline - pipeline_blocks_class = SD3AutoBlocks +class TestStableDiffusion3ModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = StableDiffusion3ModularPipeline + pipeline_blocks_class = StableDiffusion3AutoBlocks pretrained_model_name_or_path = "AlanPonnachan/tiny-sd3-modular" params = frozenset(["prompt", "height", "width", "guidance_scale"]) @@ -95,22 +95,22 @@ def test_float16_inference(self): SD3_IMAGE2IMAGE_WORKFLOWS = { "image2image":[ - ("text_encoder", "SD3TextEncoderStep"), - ("vae_encoder.preprocess", "SD3ProcessImagesInputStep"), - ("vae_encoder.encode", "SD3VaeEncoderStep"), - ("denoise.input.text_inputs", "SD3TextInputStep"), - ("denoise.input.additional_inputs", "SD3AdditionalInputsStep"), - ("denoise.before_denoise.prepare_latents", "SD3PrepareLatentsStep"), - ("denoise.before_denoise.set_timesteps", "SD3Img2ImgSetTimestepsStep"), - ("denoise.before_denoise.prepare_img2img_latents", "SD3Img2ImgPrepareLatentsStep"), - ("denoise.denoise", "SD3DenoiseStep"), - ("decode", "SD3DecodeStep"), + ("text_encoder", "StableDiffusion3TextEncoderStep"), + ("vae_encoder.preprocess", "StableDiffusion3ProcessImagesInputStep"), + ("vae_encoder.encode", "StableDiffusion3VaeEncoderStep"), + ("denoise.input.text_inputs", "StableDiffusion3TextInputStep"), + ("denoise.input.additional_inputs", "StableDiffusion3AdditionalInputsStep"), + ("denoise.before_denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ("denoise.before_denoise.set_timesteps", "StableDiffusion3Img2ImgSetTimestepsStep"), + ("denoise.before_denoise.prepare_img2img_latents", "StableDiffusion3Img2ImgPrepareLatentsStep"), + ("denoise.denoise", "StableDiffusion3DenoiseStep"), + ("decode", "StableDiffusion3DecodeStep"), ] } -class TestSD3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): - pipeline_class = SD3ModularPipeline - pipeline_blocks_class = SD3AutoBlocks +class TestStableDiffusion3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): + pipeline_class = StableDiffusion3ModularPipeline + pipeline_blocks_class = StableDiffusion3AutoBlocks pretrained_model_name_or_path = "AlanPonnachan/tiny-sd3-modular" params = frozenset(["prompt", "height", "width", "guidance_scale", "image"]) From 02bb2af815504e61459f7c03fd7506eee25457a8 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sat, 28 Mar 2026 05:17:00 +0000 Subject: [PATCH 06/14] guidance refactoring --- .../stable_diffusion_3/denoise.py | 103 +++-- .../stable_diffusion_3/encoders.py | 386 +++++++++++++----- .../stable_diffusion_3/inputs.py | 15 +- .../stable_diffusion_3/modular_pipeline.py | 6 + 4 files changed, 363 insertions(+), 147 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index 0e7ad8abccd7..738b7155eb42 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -16,6 +16,8 @@ import torch +from ...configuration_utils import FrozenDict +from ...guiders import ClassifierFreeGuidance from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging @@ -37,11 +39,19 @@ class StableDiffusion3LoopDenoiser(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return [ComponentSpec("transformer", SD3Transformer2DModel)] + return[ + ComponentSpec( + "guider", + ClassifierFreeGuidance, + config=FrozenDict({"guidance_scale": 7.0}), + default_creation_method="from_config", + ), + ComponentSpec("transformer", SD3Transformer2DModel), + ] @property def description(self) -> str: - return "Step within the denoising loop that denoise the latents." + return "Step within the denoising loop that denoises the latents." @property def inputs(self) -> list[tuple[str, Any]]: @@ -50,14 +60,14 @@ def inputs(self) -> list[tuple[str, Any]]: InputParam("latents", required=True, type_hint=torch.Tensor), InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), InputParam("do_classifier_free_guidance", type_hint=bool), InputParam("guidance_scale", default=7.0), InputParam("skip_guidance_layers", type_hint=list), InputParam("skip_layer_guidance_scale", default=2.8), InputParam("skip_layer_guidance_stop", default=0.2), InputParam("skip_layer_guidance_start", default=0.01), - InputParam("original_prompt_embeds", type_hint=torch.Tensor), - InputParam("original_pooled_prompt_embeds", type_hint=torch.Tensor), InputParam("num_inference_steps", type_hint=int), ] @@ -65,40 +75,57 @@ def inputs(self) -> list[tuple[str, Any]]: def __call__( self, components: StableDiffusion3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor ) -> PipelineState: - latent_model_input = torch.cat([block_state.latents] * 2) if block_state.do_classifier_free_guidance else block_state.latents - timestep = t.expand(latent_model_input.shape[0]) - - noise_pred = components.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=block_state.prompt_embeds, - pooled_projections=block_state.pooled_prompt_embeds, - joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), - return_dict=False, - )[0] - - if block_state.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + block_state.guidance_scale * (noise_pred_text - noise_pred_uncond) - - should_skip_layers = ( - getattr(block_state, "skip_guidance_layers", None) is not None - and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) - and i < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) - ) - - if should_skip_layers: - timestep_skip = t.expand(block_state.latents.shape[0]) - noise_pred_skip_layers = components.transformer( - hidden_states=block_state.latents, - timestep=timestep_skip, - encoder_hidden_states=block_state.original_prompt_embeds, - pooled_projections=block_state.original_pooled_prompt_embeds, - joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), - return_dict=False, - skip_layers=block_state.skip_guidance_layers, - )[0] - noise_pred = noise_pred + (noise_pred_text - noise_pred_skip_layers) * getattr(block_state, "skip_layer_guidance_scale", 2.8) + guider_inputs = { + "encoder_hidden_states": ( + getattr(block_state, "prompt_embeds", None), + getattr(block_state, "negative_prompt_embeds", None), + ), + "pooled_projections": ( + getattr(block_state, "pooled_prompt_embeds", None), + getattr(block_state, "negative_pooled_prompt_embeds", None), + ), + } + + components.guider.guidance_scale = block_state.guidance_scale + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + guider_state = components.guider.prepare_inputs(guider_inputs) + + for guider_state_batch in guider_state: + components.guider.prepare_models(components.transformer) + cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + + timestep = t.expand(block_state.latents.shape[0]) + guider_state_batch.noise_pred = components.transformer( + hidden_states=block_state.latents, + timestep=timestep, + joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + return_dict=False, + **cond_kwargs, + )[0] + + components.guider.cleanup_models(components.transformer) + + guider_output = components.guider(guider_state) + noise_pred = guider_output.pred + + should_skip_layers = ( + getattr(block_state, "skip_guidance_layers", None) is not None + and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) + and i < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) + ) + + if should_skip_layers and block_state.do_classifier_free_guidance: + timestep_skip = t.expand(block_state.latents.shape[0]) + noise_pred_skip_layers = components.transformer( + hidden_states=block_state.latents, + timestep=timestep_skip, + encoder_hidden_states=getattr(block_state, "prompt_embeds", None), + pooled_projections=getattr(block_state, "pooled_prompt_embeds", None), + joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + return_dict=False, + skip_layers=block_state.skip_guidance_layers, + )[0] + noise_pred = noise_pred + (guider_output.pred_cond - noise_pred_skip_layers) * getattr(block_state, "skip_layer_guidance_scale", 2.8) block_state.noise_pred = noise_pred return components, block_state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index a8b654abb456..ff350e4445d7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -27,6 +27,7 @@ logger = logging.get_logger(__name__) +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" ): @@ -52,6 +53,271 @@ def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.G image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor return image_latents +# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds with self -> components +def _get_t5_prompt_embeds( + components, + prompt: str | list[str] = None, + num_images_per_prompt: int = 1, + max_sequence_length: int = 256, + device: torch.device | None = None, + dtype: torch.dtype | None = None, +): + device = device or components._execution_device + dtype = dtype or components.text_encoder.dtype + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + if components.text_encoder_3 is None: + return torch.zeros( + ( + batch_size * num_images_per_prompt, + max_sequence_length, + components.transformer.config.joint_attention_dim, + ), + device=device, + dtype=dtype, + ) + + text_inputs = components.tokenizer_3( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = components.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = components.tokenizer_3.batch_decode(untruncated_ids[:, components.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + + prompt_embeds = components.text_encoder_3(text_input_ids.to(device))[0] + + dtype = components.text_encoder_3.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + +# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds with self -> components +def _get_clip_prompt_embeds( + components, + prompt: str | list[str], + num_images_per_prompt: int = 1, + device: torch.device | None = None, + clip_skip: int | None = None, + clip_model_index: int = 0, +): + device = device or components._execution_device + + clip_tokenizers = [components.tokenizer, components.tokenizer_2] + clip_text_encoders = [components.text_encoder, components.text_encoder_2] + + tokenizer = clip_tokenizers[clip_model_index] + text_encoder = clip_text_encoders[clip_model_index] + + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=components.tokenizer_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode(untruncated_ids[:, components.tokenizer_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {components.tokenizer_max_length} tokens: {removed_text}" + ) + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) + pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + + return prompt_embeds, pooled_prompt_embeds + +# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt with self -> components, self._get_clip_prompt_embeds -> _get_clip_prompt_embeds, self._get_t5_prompt_embeds -> _get_t5_prompt_embeds +def encode_prompt( + components, + prompt: str | list[str], + prompt_2: str | list[str], + prompt_3: str | list[str], + device: torch.device | None = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + negative_prompt_3: str | list[str] | None = None, + prompt_embeds: torch.FloatTensor | None = None, + negative_prompt_embeds: torch.FloatTensor | None = None, + pooled_prompt_embeds: torch.FloatTensor | None = None, + negative_pooled_prompt_embeds: torch.FloatTensor | None = None, + clip_skip: int | None = None, + max_sequence_length: int = 256, + lora_scale: float | None = None, +): + device = device or components._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin): + components._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if components.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder, lora_scale) + if components.text_encoder_2 is not None and USE_PEFT_BACKEND: + scale_lora_layers(components.text_encoder_2, lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + prompt_embed, pooled_prompt_embed = _get_clip_prompt_embeds( + components, + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = _get_clip_prompt_embeds( + components, + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=clip_skip, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = _get_t5_prompt_embeds( + components, + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt + + # normalize str to list + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt_2 = ( + batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size *[negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + ) + + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + + negative_prompt_embed, negative_pooled_prompt_embed = _get_clip_prompt_embeds( + components, + negative_prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = _get_clip_prompt_embeds( + components, + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) + + t5_negative_prompt_embed = _get_t5_prompt_embeds( + components, + prompt=negative_prompt_3, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + device=device, + ) + + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + + negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + if components.text_encoder is not None: + if isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder, lora_scale) + + if components.text_encoder_2 is not None: + if isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(components.text_encoder_2, lora_scale) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + class StableDiffusion3ProcessImagesInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @@ -202,106 +468,6 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), ] - @staticmethod - def _get_t5_prompt_embeds(components, prompt, max_sequence_length, device): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - - if components.text_encoder_3 is None: - return torch.zeros( - (batch_size, max_sequence_length, components.transformer.config.joint_attention_dim), - device=device, - dtype=components.text_encoder.dtype, - ) - - text_inputs = components.tokenizer_3( - prompt, padding="max_length", max_length=max_sequence_length, - truncation=True, add_special_tokens=True, return_tensors="pt", - ) - prompt_embeds = components.text_encoder_3(text_inputs.input_ids.to(device))[0] - return prompt_embeds.to(dtype=components.text_encoder_3.dtype, device=device) - - @staticmethod - def _get_clip_prompt_embeds(components, prompt, device, clip_skip, clip_model_index): - clip_tokenizers = [components.tokenizer, components.tokenizer_2] - clip_text_encoders =[components.text_encoder, components.text_encoder_2] - - tokenizer = clip_tokenizers[clip_model_index] - text_encoder = clip_text_encoders[clip_model_index] - - prompt = [prompt] if isinstance(prompt, str) else prompt - text_inputs = tokenizer( - prompt, padding="max_length", max_length=tokenizer.model_max_length, - truncation=True, return_tensors="pt", - ) - - prompt_embeds = text_encoder(text_inputs.input_ids.to(device), output_hidden_states=True) - pooled_prompt_embeds = prompt_embeds[0] - - if clip_skip is None: - prompt_embeds = prompt_embeds.hidden_states[-2] - else: - prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - - return prompt_embeds.to(dtype=components.text_encoder.dtype, device=device), pooled_prompt_embeds - - @staticmethod - def encode_prompt(components, block_state, device, do_classifier_free_guidance, lora_scale): - if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - if components.text_encoder is not None: - scale_lora_layers(components.text_encoder, lora_scale) - if components.text_encoder_2 is not None: - scale_lora_layers(components.text_encoder_2, lora_scale) - - prompt_embeds = block_state.prompt_embeds - pooled_prompt_embeds = block_state.pooled_prompt_embeds - - if prompt_embeds is None: - prompt = [block_state.prompt] if isinstance(block_state.prompt, str) else block_state.prompt - prompt_2 = block_state.prompt_2 or prompt - prompt_3 = block_state.prompt_3 or prompt - - prompt_embed, pooled_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, prompt, device, block_state.clip_skip, 0) - prompt_2_embed, pooled_2_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, prompt_2, device, block_state.clip_skip, 1) - clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) - - t5_prompt_embed = StableDiffusion3TextEncoderStep._get_t5_prompt_embeds(components, prompt_3, block_state.max_sequence_length, device) - clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1])) - - prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) - pooled_prompt_embeds = torch.cat([pooled_embed, pooled_2_embed], dim=-1) - - negative_prompt_embeds = block_state.negative_prompt_embeds - negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds - - if do_classifier_free_guidance and negative_prompt_embeds is None: - batch_size = prompt_embeds.shape[0] - neg_prompt = block_state.negative_prompt or "" - neg_prompt_2 = block_state.negative_prompt_2 or neg_prompt - neg_prompt_3 = block_state.negative_prompt_3 or neg_prompt - - neg_prompt = batch_size * [neg_prompt] if isinstance(neg_prompt, str) else neg_prompt - neg_prompt_2 = batch_size * [neg_prompt_2] if isinstance(neg_prompt_2, str) else neg_prompt_2 - neg_prompt_3 = batch_size * [neg_prompt_3] if isinstance(neg_prompt_3, str) else neg_prompt_3 - - neg_embed, neg_pooled_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt, device, None, 0) - neg_2_embed, neg_2_pooled_embed = StableDiffusion3TextEncoderStep._get_clip_prompt_embeds(components, neg_prompt_2, device, None, 1) - neg_clip_embeds = torch.cat([neg_embed, neg_2_embed], dim=-1) - - t5_neg_embed = StableDiffusion3TextEncoderStep._get_t5_prompt_embeds(components, neg_prompt_3, block_state.max_sequence_length, device) - neg_clip_embeds = torch.nn.functional.pad(neg_clip_embeds, (0, t5_neg_embed.shape[-1] - neg_clip_embeds.shape[-1])) - - negative_prompt_embeds = torch.cat([neg_clip_embeds, t5_neg_embed], dim=-2) - negative_pooled_prompt_embeds = torch.cat([neg_pooled_embed, neg_2_pooled_embed], dim=-1) - - if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - if components.text_encoder is not None: - unscale_lora_layers(components.text_encoder, lora_scale) - if components.text_encoder_2 is not None: - unscale_lora_layers(components.text_encoder_2, lora_scale) - - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - @torch.no_grad() def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -310,8 +476,24 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS do_classifier_free_guidance = block_state.guidance_scale > 1.0 lora_scale = block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs else None - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt( - components, block_state, block_state.device, do_classifier_free_guidance, lora_scale + prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = encode_prompt( + components=components, + prompt=block_state.prompt, + prompt_2=getattr(block_state, "prompt_2", None), + prompt_3=getattr(block_state, "prompt_3", None), + device=block_state.device, + num_images_per_prompt=getattr(block_state, "num_images_per_prompt", 1), + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=getattr(block_state, "negative_prompt", None), + negative_prompt_2=getattr(block_state, "negative_prompt_2", None), + negative_prompt_3=getattr(block_state, "negative_prompt_3", None), + prompt_embeds=getattr(block_state, "prompt_embeds", None), + negative_prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), + pooled_prompt_embeds=getattr(block_state, "pooled_prompt_embeds", None), + negative_pooled_prompt_embeds=getattr(block_state, "negative_pooled_prompt_embeds", None), + clip_skip=getattr(block_state, "clip_skip", None), + max_sequence_length=getattr(block_state, "max_sequence_length", 256), + lora_scale=lora_scale, ) block_state.prompt_embeds = prompt_embeds diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 97755443078f..20e6c7f00d29 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -50,8 +50,8 @@ def intermediate_outputs(self) -> list[str]: OutputParam("do_classifier_free_guidance", type_hint=bool), OutputParam("prompt_embeds", type_hint=torch.Tensor), OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), - OutputParam("original_prompt_embeds", type_hint=torch.Tensor), - OutputParam("original_pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_prompt_embeds", type_hint=torch.Tensor), + OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), ] @torch.no_grad() @@ -69,9 +69,6 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) pooled_prompt_embeds = pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - block_state.original_prompt_embeds = prompt_embeds - block_state.original_pooled_prompt_embeds = pooled_prompt_embeds - if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) @@ -80,11 +77,15 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) - block_state.prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) - block_state.pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) + block_state.prompt_embeds = prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + block_state.negative_prompt_embeds = negative_prompt_embeds + block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds else: block_state.prompt_embeds = prompt_embeds block_state.pooled_prompt_embeds = pooled_prompt_embeds + block_state.negative_prompt_embeds = None + block_state.negative_pooled_prompt_embeds = None self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index a3a017d38e15..0e893714b70d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -49,6 +49,12 @@ def patch_size(self): return self.transformer.config.patch_size return 2 + @property + def tokenizer_max_length(self): + if getattr(self, "tokenizer", None) is not None: + return self.tokenizer.model_max_length + return 77 + @property def vae_scale_factor(self): vae_scale_factor = 8 From 24618def12ef9cfec3850e2b072b88b1f873183a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 28 Mar 2026 05:28:59 +0000 Subject: [PATCH 07/14] Apply style fixes --- src/diffusers/__init__.py | 4 +- src/diffusers/modular_pipelines/__init__.py | 2 +- .../stable_diffusion_3/before_denoise.py | 26 +++++---- .../stable_diffusion_3/decoders.py | 11 +++- .../stable_diffusion_3/denoise.py | 28 ++++++---- .../stable_diffusion_3/encoders.py | 55 +++++++++++++------ .../stable_diffusion_3/inputs.py | 52 ++++++++++++------ .../modular_blocks_stable_diffusion_3.py | 22 ++++---- .../dummy_torch_and_transformers_objects.py | 3 +- ...est_modular_pipeline_stable_diffusion_3.py | 15 +++-- 10 files changed, 140 insertions(+), 78 deletions(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index b08524e12b47..0b8421eb4d2c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -448,10 +448,10 @@ "QwenImageLayeredAutoBlocks", "QwenImageLayeredModularPipeline", "QwenImageModularPipeline", - "StableDiffusionXLAutoBlocks", - "StableDiffusionXLModularPipeline", "StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline", + "StableDiffusionXLAutoBlocks", + "StableDiffusionXLModularPipeline", "Wan22Blocks", "Wan22Image2VideoBlocks", "Wan22Image2VideoModularPipeline", diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 1cd4c14bc844..ea10761af6ba 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -46,7 +46,7 @@ "InsertableDict", ] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] - _import_structure["stable_diffusion_3"] =["StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline"] + _import_structure["stable_diffusion_3"] = ["StableDiffusion3AutoBlocks", "StableDiffusion3ModularPipeline"] _import_structure["wan"] = [ "WanBlocks", "Wan22Blocks", diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index fd854d5cd659..13c518a67cb6 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -64,7 +64,7 @@ class StableDiffusion3SetTimestepsStep(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] @property def description(self) -> str: @@ -72,7 +72,7 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas"), @@ -83,7 +83,7 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return[ + return [ OutputParam("timesteps", type_hint=torch.Tensor), OutputParam("num_inference_steps", type_hint=int), ] @@ -103,7 +103,7 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.num_inference_steps, block_state.sigmas, block_state.device, - getattr(block_state, "mu", None) + getattr(block_state, "mu", None), ) block_state.timesteps = timesteps @@ -118,7 +118,7 @@ class StableDiffusion3Img2ImgSetTimestepsStep(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return[ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] @property def description(self) -> str: @@ -126,7 +126,7 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("num_inference_steps", default=50), InputParam("timesteps"), InputParam("sigmas"), @@ -138,7 +138,7 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return[ + return [ OutputParam("timesteps", type_hint=torch.Tensor), OutputParam("num_inference_steps", type_hint=int), ] @@ -167,7 +167,7 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.num_inference_steps, block_state.sigmas, block_state.device, - getattr(block_state, "mu", None) + getattr(block_state, "mu", None), ) timesteps, num_inference_steps = self.get_timesteps( @@ -190,7 +190,7 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("height", type_hint=int), InputParam("width", type_hint=int), InputParam("latents", type_hint=torch.Tensor | None), @@ -202,7 +202,7 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return[OutputParam("latents", type_hint=torch.Tensor)] + return [OutputParam("latents", type_hint=torch.Tensor)] @torch.no_grad() def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: @@ -219,7 +219,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS int(block_state.height) // components.vae_scale_factor, int(block_state.width) // components.vae_scale_factor, ) - block_state.latents = randn_tensor(shape, generator=block_state.generator, device=block_state.device, dtype=block_state.dtype) + block_state.latents = randn_tensor( + shape, generator=block_state.generator, device=block_state.device, dtype=block_state.dtype + ) self.set_block_state(state, block_state) return components, state @@ -234,7 +236,7 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("latents", required=True, type_hint=torch.Tensor), InputParam("image_latents", required=True, type_hint=torch.Tensor), InputParam("timesteps", required=True, type_hint=torch.Tensor), diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index 4b500cd2c95e..0f79447aa2f0 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -31,14 +31,19 @@ class StableDiffusion3DecodeStep(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return[ + return [ ComponentSpec("vae", AutoencoderKL), - ComponentSpec("image_processor", VaeImageProcessor, config=FrozenDict({"vae_scale_factor": 8}), default_creation_method="from_config"), + ComponentSpec( + "image_processor", + VaeImageProcessor, + config=FrozenDict({"vae_scale_factor": 8}), + default_creation_method="from_config", + ), ] @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("output_type", default="pil"), InputParam("latents", required=True, type_hint=torch.Tensor), ] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index 738b7155eb42..77e3aa2235e7 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -39,7 +39,7 @@ class StableDiffusion3LoopDenoiser(ModularPipelineBlocks): @property def expected_components(self) -> list[ComponentSpec]: - return[ + return [ ComponentSpec( "guider", ClassifierFreeGuidance, @@ -55,7 +55,7 @@ def description(self) -> str: @property def inputs(self) -> list[tuple[str, Any]]: - return[ + return [ InputParam("joint_attention_kwargs", type_hint=dict), InputParam("latents", required=True, type_hint=torch.Tensor), InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), @@ -110,8 +110,10 @@ def __call__( should_skip_layers = ( getattr(block_state, "skip_guidance_layers", None) is not None - and i > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) - and i < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) + and i + > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) + and i + < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) ) if should_skip_layers and block_state.do_classifier_free_guidance: @@ -125,7 +127,9 @@ def __call__( return_dict=False, skip_layers=block_state.skip_guidance_layers, )[0] - noise_pred = noise_pred + (guider_output.pred_cond - noise_pred_skip_layers) * getattr(block_state, "skip_layer_guidance_scale", 2.8) + noise_pred = noise_pred + (guider_output.pred_cond - noise_pred_skip_layers) * getattr( + block_state, "skip_layer_guidance_scale", 2.8 + ) block_state.noise_pred = noise_pred return components, block_state @@ -140,7 +144,7 @@ def expected_components(self) -> list[ComponentSpec]: @property def intermediate_outputs(self) -> list[OutputParam]: - return[OutputParam("latents", type_hint=torch.Tensor)] + return [OutputParam("latents", type_hint=torch.Tensor)] @torch.no_grad() def __call__(self, components: StableDiffusion3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): @@ -163,14 +167,14 @@ class StableDiffusion3DenoiseLoopWrapper(LoopSequentialPipelineBlocks): @property def loop_expected_components(self) -> list[ComponentSpec]: - return[ + return [ ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), ComponentSpec("transformer", SD3Transformer2DModel), ] @property def loop_inputs(self) -> list[InputParam]: - return[ + return [ InputParam("timesteps", required=True, type_hint=torch.Tensor), InputParam("num_inference_steps", required=True, type_hint=int), ] @@ -178,12 +182,16 @@ def loop_inputs(self) -> list[InputParam]: @torch.no_grad() def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.num_warmup_steps = max(len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0) + block_state.num_warmup_steps = max( + len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + ) with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) - if i == len(block_state.timesteps) - 1 or ((i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0): + if i == len(block_state.timesteps) - 1 or ( + (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + ): progress_bar.update() self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index ff350e4445d7..393bc684da7d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -27,6 +27,7 @@ logger = logging.get_logger(__name__) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" @@ -40,9 +41,10 @@ def retrieve_latents( else: raise AttributeError("Could not access latents of provided encoder_output") + def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): if isinstance(generator, list): - image_latents =[ + image_latents = [ retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) for i in range(image.shape[0]) ] @@ -53,6 +55,7 @@ def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.G image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor return image_latents + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds with self -> components def _get_t5_prompt_embeds( components, @@ -91,7 +94,9 @@ def _get_t5_prompt_embeds( untruncated_ids = components.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = components.tokenizer_3.batch_decode(untruncated_ids[:, components.tokenizer_max_length - 1 : -1]) + removed_text = components.tokenizer_3.batch_decode( + untruncated_ids[:, components.tokenizer_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" @@ -110,6 +115,7 @@ def _get_t5_prompt_embeds( return prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds with self -> components def _get_clip_prompt_embeds( components, @@ -166,6 +172,7 @@ def _get_clip_prompt_embeds( return prompt_embeds, pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt with self -> components, self._get_clip_prompt_embeds -> _get_clip_prompt_embeds, self._get_t5_prompt_embeds -> _get_t5_prompt_embeds def encode_prompt( components, @@ -256,7 +263,7 @@ def encode_prompt( batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 ) negative_prompt_3 = ( - batch_size *[negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): @@ -303,7 +310,8 @@ def encode_prompt( ) negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) - negative_pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) if components.text_encoder is not None: @@ -318,6 +326,7 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + class StableDiffusion3ProcessImagesInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @@ -327,7 +336,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: - return[ + return [ ComponentSpec( "image_processor", VaeImageProcessor, @@ -338,11 +347,11 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] @property def intermediate_outputs(self) -> list[OutputParam]: - return[OutputParam(name="processed_image")] + return [OutputParam(name="processed_image")] @staticmethod def check_inputs(height, width, vae_scale_factor, patch_size): @@ -362,8 +371,10 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS if block_state.resized_image is None: image = block_state.image self.check_inputs( - height=block_state.height, width=block_state.width, - vae_scale_factor=components.vae_scale_factor, patch_size=components.patch_size + height=block_state.height, + width=block_state.width, + vae_scale_factor=components.vae_scale_factor, + patch_size=components.patch_size, ) height = block_state.height or components.default_height width = block_state.width or components.default_width @@ -376,10 +387,13 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS self.set_block_state(state, block_state) return components, state + class StableDiffusion3VaeEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" - def __init__(self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample"): + def __init__( + self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample" + ): self._image_input_name = input_name self._image_latents_output_name = output_name self.sample_mode = sample_mode @@ -395,12 +409,16 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[InputParam(self._image_input_name), InputParam("generator")] + return [InputParam(self._image_input_name), InputParam("generator")] @property def intermediate_outputs(self) -> list[OutputParam]: - return[ - OutputParam(self._image_latents_output_name, type_hint=torch.Tensor, description="The latents representing the reference image") + return [ + OutputParam( + self._image_latents_output_name, + type_hint=torch.Tensor, + description="The latents representing the reference image", + ) ] @torch.no_grad() @@ -422,6 +440,7 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS self.set_block_state(state, block_state) return components, state + class StableDiffusion3TextEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @@ -431,7 +450,7 @@ def description(self) -> str: @property def expected_components(self) -> list[ComponentSpec]: - return[ + return [ ComponentSpec("text_encoder", CLIPTextModelWithProjection), ComponentSpec("tokenizer", CLIPTokenizer), ComponentSpec("text_encoder_2", CLIPTextModelWithProjection), @@ -442,7 +461,7 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("prompt"), InputParam("prompt_2"), InputParam("prompt_3"), @@ -461,7 +480,7 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return[ + return [ OutputParam("prompt_embeds", type_hint=torch.Tensor), OutputParam("negative_prompt_embeds", type_hint=torch.Tensor), OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), @@ -474,7 +493,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.device = components._execution_device do_classifier_free_guidance = block_state.guidance_scale > 1.0 - lora_scale = block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs else None + lora_scale = ( + block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs else None + ) prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = encode_prompt( components=components, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 20e6c7f00d29..623a60116721 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -23,16 +23,19 @@ logger = logging.get_logger(__name__) + class StableDiffusion3TextInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @property def description(self) -> str: - return "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." + return ( + "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." + ) @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("num_images_per_prompt", default=1), InputParam("guidance_scale", default=7.0), InputParam("skip_guidance_layers", type_hint=list), @@ -44,7 +47,7 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[str]: - return[ + return [ OutputParam("batch_size", type_hint=int), OutputParam("dtype", type_hint=torch.dtype), OutputParam("do_classifier_free_guidance", type_hint=bool), @@ -67,15 +70,23 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS prompt_embeds = prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) - pooled_prompt_embeds = pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + pooled_prompt_embeds = pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, neg_seq_len, -1) + negative_prompt_embeds = negative_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, neg_seq_len, -1 + ) - negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, -1) + negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, -1 + ) block_state.prompt_embeds = prompt_embeds block_state.pooled_prompt_embeds = pooled_prompt_embeds @@ -90,12 +101,17 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS self.set_block_state(state, block_state) return components, state + class StableDiffusion3AdditionalInputsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" - def __init__(self, image_latent_inputs: list[str] = ["image_latents"], additional_batch_inputs: list[str] =[]): - self._image_latent_inputs = image_latent_inputs if isinstance(image_latent_inputs, list) else [image_latent_inputs] - self._additional_batch_inputs = additional_batch_inputs if isinstance(additional_batch_inputs, list) else[additional_batch_inputs] + def __init__(self, image_latent_inputs: list[str] = ["image_latents"], additional_batch_inputs: list[str] = []): + self._image_latent_inputs = ( + image_latent_inputs if isinstance(image_latent_inputs, list) else [image_latent_inputs] + ) + self._additional_batch_inputs = ( + additional_batch_inputs if isinstance(additional_batch_inputs, list) else [additional_batch_inputs] + ) super().__init__() @property @@ -104,7 +120,7 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - inputs =[ + inputs = [ InputParam("num_images_per_prompt", default=1), InputParam("batch_size", required=True), InputParam("height"), @@ -116,7 +132,7 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return[ + return [ OutputParam("image_height", type_hint=int), OutputParam("image_width", type_hint=int), ] @@ -139,8 +155,10 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.image_width = width tensor = repeat_tensor_to_batch_size( - input_name=input_name, input_tensor=tensor, - num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size + input_name=input_name, + input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, ) setattr(block_state, input_name, tensor) @@ -149,8 +167,10 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS if tensor is None: continue tensor = repeat_tensor_to_batch_size( - input_name=input_name, input_tensor=tensor, - num_images_per_prompt=block_state.num_images_per_prompt, batch_size=block_state.batch_size + input_name=input_name, + input_tensor=tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, ) setattr(block_state, input_name, tensor) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index e823a58ea723..3dcab757250c 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -45,20 +45,20 @@ class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[StableDiffusion3Img2ImgVaeEncoderStep] + block_classes = [StableDiffusion3Img2ImgVaeEncoderStep] block_names = ["img2img"] - block_trigger_inputs =["image"] + block_trigger_inputs = ["image"] class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep()] + block_classes = [StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep()] block_names = ["prepare_latents", "set_timesteps"] class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[ + block_classes = [ StableDiffusion3PrepareLatentsStep(), StableDiffusion3Img2ImgSetTimestepsStep(), StableDiffusion3Img2ImgPrepareLatentsStep(), @@ -68,15 +68,15 @@ class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[StableDiffusion3Img2ImgBeforeDenoiseStep, StableDiffusion3BeforeDenoiseStep] + block_classes = [StableDiffusion3Img2ImgBeforeDenoiseStep, StableDiffusion3BeforeDenoiseStep] block_names = ["img2img", "text2image"] block_trigger_inputs = ["image_latents", None] class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep()] - block_names =["text_inputs", "additional_inputs"] + block_classes = [StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep()] + block_names = ["text_inputs", "additional_inputs"] class StableDiffusion3AutoInputStep(AutoPipelineBlocks): @@ -88,14 +88,16 @@ class StableDiffusion3AutoInputStep(AutoPipelineBlocks): class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" - block_classes =[StableDiffusion3AutoInputStep, StableDiffusion3AutoBeforeDenoiseStep, StableDiffusion3DenoiseStep] - block_names =["input", "before_denoise", "denoise"] + block_classes = [StableDiffusion3AutoInputStep, StableDiffusion3AutoBeforeDenoiseStep, StableDiffusion3DenoiseStep] + block_names = ["input", "before_denoise", "denoise"] + @property def outputs(self): return [OutputParam.template("latents")] -AUTO_BLOCKS = InsertableDict([ +AUTO_BLOCKS = InsertableDict( + [ ("text_encoder", StableDiffusion3TextEncoderStep()), ("vae_encoder", StableDiffusion3AutoVaeEncoderStep()), ("denoise", StableDiffusion3CoreDenoiseStep()), diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index d6c4d3972b96..b50d691e531b 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -391,6 +391,7 @@ def from_config(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) + class StableDiffusion3AutoBlocks(metaclass=DummyObject): _backends = ["torch", "transformers"] @@ -414,7 +415,7 @@ def __init__(self, *args, **kwargs): @classmethod def from_config(cls, *args, **kwargs): - requires_backends(cls,["torch", "transformers"]) + requires_backends(cls, ["torch", "transformers"]) @classmethod def from_pretrained(cls, *args, **kwargs): diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 51eab6fcd0ea..f361f6a92a1d 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -27,7 +27,7 @@ SD3_TEXT2IMAGE_WORKFLOWS = { - "text2image":[ + "text2image": [ ("text_encoder", "StableDiffusion3TextEncoderStep"), ("denoise.input", "StableDiffusion3TextInputStep"), ("denoise.before_denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), @@ -37,6 +37,7 @@ ] } + class TestStableDiffusion3ModularPipelineFast(ModularPipelineTesterMixin): pipeline_class = StableDiffusion3ModularPipeline pipeline_blocks_class = StableDiffusion3AutoBlocks @@ -63,7 +64,7 @@ def get_pipeline(self, components_manager=None, torch_dtype=torch.float32): return super().get_pipeline(components_manager, torch_dtype) def test_save_from_pretrained(self, tmp_path): - pipes =[] + pipes = [] base_pipe = self.get_pipeline().to(torch_device) pipes.append(base_pipe) @@ -73,7 +74,7 @@ def test_save_from_pretrained(self, tmp_path): pipe.to(torch_device) pipes.append(pipe) - image_slices =[] + image_slices = [] for p in pipes: inputs = self.get_dummy_inputs() image = p(**inputs, output="images") @@ -93,8 +94,9 @@ def test_load_expected_components_from_save_pretrained(self, tmp_path): def test_float16_inference(self): super().test_float16_inference(9e-2) + SD3_IMAGE2IMAGE_WORKFLOWS = { - "image2image":[ + "image2image": [ ("text_encoder", "StableDiffusion3TextEncoderStep"), ("vae_encoder.preprocess", "StableDiffusion3ProcessImagesInputStep"), ("vae_encoder.encode", "StableDiffusion3VaeEncoderStep"), @@ -108,6 +110,7 @@ def test_float16_inference(self): ] } + class TestStableDiffusion3Img2ImgModularPipelineFast(ModularPipelineTesterMixin): pipeline_class = StableDiffusion3ModularPipeline pipeline_blocks_class = StableDiffusion3AutoBlocks @@ -142,7 +145,7 @@ def get_dummy_inputs(self, seed=0): return inputs def test_save_from_pretrained(self, tmp_path): - pipes =[] + pipes = [] base_pipe = self.get_pipeline().to(torch_device) pipes.append(base_pipe) @@ -153,7 +156,7 @@ def test_save_from_pretrained(self, tmp_path): pipe.image_processor = VaeImageProcessor(vae_scale_factor=8) pipes.append(pipe) - image_slices =[] + image_slices = [] for p in pipes: inputs = self.get_dummy_inputs() image = p(**inputs, output="images") From 27cb9f76cc28ced3794b8aa03f136213c9e0ef26 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sun, 29 Mar 2026 05:37:44 +0000 Subject: [PATCH 08/14] set default height and width --- .../modular_pipelines/stable_diffusion_3/before_denoise.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 13c518a67cb6..24cab1ab1a44 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -210,6 +210,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.device = components._execution_device batch_size = block_state.batch_size * block_state.num_images_per_prompt + block_state.height = block_state.height or components.default_height + block_state.width = block_state.width or components.default_width + if block_state.latents is not None: block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) else: From 6fd0f30308f216e7dedb52025af8788b124495c5 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Wed, 1 Apr 2026 17:43:55 +0000 Subject: [PATCH 09/14] - skip layer refactoring - add autodocstring to assembled blocks --- .../stable_diffusion_3/denoise.py | 39 ++++++------------- .../modular_blocks_stable_diffusion_3.py | 9 +++++ 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index 77e3aa2235e7..e5a5c51b9009 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -62,9 +62,7 @@ def inputs(self) -> list[tuple[str, Any]]: InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), InputParam("negative_prompt_embeds", type_hint=torch.Tensor), InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), - InputParam("do_classifier_free_guidance", type_hint=bool), InputParam("guidance_scale", default=7.0), - InputParam("skip_guidance_layers", type_hint=list), InputParam("skip_layer_guidance_scale", default=2.8), InputParam("skip_layer_guidance_stop", default=0.2), InputParam("skip_layer_guidance_start", default=0.01), @@ -86,7 +84,15 @@ def __call__( ), } - components.guider.guidance_scale = block_state.guidance_scale + if hasattr(components.guider, "guidance_scale"): + components.guider.guidance_scale = block_state.guidance_scale + if hasattr(components.guider, "skip_layer_guidance_scale"): + components.guider.skip_layer_guidance_scale = block_state.skip_layer_guidance_scale + if hasattr(components.guider, "skip_layer_guidance_start"): + components.guider.skip_layer_guidance_start = block_state.skip_layer_guidance_start + if hasattr(components.guider, "skip_layer_guidance_stop"): + components.guider.skip_layer_guidance_stop = block_state.skip_layer_guidance_stop + components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) @@ -106,32 +112,8 @@ def __call__( components.guider.cleanup_models(components.transformer) guider_output = components.guider(guider_state) - noise_pred = guider_output.pred - - should_skip_layers = ( - getattr(block_state, "skip_guidance_layers", None) is not None - and i - > getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_start", 0.01) - and i - < getattr(block_state, "num_inference_steps", 50) * getattr(block_state, "skip_layer_guidance_stop", 0.2) - ) - - if should_skip_layers and block_state.do_classifier_free_guidance: - timestep_skip = t.expand(block_state.latents.shape[0]) - noise_pred_skip_layers = components.transformer( - hidden_states=block_state.latents, - timestep=timestep_skip, - encoder_hidden_states=getattr(block_state, "prompt_embeds", None), - pooled_projections=getattr(block_state, "pooled_prompt_embeds", None), - joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), - return_dict=False, - skip_layers=block_state.skip_guidance_layers, - )[0] - noise_pred = noise_pred + (guider_output.pred_cond - noise_pred_skip_layers) * getattr( - block_state, "skip_layer_guidance_scale", 2.8 - ) + block_state.noise_pred = guider_output.pred - block_state.noise_pred = noise_pred return components, block_state @@ -198,6 +180,7 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS return components, state +# auto_docstring class StableDiffusion3DenoiseStep(StableDiffusion3DenoiseLoopWrapper): block_classes = [StableDiffusion3LoopDenoiser, StableDiffusion3LoopAfterDenoiser] block_names = ["denoiser", "after_denoiser"] diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 3dcab757250c..6b4c2b277426 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -37,12 +37,14 @@ logger = logging.get_logger(__name__) +# auto_docstring class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3ProcessImagesInputStep(), StableDiffusion3VaeEncoderStep()] block_names = ["preprocess", "encode"] +# auto_docstring class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3Img2ImgVaeEncoderStep] @@ -50,12 +52,14 @@ class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): block_trigger_inputs = ["image"] +# auto_docstring class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep()] block_names = ["prepare_latents", "set_timesteps"] +# auto_docstring class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [ @@ -66,6 +70,7 @@ class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents"] +# auto_docstring class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3Img2ImgBeforeDenoiseStep, StableDiffusion3BeforeDenoiseStep] @@ -73,12 +78,14 @@ class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): block_trigger_inputs = ["image_latents", None] +# auto_docstring class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep()] block_names = ["text_inputs", "additional_inputs"] +# auto_docstring class StableDiffusion3AutoInputStep(AutoPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3Img2ImgInputStep, StableDiffusion3TextInputStep] @@ -86,6 +93,7 @@ class StableDiffusion3AutoInputStep(AutoPipelineBlocks): block_trigger_inputs = ["image_latents", None] +# auto_docstring class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [StableDiffusion3AutoInputStep, StableDiffusion3AutoBeforeDenoiseStep, StableDiffusion3DenoiseStep] @@ -106,6 +114,7 @@ def outputs(self): ) +# auto_docstring class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = AUTO_BLOCKS.values() From 956f3635e845d7423032fcfa9eabbb6501c5a7de Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Thu, 2 Apr 2026 17:17:24 +0000 Subject: [PATCH 10/14] add description and run autostring script --- .../stable_diffusion_3/before_denoise.py | 171 ++++++-- .../stable_diffusion_3/decoders.py | 6 +- .../stable_diffusion_3/denoise.py | 24 +- .../stable_diffusion_3/encoders.py | 42 +- .../stable_diffusion_3/inputs.py | 28 +- .../modular_blocks_stable_diffusion_3.py | 393 ++++++++++++++++++ 6 files changed, 586 insertions(+), 78 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 24cab1ab1a44..1de2af37d3d4 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -15,7 +15,10 @@ import torch -from ...pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import calculate_shift, retrieve_timesteps +from ...pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( + calculate_shift, + retrieve_timesteps, +) from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -23,7 +26,6 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import StableDiffusion3ModularPipeline - logger = logging.get_logger(__name__) @@ -41,7 +43,9 @@ def _get_initial_timesteps_and_optionals( ): scheduler_kwargs = {} if scheduler.config.get("use_dynamic_shifting", None) and mu is None: - image_seq_len = (height // vae_scale_factor // patch_size) * (width // vae_scale_factor // patch_size) + image_seq_len = (height // vae_scale_factor // patch_size) * ( + width // vae_scale_factor // patch_size + ) mu = calculate_shift( image_seq_len, scheduler.config.get("base_image_seq_len", 256), @@ -73,12 +77,33 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("height", type_hint=int), - InputParam("width", type_hint=int), - InputParam("mu", type_hint=float), + InputParam( + "num_inference_steps", + default=50, + description="The number of denoising steps.", + ), + InputParam( + "timesteps", + description="Custom timesteps to use for the denoising process.", + ), + InputParam( + "sigmas", description="Custom sigmas to use for the denoising process." + ), + InputParam( + "height", + type_hint=int, + description="The height in pixels of the generated image.", + ), + InputParam( + "width", + type_hint=int, + description="The width in pixels of the generated image.", + ), + InputParam( + "mu", + type_hint=float, + description="The mu value used for dynamic shifting. If not provided, it is dynamically calculated.", + ), ] @property @@ -89,7 +114,9 @@ def intermediate_outputs(self) -> list[OutputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -127,13 +154,38 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: return [ - InputParam("num_inference_steps", default=50), - InputParam("timesteps"), - InputParam("sigmas"), - InputParam("strength", default=0.6), - InputParam("height", type_hint=int), - InputParam("width", type_hint=int), - InputParam("mu", type_hint=float), + InputParam( + "num_inference_steps", + default=50, + description="The number of denoising steps.", + ), + InputParam( + "timesteps", + description="Custom timesteps to use for the denoising process.", + ), + InputParam( + "sigmas", description="Custom sigmas to use for the denoising process." + ), + InputParam( + "strength", + default=0.6, + description="Indicates extent to transform the reference image.", + ), + InputParam( + "height", + type_hint=int, + description="The height in pixels of the generated image.", + ), + InputParam( + "width", + type_hint=int, + description="The width in pixels of the generated image.", + ), + InputParam( + "mu", + type_hint=float, + description="The mu value used for dynamic shifting. If not provided, it is dynamically calculated.", + ), ] @property @@ -153,7 +205,9 @@ def get_timesteps(scheduler, num_inference_steps, strength): return timesteps, num_inference_steps - t_start @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -191,13 +245,42 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: return [ - InputParam("height", type_hint=int), - InputParam("width", type_hint=int), - InputParam("latents", type_hint=torch.Tensor | None), - InputParam("num_images_per_prompt", type_hint=int, default=1), - InputParam("generator"), - InputParam("batch_size", required=True, type_hint=int), - InputParam("dtype", type_hint=torch.dtype), + InputParam( + "height", + type_hint=int, + description="The height in pixels of the generated image.", + ), + InputParam( + "width", + type_hint=int, + description="The width in pixels of the generated image.", + ), + InputParam( + "latents", + type_hint=torch.Tensor | None, + description="Pre-generated noisy latents to be used as inputs for image generation.", + ), + InputParam( + "num_images_per_prompt", + type_hint=int, + default=1, + description="The number of images to generate per prompt.", + ), + InputParam( + "generator", + description="One or a list of torch generator(s) to make generation deterministic.", + ), + InputParam( + "batch_size", + required=True, + type_hint=int, + description="The batch size for latent generation.", + ), + InputParam( + "dtype", + type_hint=torch.dtype, + description="The data type for the latents.", + ), ] @property @@ -205,7 +288,9 @@ def intermediate_outputs(self) -> list[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor)] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device batch_size = block_state.batch_size * block_state.num_images_per_prompt @@ -214,7 +299,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.width = block_state.width or components.default_width if block_state.latents is not None: - block_state.latents = block_state.latents.to(device=block_state.device, dtype=block_state.dtype) + block_state.latents = block_state.latents.to( + device=block_state.device, dtype=block_state.dtype + ) else: shape = ( batch_size, @@ -223,7 +310,10 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS int(block_state.width) // components.vae_scale_factor, ) block_state.latents = randn_tensor( - shape, generator=block_state.generator, device=block_state.device, dtype=block_state.dtype + shape, + generator=block_state.generator, + device=block_state.device, + dtype=block_state.dtype, ) self.set_block_state(state, block_state) @@ -240,9 +330,24 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: return [ - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("image_latents", required=True, type_hint=torch.Tensor), - InputParam("timesteps", required=True, type_hint=torch.Tensor), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to be scaled by the scheduler.", + ), + InputParam( + "image_latents", + required=True, + type_hint=torch.Tensor, + description="The image latents encoded by the VAE.", + ), + InputParam( + "timesteps", + required=True, + type_hint=torch.Tensor, + description="The timesteps schedule.", + ), ] @property @@ -250,7 +355,9 @@ def intermediate_outputs(self) -> list[OutputParam]: return [OutputParam("initial_noise", type_hint=torch.Tensor)] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) block_state.initial_noise = block_state.latents diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index 0f79447aa2f0..75346e829e64 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -43,9 +43,9 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return [ - InputParam("output_type", default="pil"), - InputParam("latents", required=True, type_hint=torch.Tensor), + return[ + InputParam("output_type", default="pil", description="The output format of the generated image (e.g., 'pil', 'pt', 'np')."), + InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents to be decoded."), ] @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index e5a5c51b9009..d73531d5b825 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -55,18 +55,18 @@ def description(self) -> str: @property def inputs(self) -> list[tuple[str, Any]]: - return [ - InputParam("joint_attention_kwargs", type_hint=dict), - InputParam("latents", required=True, type_hint=torch.Tensor), - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor), - InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), - InputParam("guidance_scale", default=7.0), - InputParam("skip_layer_guidance_scale", default=2.8), - InputParam("skip_layer_guidance_stop", default=0.2), - InputParam("skip_layer_guidance_start", default=0.01), - InputParam("num_inference_steps", type_hint=int), + return[ + InputParam("joint_attention_kwargs", type_hint=dict, description="A kwargs dictionary passed along to the AttentionProcessor."), + InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process."), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings for guidance."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pooled text embeddings for guidance."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings for guidance."), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings for guidance."), + InputParam("guidance_scale", default=7.0, description="Guidance scale as defined in Classifier-Free Diffusion Guidance."), + InputParam("skip_layer_guidance_scale", default=2.8, description="The scale of the guidance for the skipped layers."), + InputParam("skip_layer_guidance_stop", default=0.2, description="The step fraction at which the guidance for skipped layers stops."), + InputParam("skip_layer_guidance_start", default=0.01, description="The step fraction at which the guidance for skipped layers starts."), + InputParam("num_inference_steps", type_hint=int, description="The number of denoising steps."), ] @torch.no_grad() diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 393bc684da7d..569bc42c3a0d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -347,7 +347,12 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] + return[ + InputParam("resized_image", description="The pre-resized image input."), + InputParam("image", description="The input image to be used as the starting point for the image-to-image process."), + InputParam("height", description="The height in pixels of the generated image."), + InputParam("width", description="The width in pixels of the generated image.") + ] @property def intermediate_outputs(self) -> list[OutputParam]: @@ -409,7 +414,10 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return [InputParam(self._image_input_name), InputParam("generator")] + return[ + InputParam(self._image_input_name, description="The processed image input to be encoded."), + InputParam("generator", description="One or a list of torch generator(s) to make generation deterministic.") + ] @property def intermediate_outputs(self) -> list[OutputParam]: @@ -461,21 +469,21 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return [ - InputParam("prompt"), - InputParam("prompt_2"), - InputParam("prompt_3"), - InputParam("negative_prompt"), - InputParam("negative_prompt_2"), - InputParam("negative_prompt_3"), - InputParam("prompt_embeds", type_hint=torch.Tensor), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor), - InputParam("pooled_prompt_embeds", type_hint=torch.Tensor), - InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), - InputParam("guidance_scale", default=7.0), - InputParam("clip_skip", type_hint=int), - InputParam("max_sequence_length", type_hint=int, default=256), - InputParam("joint_attention_kwargs"), + return[ + InputParam("prompt", description="The prompt or prompts to guide the image generation."), + InputParam("prompt_2", description="The prompt or prompts to be sent to tokenizer_2 and text_encoder_2."), + InputParam("prompt_3", description="The prompt or prompts to be sent to tokenizer_3 and text_encoder_3."), + InputParam("negative_prompt", description="The prompt or prompts not to guide the image generation."), + InputParam("negative_prompt_2", description="The prompt or prompts not to guide the image generation for tokenizer_2."), + InputParam("negative_prompt_3", description="The prompt or prompts not to guide the image generation for tokenizer_3."), + InputParam("prompt_embeds", type_hint=torch.Tensor, description="Pre-generated text embeddings."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings."), + InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated pooled text embeddings."), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative pooled text embeddings."), + InputParam("guidance_scale", default=7.0, description="Guidance scale as defined in Classifier-Free Diffusion Guidance."), + InputParam("clip_skip", type_hint=int, description="Number of layers to be skipped from CLIP while computing the prompt embeddings."), + InputParam("max_sequence_length", type_hint=int, default=256, description="Maximum sequence length to use with the prompt."), + InputParam("joint_attention_kwargs", description="A kwargs dictionary passed along to the AttentionProcessor."), ] @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 623a60116721..78f48bb74ae5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -35,14 +35,14 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam("guidance_scale", default=7.0), - InputParam("skip_guidance_layers", type_hint=list), - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor), - InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + return[ + InputParam("num_images_per_prompt", default=1, description="The number of images to generate per prompt."), + InputParam("guidance_scale", default=7.0, description="Guidance scale as defined in Classifier-Free Diffusion Guidance."), + InputParam("skip_guidance_layers", type_hint=list, description="A list of integers that specify layers to skip during guidance."), + InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings."), + InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings."), + InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings."), + InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative pooled text embeddings."), ] @property @@ -120,14 +120,14 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - inputs = [ - InputParam("num_images_per_prompt", default=1), - InputParam("batch_size", required=True), - InputParam("height"), - InputParam("width"), + inputs =[ + InputParam("num_images_per_prompt", default=1, description="The number of images to generate per prompt."), + InputParam("batch_size", required=True, description="The batch size."), + InputParam("height", description="The height in pixels of the generated image."), + InputParam("width", description="The width in pixels of the generated image."), ] for name in self._image_latent_inputs + self._additional_batch_inputs: - inputs.append(InputParam(name)) + inputs.append(InputParam(name, description=f"Latent input {name} to be processed.")) return inputs @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 6b4c2b277426..7cfe86904e38 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -39,6 +39,29 @@ # auto_docstring class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): + """ + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + The pre-resized image input. + image (`None`, *optional*): + The input image to be used as the starting point for the image-to-image process. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3ProcessImagesInputStep(), StableDiffusion3VaeEncoderStep()] block_names = ["preprocess", "encode"] @@ -46,6 +69,29 @@ class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): # auto_docstring class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): + """ + Components: + image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) + + Inputs: + resized_image (`None`, *optional*): + The pre-resized image input. + image (`None`, *optional*): + The input image to be used as the starting point for the image-to-image process. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + + Outputs: + processed_image (`None`): + TODO: Add description. + image_latents (`Tensor`): + The latents representing the reference image + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3Img2ImgVaeEncoderStep] block_names = ["img2img"] @@ -54,6 +100,43 @@ class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): # auto_docstring class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + batch_size (`int`): + The batch size for latent generation. + dtype (`dtype`, *optional*): + The data type for the latents. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + + Outputs: + latents (`Tensor`): + TODO: Add description. + timesteps (`Tensor`): + TODO: Add description. + num_inference_steps (`int`): + TODO: Add description. + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep()] block_names = ["prepare_latents", "set_timesteps"] @@ -61,6 +144,49 @@ class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): # auto_docstring class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + batch_size (`int`): + The batch size for latent generation. + dtype (`dtype`, *optional*): + The data type for the latents. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + image_latents (`Tensor`): + The image latents encoded by the VAE. + + Outputs: + latents (`Tensor`): + TODO: Add description. + timesteps (`Tensor`): + TODO: Add description. + num_inference_steps (`int`): + TODO: Add description. + initial_noise (`Tensor`): + TODO: Add description. + """ + model_name = "stable-diffusion-3" block_classes = [ StableDiffusion3PrepareLatentsStep(), @@ -72,6 +198,49 @@ class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): # auto_docstring class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) + + Inputs: + height (`int`, *optional*): + The height in pixels of the generated image. + width (`int`, *optional*): + The width in pixels of the generated image. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + batch_size (`int`): + The batch size for latent generation. + dtype (`dtype`, *optional*): + The data type for the latents. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + image_latents (`Tensor`, *optional*): + The image latents encoded by the VAE. + + Outputs: + latents (`Tensor`): + TODO: Add description. + timesteps (`Tensor`): + TODO: Add description. + num_inference_steps (`int`): + TODO: Add description. + initial_noise (`Tensor`): + TODO: Add description. + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3Img2ImgBeforeDenoiseStep, StableDiffusion3BeforeDenoiseStep] block_names = ["img2img", "text2image"] @@ -80,6 +249,50 @@ class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): # auto_docstring class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): + """ + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_guidance_layers (`list`, *optional*): + A list of integers that specify layers to skip during guidance. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + + Outputs: + batch_size (`int`): + TODO: Add description. + dtype (`dtype`): + TODO: Add description. + do_classifier_free_guidance (`bool`): + TODO: Add description. + prompt_embeds (`Tensor`): + TODO: Add description. + pooled_prompt_embeds (`Tensor`): + TODO: Add description. + negative_prompt_embeds (`Tensor`): + TODO: Add description. + negative_pooled_prompt_embeds (`Tensor`): + TODO: Add description. + image_height (`int`): + TODO: Add description. + image_width (`int`): + TODO: Add description. + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep()] block_names = ["text_inputs", "additional_inputs"] @@ -87,6 +300,50 @@ class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): # auto_docstring class StableDiffusion3AutoInputStep(AutoPipelineBlocks): + """ + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_guidance_layers (`list`, *optional*): + A list of integers that specify layers to skip during guidance. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + + Outputs: + batch_size (`int`): + TODO: Add description. + dtype (`dtype`): + TODO: Add description. + do_classifier_free_guidance (`bool`): + TODO: Add description. + prompt_embeds (`Tensor`): + TODO: Add description. + pooled_prompt_embeds (`Tensor`): + TODO: Add description. + negative_prompt_embeds (`Tensor`): + TODO: Add description. + negative_pooled_prompt_embeds (`Tensor`): + TODO: Add description. + image_height (`int`): + TODO: Add description. + image_width (`int`): + TODO: Add description. + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3Img2ImgInputStep, StableDiffusion3TextInputStep] block_names = ["img2img", "text2image"] @@ -95,6 +352,60 @@ class StableDiffusion3AutoInputStep(AutoPipelineBlocks): # auto_docstring class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): + """ + Components: + scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer + (`SD3Transformer2DModel`) + + Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_guidance_layers (`list`, *optional*): + A list of integers that specify layers to skip during guidance. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. + + Outputs: + latents (`Tensor`): + Denoised latents. + """ + model_name = "stable-diffusion-3" block_classes = [StableDiffusion3AutoInputStep, StableDiffusion3AutoBeforeDenoiseStep, StableDiffusion3DenoiseStep] block_names = ["input", "before_denoise", "denoise"] @@ -116,6 +427,88 @@ def outputs(self): # auto_docstring class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): + """ + Supported workflows: + - `text2image`: requires `prompt` + - `image2image`: requires `image`, `prompt` + + Components: + text_encoder (`CLIPTextModelWithProjection`) tokenizer (`CLIPTokenizer`) text_encoder_2 + (`CLIPTextModelWithProjection`) tokenizer_2 (`CLIPTokenizer`) text_encoder_3 (`T5EncoderModel`) tokenizer_3 + (`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler + (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer (`SD3Transformer2DModel`) + + Inputs: + prompt (`None`, *optional*): + The prompt or prompts to guide the image generation. + prompt_2 (`None`, *optional*): + The prompt or prompts to be sent to tokenizer_2 and text_encoder_2. + prompt_3 (`None`, *optional*): + The prompt or prompts to be sent to tokenizer_3 and text_encoder_3. + negative_prompt (`None`, *optional*): + The prompt or prompts not to guide the image generation. + negative_prompt_2 (`None`, *optional*): + The prompt or prompts not to guide the image generation for tokenizer_2. + negative_prompt_3 (`None`, *optional*): + The prompt or prompts not to guide the image generation for tokenizer_3. + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. + max_sequence_length (`int`, *optional*, defaults to 256): + Maximum sequence length to use with the prompt. + joint_attention_kwargs (`None`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + resized_image (`None`, *optional*): + The pre-resized image input. + image (`None`, *optional*): + The input image to be used as the starting point for the image-to-image process. + height (`None`, *optional*): + The height in pixels of the generated image. + width (`None`, *optional*): + The width in pixels of the generated image. + generator (`None`, *optional*): + One or a list of torch generator(s) to make generation deterministic. + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + skip_guidance_layers (`list`, *optional*): + A list of integers that specify layers to skip during guidance. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. + latents (`Tensor | NoneType`, *optional*): + Pre-generated noisy latents to be used as inputs for image generation. + num_inference_steps (`None`, *optional*, defaults to 50): + The number of denoising steps. + timesteps (`None`, *optional*): + Custom timesteps to use for the denoising process. + sigmas (`None`, *optional*): + Custom sigmas to use for the denoising process. + strength (`None`, *optional*, defaults to 0.6): + Indicates extent to transform the reference image. + mu (`float`, *optional*): + The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. + output_type (`None`, *optional*, defaults to pil): + The output format of the generated image (e.g., 'pil', 'pt', 'np'). + + Outputs: + images (`list`): + Generated images. + """ + model_name = "stable-diffusion-3" block_classes = AUTO_BLOCKS.values() block_names = AUTO_BLOCKS.keys() From e375cfa1a548eaf79c3b067f2aeae21505140d92 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Thu, 2 Apr 2026 17:19:44 +0000 Subject: [PATCH 11/14] styling --- .../stable_diffusion_3/__init__.py | 5 +- .../stable_diffusion_3/decoders.py | 20 +- .../stable_diffusion_3/denoise.py | 126 ++++++-- .../stable_diffusion_3/encoders.py | 275 ++++++++++++++---- .../stable_diffusion_3/inputs.py | 133 +++++++-- .../modular_blocks_stable_diffusion_3.py | 32 +- .../stable_diffusion_3/modular_pipeline.py | 5 +- ...est_modular_pipeline_stable_diffusion_3.py | 26 +- 8 files changed, 477 insertions(+), 145 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py index d7bc6020a816..51cb69ed1e8b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/__init__.py @@ -9,7 +9,6 @@ is_transformers_available, ) - _dummy_objects = {} _import_structure = {} @@ -21,7 +20,9 @@ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: - _import_structure["modular_blocks_stable_diffusion_3"] = ["StableDiffusion3AutoBlocks"] + _import_structure["modular_blocks_stable_diffusion_3"] = [ + "StableDiffusion3AutoBlocks" + ] _import_structure["modular_pipeline"] = ["StableDiffusion3ModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py index 75346e829e64..079181635e24 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/decoders.py @@ -22,7 +22,6 @@ from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam - logger = logging.get_logger(__name__) @@ -43,9 +42,18 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[ - InputParam("output_type", default="pil", description="The output format of the generated image (e.g., 'pil', 'pt', 'np')."), - InputParam("latents", required=True, type_hint=torch.Tensor, description="The denoised latents to be decoded."), + return [ + InputParam( + "output_type", + default="pil", + description="The output format of the generated image (e.g., 'pil', 'pt', 'np').", + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The denoised latents to be decoded.", + ), ] @property @@ -58,7 +66,9 @@ def __call__(self, components, state: PipelineState) -> PipelineState: vae = components.vae if not block_state.output_type == "latent": - latents = (block_state.latents / vae.config.scaling_factor) + vae.config.shift_factor + latents = ( + block_state.latents / vae.config.scaling_factor + ) + vae.config.shift_factor block_state.images = vae.decode(latents, return_dict=False)[0] block_state.images = components.image_processor.postprocess( block_state.images, output_type=block_state.output_type diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index d73531d5b825..cbc2092f853b 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -30,7 +30,6 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import StableDiffusion3ModularPipeline - logger = logging.get_logger(__name__) @@ -55,23 +54,74 @@ def description(self) -> str: @property def inputs(self) -> list[tuple[str, Any]]: - return[ - InputParam("joint_attention_kwargs", type_hint=dict, description="A kwargs dictionary passed along to the AttentionProcessor."), - InputParam("latents", required=True, type_hint=torch.Tensor, description="The initial latents to use for the denoising process."), - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings for guidance."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pooled text embeddings for guidance."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings for guidance."), - InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings for guidance."), - InputParam("guidance_scale", default=7.0, description="Guidance scale as defined in Classifier-Free Diffusion Guidance."), - InputParam("skip_layer_guidance_scale", default=2.8, description="The scale of the guidance for the skipped layers."), - InputParam("skip_layer_guidance_stop", default=0.2, description="The step fraction at which the guidance for skipped layers stops."), - InputParam("skip_layer_guidance_start", default=0.01, description="The step fraction at which the guidance for skipped layers starts."), - InputParam("num_inference_steps", type_hint=int, description="The number of denoising steps."), + return [ + InputParam( + "joint_attention_kwargs", + type_hint=dict, + description="A kwargs dictionary passed along to the AttentionProcessor.", + ), + InputParam( + "latents", + required=True, + type_hint=torch.Tensor, + description="The initial latents to use for the denoising process.", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Text embeddings for guidance.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pooled text embeddings for guidance.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Negative text embeddings for guidance.", + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Negative pooled text embeddings for guidance.", + ), + InputParam( + "guidance_scale", + default=7.0, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance.", + ), + InputParam( + "skip_layer_guidance_scale", + default=2.8, + description="The scale of the guidance for the skipped layers.", + ), + InputParam( + "skip_layer_guidance_stop", + default=0.2, + description="The step fraction at which the guidance for skipped layers stops.", + ), + InputParam( + "skip_layer_guidance_start", + default=0.01, + description="The step fraction at which the guidance for skipped layers starts.", + ), + InputParam( + "num_inference_steps", + type_hint=int, + description="The number of denoising steps.", + ), ] @torch.no_grad() def __call__( - self, components: StableDiffusion3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor + self, + components: StableDiffusion3ModularPipeline, + block_state: BlockState, + i: int, + t: torch.Tensor, ) -> PipelineState: guider_inputs = { "encoder_hidden_states": ( @@ -87,24 +137,37 @@ def __call__( if hasattr(components.guider, "guidance_scale"): components.guider.guidance_scale = block_state.guidance_scale if hasattr(components.guider, "skip_layer_guidance_scale"): - components.guider.skip_layer_guidance_scale = block_state.skip_layer_guidance_scale + components.guider.skip_layer_guidance_scale = ( + block_state.skip_layer_guidance_scale + ) if hasattr(components.guider, "skip_layer_guidance_start"): - components.guider.skip_layer_guidance_start = block_state.skip_layer_guidance_start + components.guider.skip_layer_guidance_start = ( + block_state.skip_layer_guidance_start + ) if hasattr(components.guider, "skip_layer_guidance_stop"): - components.guider.skip_layer_guidance_stop = block_state.skip_layer_guidance_stop + components.guider.skip_layer_guidance_stop = ( + block_state.skip_layer_guidance_stop + ) - components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) + components.guider.set_state( + step=i, num_inference_steps=block_state.num_inference_steps, timestep=t + ) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) - cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} + cond_kwargs = { + input_name: getattr(guider_state_batch, input_name) + for input_name in guider_inputs.keys() + } timestep = t.expand(block_state.latents.shape[0]) guider_state_batch.noise_pred = components.transformer( hidden_states=block_state.latents, timestep=timestep, - joint_attention_kwargs=getattr(block_state, "joint_attention_kwargs", None), + joint_attention_kwargs=getattr( + block_state, "joint_attention_kwargs", None + ), return_dict=False, **cond_kwargs, )[0] @@ -129,7 +192,13 @@ def intermediate_outputs(self) -> list[OutputParam]: return [OutputParam("latents", type_hint=torch.Tensor)] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): + def __call__( + self, + components: StableDiffusion3ModularPipeline, + block_state: BlockState, + i: int, + t: torch.Tensor, + ): latents_dtype = block_state.latents.dtype block_state.latents = components.scheduler.step( block_state.noise_pred, @@ -162,17 +231,24 @@ def loop_inputs(self) -> list[InputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) block_state.num_warmup_steps = max( - len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 + len(block_state.timesteps) + - block_state.num_inference_steps * components.scheduler.order, + 0, ) with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): - components, block_state = self.loop_step(components, block_state, i=i, t=t) + components, block_state = self.loop_step( + components, block_state, i=i, t=t + ) if i == len(block_state.timesteps) - 1 or ( - (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 + (i + 1) > block_state.num_warmup_steps + and (i + 1) % components.scheduler.order == 0 ): progress_bar.update() diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 569bc42c3a0d..159a90e248f3 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -13,7 +13,12 @@ # limitations under the License. import torch -from transformers import CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel, T5TokenizerFast +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, + T5EncoderModel, + T5TokenizerFast, +) from ...configuration_utils import FrozenDict from ...image_processor import VaeImageProcessor @@ -24,13 +29,14 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import StableDiffusion3ModularPipeline - logger = logging.get_logger(__name__) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" + encoder_output: torch.Tensor, + generator: torch.Generator | None = None, + sample_mode: str = "sample", ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) @@ -42,17 +48,30 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): +def encode_vae_image( + vae: AutoencoderKL, + image: torch.Tensor, + generator: torch.Generator, + sample_mode="sample", +): if isinstance(generator, list): image_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + retrieve_latents( + vae.encode(image[i : i + 1]), + generator=generator[i], + sample_mode=sample_mode, + ) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + image_latents = retrieve_latents( + vae.encode(image), generator=generator, sample_mode=sample_mode + ) - image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + image_latents = ( + image_latents - vae.config.shift_factor + ) * vae.config.scaling_factor return image_latents @@ -91,9 +110,13 @@ def _get_t5_prompt_embeds( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = components.tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids + untruncated_ids = components.tokenizer_3( + prompt, padding="longest", return_tensors="pt" + ).input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): removed_text = components.tokenizer_3.batch_decode( untruncated_ids[:, components.tokenizer_max_length - 1 : -1] ) @@ -145,9 +168,15 @@ def _get_clip_prompt_embeds( ) text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode(untruncated_ids[:, components.tokenizer_max_length - 1 : -1]) + untruncated_ids = tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, components.tokenizer_max_length - 1 : -1] + ) logger.warning( "The following part of your input was truncated because CLIP can only handle sequences up to" f" {components.tokenizer_max_length} tokens: {removed_text}" @@ -168,7 +197,9 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) - pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1) + pooled_prompt_embeds = pooled_prompt_embeds.view( + batch_size * num_images_per_prompt, -1 + ) return prompt_embeds, pooled_prompt_embeds @@ -246,11 +277,14 @@ def encode_prompt( ) clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), ) prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) - pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1 + ) if do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" @@ -258,12 +292,20 @@ def encode_prompt( negative_prompt_3 = negative_prompt_3 or negative_prompt # normalize str to list - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) negative_prompt_2 = ( - batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2 + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 ) negative_prompt_3 = ( - batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3 + batch_size * [negative_prompt_3] + if isinstance(negative_prompt_3, str) + else negative_prompt_3 ) if prompt is not None and type(prompt) is not type(negative_prompt): @@ -286,15 +328,19 @@ def encode_prompt( clip_skip=None, clip_model_index=0, ) - negative_prompt_2_embed, negative_pooled_prompt_2_embed = _get_clip_prompt_embeds( - components, - negative_prompt_2, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=None, - clip_model_index=1, + negative_prompt_2_embed, negative_pooled_prompt_2_embed = ( + _get_clip_prompt_embeds( + components, + negative_prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + ) + negative_clip_prompt_embeds = torch.cat( + [negative_prompt_embed, negative_prompt_2_embed], dim=-1 ) - negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1) t5_negative_prompt_embed = _get_t5_prompt_embeds( components, @@ -306,10 +352,16 @@ def encode_prompt( negative_clip_prompt_embeds = torch.nn.functional.pad( negative_clip_prompt_embeds, - (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ( + 0, + t5_negative_prompt_embed.shape[-1] + - negative_clip_prompt_embeds.shape[-1], + ), ) - negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2) + negative_prompt_embeds = torch.cat( + [negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2 + ) negative_pooled_prompt_embeds = torch.cat( [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 ) @@ -324,7 +376,12 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(components.text_encoder_2, lora_scale) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) class StableDiffusion3ProcessImagesInputStep(ModularPipelineBlocks): @@ -347,11 +404,18 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[ + return [ InputParam("resized_image", description="The pre-resized image input."), - InputParam("image", description="The input image to be used as the starting point for the image-to-image process."), - InputParam("height", description="The height in pixels of the generated image."), - InputParam("width", description="The width in pixels of the generated image.") + InputParam( + "image", + description="The input image to be used as the starting point for the image-to-image process.", + ), + InputParam( + "height", description="The height in pixels of the generated image." + ), + InputParam( + "width", description="The width in pixels of the generated image." + ), ] @property @@ -361,17 +425,25 @@ def intermediate_outputs(self) -> list[OutputParam]: @staticmethod def check_inputs(height, width, vae_scale_factor, patch_size): if height is not None and height % (vae_scale_factor * patch_size) != 0: - raise ValueError(f"Height must be divisible by {vae_scale_factor * patch_size} but is {height}") + raise ValueError( + f"Height must be divisible by {vae_scale_factor * patch_size} but is {height}" + ) if width is not None and width % (vae_scale_factor * patch_size) != 0: - raise ValueError(f"Width must be divisible by {vae_scale_factor * patch_size} but is {width}") + raise ValueError( + f"Width must be divisible by {vae_scale_factor * patch_size} but is {width}" + ) @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState): + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ): block_state = self.get_block_state(state) if block_state.resized_image is None and block_state.image is None: - raise ValueError("`resized_image` and `image` cannot be None at the same time") + raise ValueError( + "`resized_image` and `image` cannot be None at the same time" + ) if block_state.resized_image is None: image = block_state.image @@ -387,7 +459,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS width, height = block_state.resized_image[0].size image = block_state.resized_image - block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + block_state.processed_image = components.image_processor.preprocess( + image=image, height=height, width=width + ) self.set_block_state(state, block_state) return components, state @@ -397,7 +471,10 @@ class StableDiffusion3VaeEncoderStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" def __init__( - self, input_name: str = "processed_image", output_name: str = "image_latents", sample_mode: str = "sample" + self, + input_name: str = "processed_image", + output_name: str = "image_latents", + sample_mode: str = "sample", ): self._image_input_name = input_name self._image_latents_output_name = output_name @@ -414,9 +491,15 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[ - InputParam(self._image_input_name, description="The processed image input to be encoded."), - InputParam("generator", description="One or a list of torch generator(s) to make generation deterministic.") + return [ + InputParam( + self._image_input_name, + description="The processed image input to be encoded.", + ), + InputParam( + "generator", + description="One or a list of torch generator(s) to make generation deterministic.", + ), ] @property @@ -430,7 +513,9 @@ def intermediate_outputs(self) -> list[OutputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) image = getattr(block_state, self._image_input_name) @@ -441,7 +526,10 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS dtype = components.vae.dtype image = image.to(device=device, dtype=dtype) image_latents = encode_vae_image( - image=image, vae=components.vae, generator=block_state.generator, sample_mode=self.sample_mode + image=image, + vae=components.vae, + generator=block_state.generator, + sample_mode=self.sample_mode, ) setattr(block_state, self._image_latents_output_name, image_latents) @@ -469,21 +557,71 @@ def expected_components(self) -> list[ComponentSpec]: @property def inputs(self) -> list[InputParam]: - return[ - InputParam("prompt", description="The prompt or prompts to guide the image generation."), - InputParam("prompt_2", description="The prompt or prompts to be sent to tokenizer_2 and text_encoder_2."), - InputParam("prompt_3", description="The prompt or prompts to be sent to tokenizer_3 and text_encoder_3."), - InputParam("negative_prompt", description="The prompt or prompts not to guide the image generation."), - InputParam("negative_prompt_2", description="The prompt or prompts not to guide the image generation for tokenizer_2."), - InputParam("negative_prompt_3", description="The prompt or prompts not to guide the image generation for tokenizer_3."), - InputParam("prompt_embeds", type_hint=torch.Tensor, description="Pre-generated text embeddings."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings."), - InputParam("pooled_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated pooled text embeddings."), - InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative pooled text embeddings."), - InputParam("guidance_scale", default=7.0, description="Guidance scale as defined in Classifier-Free Diffusion Guidance."), - InputParam("clip_skip", type_hint=int, description="Number of layers to be skipped from CLIP while computing the prompt embeddings."), - InputParam("max_sequence_length", type_hint=int, default=256, description="Maximum sequence length to use with the prompt."), - InputParam("joint_attention_kwargs", description="A kwargs dictionary passed along to the AttentionProcessor."), + return [ + InputParam( + "prompt", + description="The prompt or prompts to guide the image generation.", + ), + InputParam( + "prompt_2", + description="The prompt or prompts to be sent to tokenizer_2 and text_encoder_2.", + ), + InputParam( + "prompt_3", + description="The prompt or prompts to be sent to tokenizer_3 and text_encoder_3.", + ), + InputParam( + "negative_prompt", + description="The prompt or prompts not to guide the image generation.", + ), + InputParam( + "negative_prompt_2", + description="The prompt or prompts not to guide the image generation for tokenizer_2.", + ), + InputParam( + "negative_prompt_3", + description="The prompt or prompts not to guide the image generation for tokenizer_3.", + ), + InputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated text embeddings.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings.", + ), + InputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings.", + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative pooled text embeddings.", + ), + InputParam( + "guidance_scale", + default=7.0, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance.", + ), + InputParam( + "clip_skip", + type_hint=int, + description="Number of layers to be skipped from CLIP while computing the prompt embeddings.", + ), + InputParam( + "max_sequence_length", + type_hint=int, + default=256, + description="Maximum sequence length to use with the prompt.", + ), + InputParam( + "joint_attention_kwargs", + description="A kwargs dictionary passed along to the AttentionProcessor.", + ), ] @property @@ -496,16 +634,25 @@ def intermediate_outputs(self) -> list[OutputParam]: ] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device do_classifier_free_guidance = block_state.guidance_scale > 1.0 lora_scale = ( - block_state.joint_attention_kwargs.get("scale", None) if block_state.joint_attention_kwargs else None + block_state.joint_attention_kwargs.get("scale", None) + if block_state.joint_attention_kwargs + else None ) - prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = encode_prompt( + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = encode_prompt( components=components, prompt=block_state.prompt, prompt_2=getattr(block_state, "prompt_2", None), @@ -519,7 +666,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS prompt_embeds=getattr(block_state, "prompt_embeds", None), negative_prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), pooled_prompt_embeds=getattr(block_state, "pooled_prompt_embeds", None), - negative_pooled_prompt_embeds=getattr(block_state, "negative_pooled_prompt_embeds", None), + negative_pooled_prompt_embeds=getattr( + block_state, "negative_pooled_prompt_embeds", None + ), clip_skip=getattr(block_state, "clip_skip", None), max_sequence_length=getattr(block_state, "max_sequence_length", 256), lora_scale=lora_scale, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index 78f48bb74ae5..a01be86816a5 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -17,10 +17,12 @@ from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam -from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size +from ..qwenimage.inputs import ( + calculate_dimension_from_latents, + repeat_tensor_to_batch_size, +) from .modular_pipeline import StableDiffusion3ModularPipeline - logger = logging.get_logger(__name__) @@ -29,20 +31,48 @@ class StableDiffusion3TextInputStep(ModularPipelineBlocks): @property def description(self) -> str: - return ( - "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." - ) + return "Text input processing step that standardizes text embeddings for SD3, applying CFG duplication if needed." @property def inputs(self) -> list[InputParam]: - return[ - InputParam("num_images_per_prompt", default=1, description="The number of images to generate per prompt."), - InputParam("guidance_scale", default=7.0, description="Guidance scale as defined in Classifier-Free Diffusion Guidance."), - InputParam("skip_guidance_layers", type_hint=list, description="A list of integers that specify layers to skip during guidance."), - InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated text embeddings."), - InputParam("pooled_prompt_embeds", required=True, type_hint=torch.Tensor, description="Pre-generated pooled text embeddings."), - InputParam("negative_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative text embeddings."), - InputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Pre-generated negative pooled text embeddings."), + return [ + InputParam( + "num_images_per_prompt", + default=1, + description="The number of images to generate per prompt.", + ), + InputParam( + "guidance_scale", + default=7.0, + description="Guidance scale as defined in Classifier-Free Diffusion Guidance.", + ), + InputParam( + "skip_guidance_layers", + type_hint=list, + description="A list of integers that specify layers to skip during guidance.", + ), + InputParam( + "prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated text embeddings.", + ), + InputParam( + "pooled_prompt_embeds", + required=True, + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings.", + ), + InputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative text embeddings.", + ), + InputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="Pre-generated negative pooled text embeddings.", + ), ] @property @@ -58,7 +88,9 @@ def intermediate_outputs(self) -> list[str]: ] @torch.no_grad() - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) block_state.batch_size = block_state.prompt_embeds.shape[0] @@ -66,23 +98,38 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS block_state.do_classifier_free_guidance = block_state.guidance_scale > 1.0 _, seq_len, _ = block_state.prompt_embeds.shape - prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1) + prompt_embeds = block_state.prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) + prompt_embeds = prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) - pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat(1, block_state.num_images_per_prompt) + pooled_prompt_embeds = block_state.pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt + ) pooled_prompt_embeds = pooled_prompt_embeds.view( block_state.batch_size * block_state.num_images_per_prompt, -1 ) - if block_state.do_classifier_free_guidance and block_state.negative_prompt_embeds is not None: + if ( + block_state.do_classifier_free_guidance + and block_state.negative_prompt_embeds is not None + ): _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape - negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt, 1 + ) negative_prompt_embeds = negative_prompt_embeds.view( - block_state.batch_size * block_state.num_images_per_prompt, neg_seq_len, -1 + block_state.batch_size * block_state.num_images_per_prompt, + neg_seq_len, + -1, ) - negative_pooled_prompt_embeds = block_state.negative_pooled_prompt_embeds.repeat( - 1, block_state.num_images_per_prompt + negative_pooled_prompt_embeds = ( + block_state.negative_pooled_prompt_embeds.repeat( + 1, block_state.num_images_per_prompt + ) ) negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view( block_state.batch_size * block_state.num_images_per_prompt, -1 @@ -105,12 +152,20 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS class StableDiffusion3AdditionalInputsStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" - def __init__(self, image_latent_inputs: list[str] = ["image_latents"], additional_batch_inputs: list[str] = []): + def __init__( + self, + image_latent_inputs: list[str] = ["image_latents"], + additional_batch_inputs: list[str] = [], + ): self._image_latent_inputs = ( - image_latent_inputs if isinstance(image_latent_inputs, list) else [image_latent_inputs] + image_latent_inputs + if isinstance(image_latent_inputs, list) + else [image_latent_inputs] ) self._additional_batch_inputs = ( - additional_batch_inputs if isinstance(additional_batch_inputs, list) else [additional_batch_inputs] + additional_batch_inputs + if isinstance(additional_batch_inputs, list) + else [additional_batch_inputs] ) super().__init__() @@ -120,14 +175,24 @@ def description(self) -> str: @property def inputs(self) -> list[InputParam]: - inputs =[ - InputParam("num_images_per_prompt", default=1, description="The number of images to generate per prompt."), + inputs = [ + InputParam( + "num_images_per_prompt", + default=1, + description="The number of images to generate per prompt.", + ), InputParam("batch_size", required=True, description="The batch size."), - InputParam("height", description="The height in pixels of the generated image."), - InputParam("width", description="The width in pixels of the generated image."), + InputParam( + "height", description="The height in pixels of the generated image." + ), + InputParam( + "width", description="The width in pixels of the generated image." + ), ] for name in self._image_latent_inputs + self._additional_batch_inputs: - inputs.append(InputParam(name, description=f"Latent input {name} to be processed.")) + inputs.append( + InputParam(name, description=f"Latent input {name} to be processed.") + ) return inputs @property @@ -137,7 +202,9 @@ def intermediate_outputs(self) -> list[OutputParam]: OutputParam("image_width", type_hint=int), ] - def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineState) -> PipelineState: + def __call__( + self, components: StableDiffusion3ModularPipeline, state: PipelineState + ) -> PipelineState: block_state = self.get_block_state(state) for input_name in self._image_latent_inputs: @@ -145,7 +212,9 @@ def __call__(self, components: StableDiffusion3ModularPipeline, state: PipelineS if tensor is None: continue - height, width = calculate_dimension_from_latents(tensor, components.vae_scale_factor) + height, width = calculate_dimension_from_latents( + tensor, components.vae_scale_factor + ) block_state.height = block_state.height or height block_state.width = block_state.width or width diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 7cfe86904e38..4432b6463785 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -28,11 +28,7 @@ StableDiffusion3TextEncoderStep, StableDiffusion3VaeEncoderStep, ) -from .inputs import ( - StableDiffusion3AdditionalInputsStep, - StableDiffusion3TextInputStep, -) - +from .inputs import StableDiffusion3AdditionalInputsStep, StableDiffusion3TextInputStep logger = logging.get_logger(__name__) @@ -63,7 +59,10 @@ class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): """ model_name = "stable-diffusion-3" - block_classes = [StableDiffusion3ProcessImagesInputStep(), StableDiffusion3VaeEncoderStep()] + block_classes = [ + StableDiffusion3ProcessImagesInputStep(), + StableDiffusion3VaeEncoderStep(), + ] block_names = ["preprocess", "encode"] @@ -138,7 +137,10 @@ class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): """ model_name = "stable-diffusion-3" - block_classes = [StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep()] + block_classes = [ + StableDiffusion3PrepareLatentsStep(), + StableDiffusion3SetTimestepsStep(), + ] block_names = ["prepare_latents", "set_timesteps"] @@ -242,7 +244,10 @@ class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): """ model_name = "stable-diffusion-3" - block_classes = [StableDiffusion3Img2ImgBeforeDenoiseStep, StableDiffusion3BeforeDenoiseStep] + block_classes = [ + StableDiffusion3Img2ImgBeforeDenoiseStep, + StableDiffusion3BeforeDenoiseStep, + ] block_names = ["img2img", "text2image"] block_trigger_inputs = ["image_latents", None] @@ -294,7 +299,10 @@ class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): """ model_name = "stable-diffusion-3" - block_classes = [StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep()] + block_classes = [ + StableDiffusion3TextInputStep(), + StableDiffusion3AdditionalInputsStep(), + ] block_names = ["text_inputs", "additional_inputs"] @@ -407,7 +415,11 @@ class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): """ model_name = "stable-diffusion-3" - block_classes = [StableDiffusion3AutoInputStep, StableDiffusion3AutoBeforeDenoiseStep, StableDiffusion3DenoiseStep] + block_classes = [ + StableDiffusion3AutoInputStep, + StableDiffusion3AutoBeforeDenoiseStep, + StableDiffusion3DenoiseStep, + ] block_names = ["input", "before_denoise", "denoise"] @property diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py index 0e893714b70d..645ad930b426 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_pipeline.py @@ -16,11 +16,12 @@ from ...utils import logging from ..modular_pipeline import ModularPipeline - logger = logging.get_logger(__name__) -class StableDiffusion3ModularPipeline(ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin): +class StableDiffusion3ModularPipeline( + ModularPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin +): """ A ModularPipeline for Stable Diffusion 3. diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index f361f6a92a1d..f1cf6ce4630f 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -20,17 +20,22 @@ import torch from diffusers.image_processor import VaeImageProcessor -from diffusers.modular_pipelines.stable_diffusion_3 import StableDiffusion3AutoBlocks, StableDiffusion3ModularPipeline +from diffusers.modular_pipelines.stable_diffusion_3 import ( + StableDiffusion3AutoBlocks, + StableDiffusion3ModularPipeline, +) from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin - SD3_TEXT2IMAGE_WORKFLOWS = { "text2image": [ ("text_encoder", "StableDiffusion3TextEncoderStep"), ("denoise.input", "StableDiffusion3TextInputStep"), - ("denoise.before_denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ( + "denoise.before_denoise.prepare_latents", + "StableDiffusion3PrepareLatentsStep", + ), ("denoise.before_denoise.set_timesteps", "StableDiffusion3SetTimestepsStep"), ("denoise.denoise", "StableDiffusion3DenoiseStep"), ("decode", "StableDiffusion3DecodeStep"), @@ -102,9 +107,18 @@ def test_float16_inference(self): ("vae_encoder.encode", "StableDiffusion3VaeEncoderStep"), ("denoise.input.text_inputs", "StableDiffusion3TextInputStep"), ("denoise.input.additional_inputs", "StableDiffusion3AdditionalInputsStep"), - ("denoise.before_denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), - ("denoise.before_denoise.set_timesteps", "StableDiffusion3Img2ImgSetTimestepsStep"), - ("denoise.before_denoise.prepare_img2img_latents", "StableDiffusion3Img2ImgPrepareLatentsStep"), + ( + "denoise.before_denoise.prepare_latents", + "StableDiffusion3PrepareLatentsStep", + ), + ( + "denoise.before_denoise.set_timesteps", + "StableDiffusion3Img2ImgSetTimestepsStep", + ), + ( + "denoise.before_denoise.prepare_img2img_latents", + "StableDiffusion3Img2ImgPrepareLatentsStep", + ), ("denoise.denoise", "StableDiffusion3DenoiseStep"), ("decode", "StableDiffusion3DecodeStep"), ] From 8626d27a4308a072f226f3448e53b2eaf9be777d Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Thu, 2 Apr 2026 17:53:09 +0000 Subject: [PATCH 12/14] add descriptions for outputparams and styling --- .../stable_diffusion_3/before_denoise.py | 40 ++++++++++-- .../stable_diffusion_3/denoise.py | 8 ++- .../stable_diffusion_3/encoders.py | 6 +- .../stable_diffusion_3/inputs.py | 54 +++++++++++++--- .../modular_blocks_stable_diffusion_3.py | 62 +++++++++---------- ...est_modular_pipeline_stable_diffusion_3.py | 1 + 6 files changed, 123 insertions(+), 48 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 1de2af37d3d4..0a44449aafdb 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -109,8 +109,16 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam("timesteps", type_hint=torch.Tensor), - OutputParam("num_inference_steps", type_hint=int), + OutputParam( + "timesteps", + type_hint=torch.Tensor, + description="The timesteps schedule for the denoising process.", + ), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The final number of inference steps.", + ), ] @torch.no_grad() @@ -191,8 +199,16 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam("timesteps", type_hint=torch.Tensor), - OutputParam("num_inference_steps", type_hint=int), + OutputParam( + "timesteps", + type_hint=torch.Tensor, + description="The timesteps schedule for the denoising process.", + ), + OutputParam( + "num_inference_steps", + type_hint=int, + description="The final number of inference steps.", + ), ] @staticmethod @@ -285,7 +301,13 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor)] + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The prepared latent tensors to be denoised.", + ) + ] @torch.no_grad() def __call__( @@ -352,7 +374,13 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return [OutputParam("initial_noise", type_hint=torch.Tensor)] + return [ + OutputParam( + "initial_noise", + type_hint=torch.Tensor, + description="The initial noise applied to the image latents.", + ) + ] @torch.no_grad() def __call__( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py index cbc2092f853b..f5886d4ac40e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/denoise.py @@ -189,7 +189,13 @@ def expected_components(self) -> list[ComponentSpec]: @property def intermediate_outputs(self) -> list[OutputParam]: - return [OutputParam("latents", type_hint=torch.Tensor)] + return [ + OutputParam( + "latents", + type_hint=torch.Tensor, + description="The denoised latent tensors.", + ) + ] @torch.no_grad() def __call__( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 159a90e248f3..0dc33ebb617c 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -420,7 +420,11 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: - return [OutputParam(name="processed_image")] + return [ + OutputParam( + name="processed_image", description="The pre-processed image tensor." + ) + ] @staticmethod def check_inputs(height, width, vae_scale_factor, patch_size): diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index a01be86816a5..e5489f42e70e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -78,13 +78,41 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[str]: return [ - OutputParam("batch_size", type_hint=int), - OutputParam("dtype", type_hint=torch.dtype), - OutputParam("do_classifier_free_guidance", type_hint=bool), - OutputParam("prompt_embeds", type_hint=torch.Tensor), - OutputParam("pooled_prompt_embeds", type_hint=torch.Tensor), - OutputParam("negative_prompt_embeds", type_hint=torch.Tensor), - OutputParam("negative_pooled_prompt_embeds", type_hint=torch.Tensor), + OutputParam( + "batch_size", + type_hint=int, + description="The batch size for the inference.", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="The expected data type for latents.", + ), + OutputParam( + "do_classifier_free_guidance", + type_hint=bool, + description="Flag indicating if CFG is enabled.", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + description="The processed text embeddings.", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + description="The processed pooled text embeddings.", + ), + OutputParam( + "negative_prompt_embeds", + type_hint=torch.Tensor, + description="The processed negative text embeddings.", + ), + OutputParam( + "negative_pooled_prompt_embeds", + type_hint=torch.Tensor, + description="The processed negative pooled text embeddings.", + ), ] @torch.no_grad() @@ -198,8 +226,16 @@ def inputs(self) -> list[InputParam]: @property def intermediate_outputs(self) -> list[OutputParam]: return [ - OutputParam("image_height", type_hint=int), - OutputParam("image_width", type_hint=int), + OutputParam( + "image_height", + type_hint=int, + description="The height of the generated image.", + ), + OutputParam( + "image_width", + type_hint=int, + description="The width of the generated image.", + ), ] def __call__( diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index 4432b6463785..d63cb4e2e55c 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -53,7 +53,7 @@ class StableDiffusion3Img2ImgVaeEncoderStep(SequentialPipelineBlocks): Outputs: processed_image (`None`): - TODO: Add description. + The pre-processed image tensor. image_latents (`Tensor`): The latents representing the reference image """ @@ -86,7 +86,7 @@ class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): Outputs: processed_image (`None`): - TODO: Add description. + The pre-processed image tensor. image_latents (`Tensor`): The latents representing the reference image """ @@ -129,11 +129,11 @@ class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): Outputs: latents (`Tensor`): - TODO: Add description. + The prepared latent tensors to be denoised. timesteps (`Tensor`): - TODO: Add description. + The timesteps schedule for the denoising process. num_inference_steps (`int`): - TODO: Add description. + The final number of inference steps. """ model_name = "stable-diffusion-3" @@ -180,13 +180,13 @@ class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): Outputs: latents (`Tensor`): - TODO: Add description. + The prepared latent tensors to be denoised. timesteps (`Tensor`): - TODO: Add description. + The timesteps schedule for the denoising process. num_inference_steps (`int`): - TODO: Add description. + The final number of inference steps. initial_noise (`Tensor`): - TODO: Add description. + The initial noise applied to the image latents. """ model_name = "stable-diffusion-3" @@ -234,13 +234,13 @@ class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): Outputs: latents (`Tensor`): - TODO: Add description. + The prepared latent tensors to be denoised. timesteps (`Tensor`): - TODO: Add description. + The timesteps schedule for the denoising process. num_inference_steps (`int`): - TODO: Add description. + The final number of inference steps. initial_noise (`Tensor`): - TODO: Add description. + The initial noise applied to the image latents. """ model_name = "stable-diffusion-3" @@ -279,23 +279,23 @@ class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): Outputs: batch_size (`int`): - TODO: Add description. + The batch size for the inference. dtype (`dtype`): - TODO: Add description. + The expected data type for latents. do_classifier_free_guidance (`bool`): - TODO: Add description. + Flag indicating if CFG is enabled. prompt_embeds (`Tensor`): - TODO: Add description. + The processed text embeddings. pooled_prompt_embeds (`Tensor`): - TODO: Add description. + The processed pooled text embeddings. negative_prompt_embeds (`Tensor`): - TODO: Add description. + The processed negative text embeddings. negative_pooled_prompt_embeds (`Tensor`): - TODO: Add description. + The processed negative pooled text embeddings. image_height (`int`): - TODO: Add description. + The height of the generated image. image_width (`int`): - TODO: Add description. + The width of the generated image. """ model_name = "stable-diffusion-3" @@ -333,23 +333,23 @@ class StableDiffusion3AutoInputStep(AutoPipelineBlocks): Outputs: batch_size (`int`): - TODO: Add description. + The batch size for the inference. dtype (`dtype`): - TODO: Add description. + The expected data type for latents. do_classifier_free_guidance (`bool`): - TODO: Add description. + Flag indicating if CFG is enabled. prompt_embeds (`Tensor`): - TODO: Add description. + The processed text embeddings. pooled_prompt_embeds (`Tensor`): - TODO: Add description. + The processed pooled text embeddings. negative_prompt_embeds (`Tensor`): - TODO: Add description. + The processed negative text embeddings. negative_pooled_prompt_embeds (`Tensor`): - TODO: Add description. + The processed negative pooled text embeddings. image_height (`int`): - TODO: Add description. + The height of the generated image. image_width (`int`): - TODO: Add description. + The width of the generated image. """ model_name = "stable-diffusion-3" diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index f1cf6ce4630f..3db63bfa16b0 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -28,6 +28,7 @@ from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin + SD3_TEXT2IMAGE_WORKFLOWS = { "text2image": [ ("text_encoder", "StableDiffusion3TextEncoderStep"), From 0417997c60587d601d5a5f2980545e31fb68d38e Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sun, 5 Apr 2026 05:48:55 +0000 Subject: [PATCH 13/14] 1. fix imports 2. refactored encoders and inputs 3. refactored for more flat structure 4. styling --- .../stable_diffusion_3/before_denoise.py | 62 ++- .../stable_diffusion_3/encoders.py | 360 ++++++++---------- .../stable_diffusion_3/inputs.py | 125 ++++-- .../modular_blocks_stable_diffusion_3.py | 300 +++++---------- ...est_modular_pipeline_stable_diffusion_3.py | 26 +- 5 files changed, 404 insertions(+), 469 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py index 0a44449aafdb..462f2d93d97d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/before_denoise.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import torch -from ...pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( - calculate_shift, - retrieve_timesteps, -) from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -29,6 +26,63 @@ logger = logging.get_logger(__name__) +# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + def _get_initial_timesteps_and_optionals( transformer, scheduler, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 0dc33ebb617c..8c5c4f6f2273 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -75,33 +75,30 @@ def encode_vae_image( return image_latents -# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds with self -> components def _get_t5_prompt_embeds( - components, + text_encoder: T5EncoderModel | None, + tokenizer: T5TokenizerFast | None, prompt: str | list[str] = None, - num_images_per_prompt: int = 1, max_sequence_length: int = 256, device: torch.device | None = None, - dtype: torch.dtype | None = None, + joint_attention_dim: int = 4096, ): - device = device or components._execution_device - dtype = dtype or components.text_encoder.dtype + device = device or ( + text_encoder.device if text_encoder is not None else torch.device("cpu") + ) + dtype = text_encoder.dtype if text_encoder is not None else torch.float32 prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) - if components.text_encoder_3 is None: + if text_encoder is None or tokenizer is None: return torch.zeros( - ( - batch_size * num_images_per_prompt, - max_sequence_length, - components.transformer.config.joint_attention_dim, - ), + (batch_size, max_sequence_length, joint_attention_dim), device=device, dtype=dtype, ) - text_inputs = components.tokenizer_3( + text_inputs = tokenizer( prompt, padding="max_length", max_length=max_sequence_length, @@ -110,59 +107,56 @@ def _get_t5_prompt_embeds( return_tensors="pt", ) text_input_ids = text_inputs.input_ids - untruncated_ids = components.tokenizer_3( + untruncated_ids = tokenizer( prompt, padding="longest", return_tensors="pt" ).input_ids if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( text_input_ids, untruncated_ids ): - removed_text = components.tokenizer_3.batch_decode( - untruncated_ids[:, components.tokenizer_max_length - 1 : -1] + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] ) logger.warning( - "The following part of your input was truncated because `max_sequence_length` is set to " + f"The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) - prompt_embeds = components.text_encoder_3(text_input_ids.to(device))[0] - - dtype = components.text_encoder_3.dtype + prompt_embeds = text_encoder(text_input_ids.to(device))[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds -# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_clip_prompt_embeds with self -> components def _get_clip_prompt_embeds( - components, + text_encoder: CLIPTextModelWithProjection | None, + tokenizer: CLIPTokenizer | None, prompt: str | list[str], - num_images_per_prompt: int = 1, device: torch.device | None = None, clip_skip: int | None = None, - clip_model_index: int = 0, + hidden_size: int = 768, ): - device = device or components._execution_device - - clip_tokenizers = [components.tokenizer, components.tokenizer_2] - clip_text_encoders = [components.text_encoder, components.text_encoder_2] - - tokenizer = clip_tokenizers[clip_model_index] - text_encoder = clip_text_encoders[clip_model_index] + device = device or ( + text_encoder.device if text_encoder is not None else torch.device("cpu") + ) + dtype = text_encoder.dtype if text_encoder is not None else torch.float32 prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) + if text_encoder is None or tokenizer is None: + prompt_embeds = torch.zeros( + (batch_size, 77, hidden_size), device=device, dtype=dtype + ) + pooled_prompt_embeds = torch.zeros( + (batch_size, hidden_size), device=device, dtype=dtype + ) + return prompt_embeds, pooled_prompt_embeds + text_inputs = tokenizer( prompt, padding="max_length", - max_length=components.tokenizer_max_length, + max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt", ) @@ -175,11 +169,11 @@ def _get_clip_prompt_embeds( text_input_ids, untruncated_ids ): removed_text = tokenizer.batch_decode( - untruncated_ids[:, components.tokenizer_max_length - 1 : -1] + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] ) logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {components.tokenizer_max_length} tokens: {removed_text}" + f"The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" ) prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) pooled_prompt_embeds = prompt_embeds[0] @@ -189,192 +183,158 @@ def _get_clip_prompt_embeds( else: prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] - prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) - - _, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt) - pooled_prompt_embeds = pooled_prompt_embeds.view( - batch_size * num_images_per_prompt, -1 - ) + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) return prompt_embeds, pooled_prompt_embeds -# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_prompt with self -> components, self._get_clip_prompt_embeds -> _get_clip_prompt_embeds, self._get_t5_prompt_embeds -> _get_t5_prompt_embeds def encode_prompt( components, prompt: str | list[str], - prompt_2: str | list[str], - prompt_3: str | list[str], + prompt_2: str | list[str] | None = None, + prompt_3: str | list[str] | None = None, device: torch.device | None = None, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, negative_prompt: str | list[str] | None = None, negative_prompt_2: str | list[str] | None = None, negative_prompt_3: str | list[str] | None = None, - prompt_embeds: torch.FloatTensor | None = None, - negative_prompt_embeds: torch.FloatTensor | None = None, - pooled_prompt_embeds: torch.FloatTensor | None = None, - negative_pooled_prompt_embeds: torch.FloatTensor | None = None, clip_skip: int | None = None, max_sequence_length: int = 256, lora_scale: float | None = None, ): device = device or components._execution_device - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin): components._lora_scale = lora_scale - - # dynamically adjust the LoRA scale if components.text_encoder is not None and USE_PEFT_BACKEND: scale_lora_layers(components.text_encoder, lora_scale) if components.text_encoder_2 is not None and USE_PEFT_BACKEND: scale_lora_layers(components.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt is not None: - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - prompt_2 = prompt_2 or prompt - prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - - prompt_3 = prompt_3 or prompt - prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 - - prompt_embed, pooled_prompt_embed = _get_clip_prompt_embeds( - components, - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=clip_skip, - clip_model_index=0, - ) - prompt_2_embed, pooled_prompt_2_embed = _get_clip_prompt_embeds( - components, - prompt=prompt_2, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=clip_skip, - clip_model_index=1, - ) - clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) - - t5_prompt_embed = _get_t5_prompt_embeds( - components, - prompt=prompt_3, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) - - clip_prompt_embeds = torch.nn.functional.pad( - clip_prompt_embeds, - (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), - ) + batch_size = len(prompt) - prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) - pooled_prompt_embeds = torch.cat( - [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1 - ) + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 - if do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt_2 or negative_prompt - negative_prompt_3 = negative_prompt_3 or negative_prompt + prompt_3 = prompt_3 or prompt + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 - # normalize str to list - negative_prompt = ( - batch_size * [negative_prompt] - if isinstance(negative_prompt, str) - else negative_prompt - ) - negative_prompt_2 = ( - batch_size * [negative_prompt_2] - if isinstance(negative_prompt_2, str) - else negative_prompt_2 - ) - negative_prompt_3 = ( - batch_size * [negative_prompt_3] - if isinstance(negative_prompt_3, str) - else negative_prompt_3 - ) + prompt_embed, pooled_prompt_embed = _get_clip_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=prompt, + device=device, + clip_skip=clip_skip, + hidden_size=768, + ) + prompt_2_embed, pooled_prompt_2_embed = _get_clip_prompt_embeds( + components.text_encoder_2, + components.tokenizer_2, + prompt=prompt_2, + device=device, + clip_skip=clip_skip, + hidden_size=1280, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = _get_t5_prompt_embeds( + components.text_encoder_3, + components.tokenizer_3, + prompt=prompt_3, + max_sequence_length=max_sequence_length, + device=device, + joint_attention_dim=( + components.transformer.config.joint_attention_dim + if getattr(components, "transformer", None) is not None + else 4096 + ), + ) - if prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), + ) + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1 + ) - negative_prompt_embed, negative_pooled_prompt_embed = _get_clip_prompt_embeds( - components, - negative_prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=None, - clip_model_index=0, - ) - negative_prompt_2_embed, negative_pooled_prompt_2_embed = ( - _get_clip_prompt_embeds( - components, - negative_prompt_2, - device=device, - num_images_per_prompt=num_images_per_prompt, - clip_skip=None, - clip_model_index=1, - ) - ) - negative_clip_prompt_embeds = torch.cat( - [negative_prompt_embed, negative_prompt_2_embed], dim=-1 - ) + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + negative_prompt_3 = negative_prompt_3 or negative_prompt - t5_negative_prompt_embed = _get_t5_prompt_embeds( - components, - prompt=negative_prompt_3, - num_images_per_prompt=num_images_per_prompt, - max_sequence_length=max_sequence_length, - device=device, - ) + negative_prompt = ( + batch_size * [negative_prompt] + if isinstance(negative_prompt, str) + else negative_prompt + ) + negative_prompt_2 = ( + batch_size * [negative_prompt_2] + if isinstance(negative_prompt_2, str) + else negative_prompt_2 + ) + negative_prompt_3 = ( + batch_size * [negative_prompt_3] + if isinstance(negative_prompt_3, str) + else negative_prompt_3 + ) - negative_clip_prompt_embeds = torch.nn.functional.pad( - negative_clip_prompt_embeds, - ( - 0, - t5_negative_prompt_embed.shape[-1] - - negative_clip_prompt_embeds.shape[-1], - ), - ) + negative_prompt_embed, negative_pooled_prompt_embed = _get_clip_prompt_embeds( + components.text_encoder, + components.tokenizer, + prompt=negative_prompt, + device=device, + clip_skip=None, + hidden_size=768, + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = _get_clip_prompt_embeds( + components.text_encoder_2, + components.tokenizer_2, + prompt=negative_prompt_2, + device=device, + clip_skip=None, + hidden_size=1280, + ) + negative_clip_prompt_embeds = torch.cat( + [negative_prompt_embed, negative_prompt_2_embed], dim=-1 + ) - negative_prompt_embeds = torch.cat( - [negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2 - ) - negative_pooled_prompt_embeds = torch.cat( - [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 - ) + t5_negative_prompt_embed = _get_t5_prompt_embeds( + components.text_encoder_3, + components.tokenizer_3, + prompt=negative_prompt_3, + max_sequence_length=max_sequence_length, + device=device, + joint_attention_dim=( + components.transformer.config.joint_attention_dim + if getattr(components, "transformer", None) is not None + else 4096 + ), + ) - if components.text_encoder is not None: - if isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder, lora_scale) + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]), + ) + negative_prompt_embeds = torch.cat( + [negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2 + ) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) - if components.text_encoder_2 is not None: - if isinstance(components, SD3LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(components.text_encoder_2, lora_scale) + if ( + components.text_encoder is not None + and isinstance(components, SD3LoraLoaderMixin) + and USE_PEFT_BACKEND + ): + unscale_lora_layers(components.text_encoder, lora_scale) + if ( + components.text_encoder_2 is not None + and isinstance(components, SD3LoraLoaderMixin) + and USE_PEFT_BACKEND + ): + unscale_lora_layers(components.text_encoder_2, lora_scale) return ( prompt_embeds, @@ -606,11 +566,6 @@ def inputs(self) -> list[InputParam]: type_hint=torch.Tensor, description="Pre-generated negative pooled text embeddings.", ), - InputParam( - "guidance_scale", - default=7.0, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance.", - ), InputParam( "clip_skip", type_hint=int, @@ -644,10 +599,9 @@ def __call__( block_state = self.get_block_state(state) block_state.device = components._execution_device - do_classifier_free_guidance = block_state.guidance_scale > 1.0 lora_scale = ( block_state.joint_attention_kwargs.get("scale", None) - if block_state.joint_attention_kwargs + if getattr(block_state, "joint_attention_kwargs", None) else None ) @@ -662,17 +616,9 @@ def __call__( prompt_2=getattr(block_state, "prompt_2", None), prompt_3=getattr(block_state, "prompt_3", None), device=block_state.device, - num_images_per_prompt=getattr(block_state, "num_images_per_prompt", 1), - do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=getattr(block_state, "negative_prompt", None), negative_prompt_2=getattr(block_state, "negative_prompt_2", None), negative_prompt_3=getattr(block_state, "negative_prompt_3", None), - prompt_embeds=getattr(block_state, "prompt_embeds", None), - negative_prompt_embeds=getattr(block_state, "negative_prompt_embeds", None), - pooled_prompt_embeds=getattr(block_state, "pooled_prompt_embeds", None), - negative_pooled_prompt_embeds=getattr( - block_state, "negative_pooled_prompt_embeds", None - ), clip_skip=getattr(block_state, "clip_skip", None), max_sequence_length=getattr(block_state, "max_sequence_length", 256), lora_scale=lora_scale, diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py index e5489f42e70e..9fc3a21b178e 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/inputs.py @@ -17,15 +17,104 @@ from ...utils import logging from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam -from ..qwenimage.inputs import ( - calculate_dimension_from_latents, - repeat_tensor_to_batch_size, -) from .modular_pipeline import StableDiffusion3ModularPipeline logger = logging.get_logger(__name__) +# Copied from diffusers.modular_pipelines.qwenimage.inputs.repeat_tensor_to_batch_size +def repeat_tensor_to_batch_size( + input_name: str, + input_tensor: torch.Tensor, + batch_size: int, + num_images_per_prompt: int = 1, +) -> torch.Tensor: + """Repeat tensor elements to match the final batch size. + + This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) + by repeating each element along dimension 0. + + The input tensor must have batch size 1 or batch_size. The function will: + - If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times + - If batch size equals batch_size: repeat each element num_images_per_prompt times + + Args: + input_name (str): Name of the input tensor (used for error messages) + input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. + batch_size (int): The base batch size (number of prompts) + num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. + + Returns: + torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) + + Raises: + ValueError: If input_tensor is not a torch.Tensor or has invalid batch size + + Examples: + tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, + batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: + [4, 3] + + tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", + tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) + - shape: [4, 3] + """ + # make sure input is a tensor + if not isinstance(input_tensor, torch.Tensor): + raise ValueError(f"`{input_name}` must be a tensor") + + # make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts + if input_tensor.shape[0] == 1: + repeat_by = batch_size * num_images_per_prompt + elif input_tensor.shape[0] == batch_size: + repeat_by = num_images_per_prompt + else: + raise ValueError( + f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" + ) + + # expand the tensor to match the batch_size * num_images_per_prompt + input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) + + return input_tensor + + +# Copied from diffusers.modular_pipelines.qwenimage.inputs.calculate_dimension_from_latents +def calculate_dimension_from_latents( + latents: torch.Tensor, vae_scale_factor: int +) -> tuple[int, int]: + """Calculate image dimensions from latent tensor dimensions. + + This function converts latent space dimensions to image space dimensions by multiplying the latent height and width + by the VAE scale factor. + + Args: + latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. + Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] + vae_scale_factor (int): The scale factor used by the VAE to compress images. + Typically 8 for most VAEs (image is 8x larger than latents in each dimension) + + Returns: + tuple[int, int]: The calculated image dimensions as (height, width) + + Raises: + ValueError: If latents tensor doesn't have 4 or 5 dimensions + + """ + # make sure the latents are not packed + if latents.ndim != 4 and latents.ndim != 5: + raise ValueError( + f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}" + ) + + latent_height, latent_width = latents.shape[-2:] + + height = latent_height * vae_scale_factor + width = latent_width * vae_scale_factor + + return height, width + + class StableDiffusion3TextInputStep(ModularPipelineBlocks): model_name = "stable-diffusion-3" @@ -41,16 +130,6 @@ def inputs(self) -> list[InputParam]: default=1, description="The number of images to generate per prompt.", ), - InputParam( - "guidance_scale", - default=7.0, - description="Guidance scale as defined in Classifier-Free Diffusion Guidance.", - ), - InputParam( - "skip_guidance_layers", - type_hint=list, - description="A list of integers that specify layers to skip during guidance.", - ), InputParam( "prompt_embeds", required=True, @@ -88,11 +167,6 @@ def intermediate_outputs(self) -> list[str]: type_hint=torch.dtype, description="The expected data type for latents.", ), - OutputParam( - "do_classifier_free_guidance", - type_hint=bool, - description="Flag indicating if CFG is enabled.", - ), OutputParam( "prompt_embeds", type_hint=torch.Tensor, @@ -123,7 +197,6 @@ def __call__( block_state.batch_size = block_state.prompt_embeds.shape[0] block_state.dtype = block_state.prompt_embeds.dtype - block_state.do_classifier_free_guidance = block_state.guidance_scale > 1.0 _, seq_len, _ = block_state.prompt_embeds.shape prompt_embeds = block_state.prompt_embeds.repeat( @@ -140,10 +213,7 @@ def __call__( block_state.batch_size * block_state.num_images_per_prompt, -1 ) - if ( - block_state.do_classifier_free_guidance - and block_state.negative_prompt_embeds is not None - ): + if getattr(block_state, "negative_prompt_embeds", None) is not None: _, neg_seq_len, _ = block_state.negative_prompt_embeds.shape negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( 1, block_state.num_images_per_prompt, 1 @@ -163,16 +233,15 @@ def __call__( block_state.batch_size * block_state.num_images_per_prompt, -1 ) - block_state.prompt_embeds = prompt_embeds - block_state.pooled_prompt_embeds = pooled_prompt_embeds block_state.negative_prompt_embeds = negative_prompt_embeds block_state.negative_pooled_prompt_embeds = negative_pooled_prompt_embeds else: - block_state.prompt_embeds = prompt_embeds - block_state.pooled_prompt_embeds = pooled_prompt_embeds block_state.negative_prompt_embeds = None block_state.negative_pooled_prompt_embeds = None + block_state.prompt_embeds = prompt_embeds + block_state.pooled_prompt_embeds = pooled_prompt_embeds + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py index d63cb4e2e55c..29171e0c64d2 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/modular_blocks_stable_diffusion_3.py @@ -98,26 +98,31 @@ class StableDiffusion3AutoVaeEncoderStep(AutoPipelineBlocks): # auto_docstring -class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): +class StableDiffusion3T2ICoreDenoiseStep(SequentialPipelineBlocks): """ Components: - scheduler (`FlowMatchEulerDiscreteScheduler`) + scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer + (`SD3Transformer2DModel`) Inputs: + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. height (`int`, *optional*): The height in pixels of the generated image. width (`int`, *optional*): The width in pixels of the generated image. latents (`Tensor | NoneType`, *optional*): Pre-generated noisy latents to be used as inputs for image generation. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. generator (`None`, *optional*): One or a list of torch generator(s) to make generation deterministic. - batch_size (`int`): - The batch size for latent generation. - dtype (`dtype`, *optional*): - The data type for the latents. num_inference_steps (`None`, *optional*, defaults to 50): The number of denoising steps. timesteps (`None`, *optional*): @@ -126,99 +131,64 @@ class StableDiffusion3BeforeDenoiseStep(SequentialPipelineBlocks): Custom sigmas to use for the denoising process. mu (`float`, *optional*): The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. Outputs: latents (`Tensor`): - The prepared latent tensors to be denoised. - timesteps (`Tensor`): - The timesteps schedule for the denoising process. - num_inference_steps (`int`): - The final number of inference steps. + Denoised latents. """ model_name = "stable-diffusion-3" block_classes = [ + StableDiffusion3TextInputStep(), StableDiffusion3PrepareLatentsStep(), StableDiffusion3SetTimestepsStep(), + StableDiffusion3DenoiseStep(), ] - block_names = ["prepare_latents", "set_timesteps"] - - -# auto_docstring -class StableDiffusion3Img2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - """ - Components: - scheduler (`FlowMatchEulerDiscreteScheduler`) - - Inputs: - height (`int`, *optional*): - The height in pixels of the generated image. - width (`int`, *optional*): - The width in pixels of the generated image. - latents (`Tensor | NoneType`, *optional*): - Pre-generated noisy latents to be used as inputs for image generation. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - generator (`None`, *optional*): - One or a list of torch generator(s) to make generation deterministic. - batch_size (`int`): - The batch size for latent generation. - dtype (`dtype`, *optional*): - The data type for the latents. - num_inference_steps (`None`, *optional*, defaults to 50): - The number of denoising steps. - timesteps (`None`, *optional*): - Custom timesteps to use for the denoising process. - sigmas (`None`, *optional*): - Custom sigmas to use for the denoising process. - strength (`None`, *optional*, defaults to 0.6): - Indicates extent to transform the reference image. - mu (`float`, *optional*): - The mu value used for dynamic shifting. If not provided, it is dynamically calculated. - image_latents (`Tensor`): - The image latents encoded by the VAE. - - Outputs: - latents (`Tensor`): - The prepared latent tensors to be denoised. - timesteps (`Tensor`): - The timesteps schedule for the denoising process. - num_inference_steps (`int`): - The final number of inference steps. - initial_noise (`Tensor`): - The initial noise applied to the image latents. - """ + block_names = ["text_inputs", "prepare_latents", "set_timesteps", "denoise"] - model_name = "stable-diffusion-3" - block_classes = [ - StableDiffusion3PrepareLatentsStep(), - StableDiffusion3Img2ImgSetTimestepsStep(), - StableDiffusion3Img2ImgPrepareLatentsStep(), - ] - block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents"] + @property + def outputs(self): + return [OutputParam.template("latents")] # auto_docstring -class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): +class StableDiffusion3I2ICoreDenoiseStep(SequentialPipelineBlocks): """ Components: - scheduler (`FlowMatchEulerDiscreteScheduler`) + scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer + (`SD3Transformer2DModel`) Inputs: - height (`int`, *optional*): + num_images_per_prompt (`None`, *optional*, defaults to 1): + The number of images to generate per prompt. + prompt_embeds (`Tensor`): + Pre-generated text embeddings. + pooled_prompt_embeds (`Tensor`): + Pre-generated pooled text embeddings. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. + height (`None`, *optional*): The height in pixels of the generated image. - width (`int`, *optional*): + width (`None`, *optional*): The width in pixels of the generated image. + image_latents (`None`, *optional*): + Latent input image_latents to be processed. latents (`Tensor | NoneType`, *optional*): Pre-generated noisy latents to be used as inputs for image generation. - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. generator (`None`, *optional*): One or a list of torch generator(s) to make generation deterministic. - batch_size (`int`): - The batch size for latent generation. - dtype (`dtype`, *optional*): - The data type for the latents. num_inference_steps (`None`, *optional*, defaults to 50): The number of denoising steps. timesteps (`None`, *optional*): @@ -229,137 +199,47 @@ class StableDiffusion3AutoBeforeDenoiseStep(AutoPipelineBlocks): Indicates extent to transform the reference image. mu (`float`, *optional*): The mu value used for dynamic shifting. If not provided, it is dynamically calculated. - image_latents (`Tensor`, *optional*): - The image latents encoded by the VAE. - - Outputs: - latents (`Tensor`): - The prepared latent tensors to be denoised. - timesteps (`Tensor`): - The timesteps schedule for the denoising process. - num_inference_steps (`int`): - The final number of inference steps. - initial_noise (`Tensor`): - The initial noise applied to the image latents. - """ - - model_name = "stable-diffusion-3" - block_classes = [ - StableDiffusion3Img2ImgBeforeDenoiseStep, - StableDiffusion3BeforeDenoiseStep, - ] - block_names = ["img2img", "text2image"] - block_trigger_inputs = ["image_latents", None] - - -# auto_docstring -class StableDiffusion3Img2ImgInputStep(SequentialPipelineBlocks): - """ - Inputs: - num_images_per_prompt (`None`, *optional*, defaults to 1): - The number of images to generate per prompt. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary passed along to the AttentionProcessor. guidance_scale (`None`, *optional*, defaults to 7.0): Guidance scale as defined in Classifier-Free Diffusion Guidance. - skip_guidance_layers (`list`, *optional*): - A list of integers that specify layers to skip during guidance. - prompt_embeds (`Tensor`): - Pre-generated text embeddings. - pooled_prompt_embeds (`Tensor`): - Pre-generated pooled text embeddings. - negative_prompt_embeds (`Tensor`, *optional*): - Pre-generated negative text embeddings. - negative_pooled_prompt_embeds (`Tensor`, *optional*): - Pre-generated negative pooled text embeddings. - height (`None`, *optional*): - The height in pixels of the generated image. - width (`None`, *optional*): - The width in pixels of the generated image. - image_latents (`None`, *optional*): - Latent input image_latents to be processed. + skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): + The scale of the guidance for the skipped layers. + skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): + The step fraction at which the guidance for skipped layers stops. + skip_layer_guidance_start (`None`, *optional*, defaults to 0.01): + The step fraction at which the guidance for skipped layers starts. Outputs: - batch_size (`int`): - The batch size for the inference. - dtype (`dtype`): - The expected data type for latents. - do_classifier_free_guidance (`bool`): - Flag indicating if CFG is enabled. - prompt_embeds (`Tensor`): - The processed text embeddings. - pooled_prompt_embeds (`Tensor`): - The processed pooled text embeddings. - negative_prompt_embeds (`Tensor`): - The processed negative text embeddings. - negative_pooled_prompt_embeds (`Tensor`): - The processed negative pooled text embeddings. - image_height (`int`): - The height of the generated image. - image_width (`int`): - The width of the generated image. + latents (`Tensor`): + Denoised latents. """ model_name = "stable-diffusion-3" block_classes = [ StableDiffusion3TextInputStep(), StableDiffusion3AdditionalInputsStep(), + StableDiffusion3PrepareLatentsStep(), + StableDiffusion3Img2ImgSetTimestepsStep(), + StableDiffusion3Img2ImgPrepareLatentsStep(), + StableDiffusion3DenoiseStep(), + ] + block_names = [ + "text_inputs", + "additional_inputs", + "prepare_latents", + "set_timesteps", + "prepare_img2img_latents", + "denoise", ] - block_names = ["text_inputs", "additional_inputs"] - - -# auto_docstring -class StableDiffusion3AutoInputStep(AutoPipelineBlocks): - """ - Inputs: - num_images_per_prompt (`None`, *optional*, defaults to 1): - The number of images to generate per prompt. - guidance_scale (`None`, *optional*, defaults to 7.0): - Guidance scale as defined in Classifier-Free Diffusion Guidance. - skip_guidance_layers (`list`, *optional*): - A list of integers that specify layers to skip during guidance. - prompt_embeds (`Tensor`): - Pre-generated text embeddings. - pooled_prompt_embeds (`Tensor`): - Pre-generated pooled text embeddings. - negative_prompt_embeds (`Tensor`, *optional*): - Pre-generated negative text embeddings. - negative_pooled_prompt_embeds (`Tensor`, *optional*): - Pre-generated negative pooled text embeddings. - height (`None`, *optional*): - The height in pixels of the generated image. - width (`None`, *optional*): - The width in pixels of the generated image. - image_latents (`None`, *optional*): - Latent input image_latents to be processed. - - Outputs: - batch_size (`int`): - The batch size for the inference. - dtype (`dtype`): - The expected data type for latents. - do_classifier_free_guidance (`bool`): - Flag indicating if CFG is enabled. - prompt_embeds (`Tensor`): - The processed text embeddings. - pooled_prompt_embeds (`Tensor`): - The processed pooled text embeddings. - negative_prompt_embeds (`Tensor`): - The processed negative text embeddings. - negative_pooled_prompt_embeds (`Tensor`): - The processed negative pooled text embeddings. - image_height (`int`): - The height of the generated image. - image_width (`int`): - The width of the generated image. - """ - model_name = "stable-diffusion-3" - block_classes = [StableDiffusion3Img2ImgInputStep, StableDiffusion3TextInputStep] - block_names = ["img2img", "text2image"] - block_trigger_inputs = ["image_latents", None] + @property + def outputs(self): + return [OutputParam.template("latents")] # auto_docstring -class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): +class StableDiffusion3AutoCoreDenoiseStep(AutoPipelineBlocks): """ Components: scheduler (`FlowMatchEulerDiscreteScheduler`) guider (`ClassifierFreeGuidance`) transformer @@ -368,10 +248,6 @@ class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): Inputs: num_images_per_prompt (`None`, *optional*, defaults to 1): The number of images to generate per prompt. - guidance_scale (`None`, *optional*, defaults to 7.0): - Guidance scale as defined in Classifier-Free Diffusion Guidance. - skip_guidance_layers (`list`, *optional*): - A list of integers that specify layers to skip during guidance. prompt_embeds (`Tensor`): Pre-generated text embeddings. pooled_prompt_embeds (`Tensor`): @@ -386,13 +262,13 @@ class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): The width in pixels of the generated image. image_latents (`None`, *optional*): Latent input image_latents to be processed. - latents (`Tensor | NoneType`, *optional*): + latents (`Tensor | NoneType`): Pre-generated noisy latents to be used as inputs for image generation. generator (`None`, *optional*): One or a list of torch generator(s) to make generation deterministic. - num_inference_steps (`None`, *optional*, defaults to 50): + num_inference_steps (`None`): The number of denoising steps. - timesteps (`None`, *optional*): + timesteps (`None`): Custom timesteps to use for the denoising process. sigmas (`None`, *optional*): Custom sigmas to use for the denoising process. @@ -402,6 +278,8 @@ class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): The mu value used for dynamic shifting. If not provided, it is dynamically calculated. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary passed along to the AttentionProcessor. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): The scale of the guidance for the skipped layers. skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): @@ -416,11 +294,11 @@ class StableDiffusion3CoreDenoiseStep(SequentialPipelineBlocks): model_name = "stable-diffusion-3" block_classes = [ - StableDiffusion3AutoInputStep, - StableDiffusion3AutoBeforeDenoiseStep, - StableDiffusion3DenoiseStep, + StableDiffusion3I2ICoreDenoiseStep, + StableDiffusion3T2ICoreDenoiseStep, ] - block_names = ["input", "before_denoise", "denoise"] + block_names = ["img2img", "text2image"] + block_trigger_inputs = ["image_latents", None] @property def outputs(self): @@ -431,7 +309,7 @@ def outputs(self): [ ("text_encoder", StableDiffusion3TextEncoderStep()), ("vae_encoder", StableDiffusion3AutoVaeEncoderStep()), - ("denoise", StableDiffusion3CoreDenoiseStep()), + ("denoise", StableDiffusion3AutoCoreDenoiseStep()), ("decode", StableDiffusion3DecodeStep()), ] ) @@ -471,8 +349,6 @@ class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): Pre-generated pooled text embeddings. negative_pooled_prompt_embeds (`Tensor`, *optional*): Pre-generated negative pooled text embeddings. - guidance_scale (`None`, *optional*, defaults to 7.0): - Guidance scale as defined in Classifier-Free Diffusion Guidance. clip_skip (`int`, *optional*): Number of layers to be skipped from CLIP while computing the prompt embeddings. max_sequence_length (`int`, *optional*, defaults to 256): @@ -491,15 +367,13 @@ class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): One or a list of torch generator(s) to make generation deterministic. num_images_per_prompt (`None`, *optional*, defaults to 1): The number of images to generate per prompt. - skip_guidance_layers (`list`, *optional*): - A list of integers that specify layers to skip during guidance. image_latents (`None`, *optional*): Latent input image_latents to be processed. - latents (`Tensor | NoneType`, *optional*): + latents (`Tensor | NoneType`): Pre-generated noisy latents to be used as inputs for image generation. - num_inference_steps (`None`, *optional*, defaults to 50): + num_inference_steps (`None`): The number of denoising steps. - timesteps (`None`, *optional*): + timesteps (`None`): Custom timesteps to use for the denoising process. sigmas (`None`, *optional*): Custom sigmas to use for the denoising process. @@ -507,6 +381,8 @@ class StableDiffusion3AutoBlocks(SequentialPipelineBlocks): Indicates extent to transform the reference image. mu (`float`, *optional*): The mu value used for dynamic shifting. If not provided, it is dynamically calculated. + guidance_scale (`None`, *optional*, defaults to 7.0): + Guidance scale as defined in Classifier-Free Diffusion Guidance. skip_layer_guidance_scale (`None`, *optional*, defaults to 2.8): The scale of the guidance for the skipped layers. skip_layer_guidance_stop (`None`, *optional*, defaults to 0.2): diff --git a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py index 3db63bfa16b0..5519303af592 100644 --- a/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py +++ b/tests/modular_pipelines/stable_diffusion_3/test_modular_pipeline_stable_diffusion_3.py @@ -28,16 +28,12 @@ from ...testing_utils import floats_tensor, torch_device from ..test_modular_pipelines_common import ModularPipelineTesterMixin - SD3_TEXT2IMAGE_WORKFLOWS = { "text2image": [ ("text_encoder", "StableDiffusion3TextEncoderStep"), - ("denoise.input", "StableDiffusion3TextInputStep"), - ( - "denoise.before_denoise.prepare_latents", - "StableDiffusion3PrepareLatentsStep", - ), - ("denoise.before_denoise.set_timesteps", "StableDiffusion3SetTimestepsStep"), + ("denoise.text_inputs", "StableDiffusion3TextInputStep"), + ("denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ("denoise.set_timesteps", "StableDiffusion3SetTimestepsStep"), ("denoise.denoise", "StableDiffusion3DenoiseStep"), ("decode", "StableDiffusion3DecodeStep"), ] @@ -106,18 +102,12 @@ def test_float16_inference(self): ("text_encoder", "StableDiffusion3TextEncoderStep"), ("vae_encoder.preprocess", "StableDiffusion3ProcessImagesInputStep"), ("vae_encoder.encode", "StableDiffusion3VaeEncoderStep"), - ("denoise.input.text_inputs", "StableDiffusion3TextInputStep"), - ("denoise.input.additional_inputs", "StableDiffusion3AdditionalInputsStep"), - ( - "denoise.before_denoise.prepare_latents", - "StableDiffusion3PrepareLatentsStep", - ), - ( - "denoise.before_denoise.set_timesteps", - "StableDiffusion3Img2ImgSetTimestepsStep", - ), + ("denoise.text_inputs", "StableDiffusion3TextInputStep"), + ("denoise.additional_inputs", "StableDiffusion3AdditionalInputsStep"), + ("denoise.prepare_latents", "StableDiffusion3PrepareLatentsStep"), + ("denoise.set_timesteps", "StableDiffusion3Img2ImgSetTimestepsStep"), ( - "denoise.before_denoise.prepare_img2img_latents", + "denoise.prepare_img2img_latents", "StableDiffusion3Img2ImgPrepareLatentsStep", ), ("denoise.denoise", "StableDiffusion3DenoiseStep"), From 3c8efd03a96780cd898b6cb9a952997760c31fd1 Mon Sep 17 00:00:00 2001 From: AlanPonnachan Date: Sun, 5 Apr 2026 06:24:32 +0000 Subject: [PATCH 14/14] fix dtype --- .../stable_diffusion_3/encoders.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py index 8c5c4f6f2273..83b5d592ac55 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_3/encoders.py @@ -82,11 +82,12 @@ def _get_t5_prompt_embeds( max_sequence_length: int = 256, device: torch.device | None = None, joint_attention_dim: int = 4096, + dtype: torch.dtype | None = None, ): device = device or ( text_encoder.device if text_encoder is not None else torch.device("cpu") ) - dtype = text_encoder.dtype if text_encoder is not None else torch.float32 + dtype = dtype or (text_encoder.dtype if text_encoder is not None else torch.float32) prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -135,11 +136,12 @@ def _get_clip_prompt_embeds( device: torch.device | None = None, clip_skip: int | None = None, hidden_size: int = 768, + dtype: torch.dtype | None = None, ): device = device or ( text_encoder.device if text_encoder is not None else torch.device("cpu") ) - dtype = text_encoder.dtype if text_encoder is not None else torch.float32 + dtype = dtype or (text_encoder.dtype if text_encoder is not None else torch.float32) prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) @@ -203,6 +205,16 @@ def encode_prompt( ): device = device or components._execution_device + expected_dtype = None + if components.text_encoder is not None: + expected_dtype = components.text_encoder.dtype + elif components.text_encoder_2 is not None: + expected_dtype = components.text_encoder_2.dtype + elif getattr(components, "transformer", None) is not None: + expected_dtype = components.transformer.dtype + else: + expected_dtype = torch.float32 + if lora_scale is not None and isinstance(components, SD3LoraLoaderMixin): components._lora_scale = lora_scale if components.text_encoder is not None and USE_PEFT_BACKEND: @@ -226,6 +238,7 @@ def encode_prompt( device=device, clip_skip=clip_skip, hidden_size=768, + dtype=expected_dtype, ) prompt_2_embed, pooled_prompt_2_embed = _get_clip_prompt_embeds( components.text_encoder_2, @@ -234,6 +247,7 @@ def encode_prompt( device=device, clip_skip=clip_skip, hidden_size=1280, + dtype=expected_dtype, ) clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) @@ -248,6 +262,7 @@ def encode_prompt( if getattr(components, "transformer", None) is not None else 4096 ), + dtype=expected_dtype, ) clip_prompt_embeds = torch.nn.functional.pad( @@ -286,6 +301,7 @@ def encode_prompt( device=device, clip_skip=None, hidden_size=768, + dtype=expected_dtype, ) negative_prompt_2_embed, negative_pooled_prompt_2_embed = _get_clip_prompt_embeds( components.text_encoder_2, @@ -294,6 +310,7 @@ def encode_prompt( device=device, clip_skip=None, hidden_size=1280, + dtype=expected_dtype, ) negative_clip_prompt_embeds = torch.cat( [negative_prompt_embed, negative_prompt_2_embed], dim=-1 @@ -310,6 +327,7 @@ def encode_prompt( if getattr(components, "transformer", None) is not None else 4096 ), + dtype=expected_dtype, ) negative_clip_prompt_embeds = torch.nn.functional.pad(