From 2d22e4c3c8a66be28557331394b74abb85537795 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Mon, 26 Jan 2026 23:53:47 +0530 Subject: [PATCH 01/21] Add Flux2KleinInpaintPipeline --- src/diffusers/__init__.py | 2 + src/diffusers/pipelines/__init__.py | 4 +- src/diffusers/pipelines/flux2/__init__.py | 2 + .../flux2/pipeline_flux2_klein_inpaint.py | 1007 +++++++++++++++++ 4 files changed, 1013 insertions(+), 2 deletions(-) create mode 100644 src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 24b9c12db6d4..3d9bbb53d34b 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -482,6 +482,7 @@ "EasyAnimateControlPipeline", "EasyAnimateInpaintPipeline", "EasyAnimatePipeline", + "Flux2KleinInpaintPipeline", "Flux2KleinPipeline", "Flux2Pipeline", "FluxControlImg2ImgPipeline", @@ -1211,6 +1212,7 @@ EasyAnimateControlPipeline, EasyAnimateInpaintPipeline, EasyAnimatePipeline, + Flux2KleinInpaintPipeline, Flux2KleinPipeline, Flux2Pipeline, FluxControlImg2ImgPipeline, diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 65378631a172..dc96ea4dde55 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -130,7 +130,7 @@ ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] - _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"] + _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinInpaintPipeline"] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -678,7 +678,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .flux2 import Flux2KleinPipeline, Flux2Pipeline + from .flux2 import Flux2KleinInpaintPipeline, Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline from .hidream_image import HiDreamImagePipeline from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline diff --git a/src/diffusers/pipelines/flux2/__init__.py b/src/diffusers/pipelines/flux2/__init__.py index f6e1d5206630..93ec5e704caf 100644 --- a/src/diffusers/pipelines/flux2/__init__.py +++ b/src/diffusers/pipelines/flux2/__init__.py @@ -24,6 +24,7 @@ else: _import_structure["pipeline_flux2"] = ["Flux2Pipeline"] _import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"] + _import_structure["pipeline_flux2_klein_inpaint"] = ["Flux2KleinInpaintPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: if not (is_transformers_available() and is_torch_available()): @@ -33,6 +34,7 @@ else: from .pipeline_flux2 import Flux2Pipeline from .pipeline_flux2_klein import Flux2KleinPipeline + from .pipeline_flux2_klein_inpaint import Flux2KleinInpaintPipeline else: import sys diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py new file mode 100644 index 000000000000..b4382ceca414 --- /dev/null +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -0,0 +1,1007 @@ +# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL +import torch +from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM + +from ...image_processor import PipelineImageInput +from ...loaders import Flux2LoraLoaderMixin +from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel +from ...schedulers import FlowMatchEulerDiscreteScheduler +from ...utils import is_torch_xla_available, logging, replace_example_docstring +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import DiffusionPipeline +from .image_processor import Flux2ImageProcessor +from .pipeline_output import Flux2PipelineOutput + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import Flux2KleinInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = Flux2KleinInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench" + >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] + >>> image.save("flux2klein_inpainting.png") + ``` +""" + + +# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + +class Flux2KleinInpaintPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): + r""" + Flux2 Klein pipeline for image inpainting. + + Reference: + [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) + + Args: + transformer ([`Flux2Transformer2DModel`]): + Conditional Transformer (MMDiT) architecture to denoise the encoded image latents. + scheduler ([`FlowMatchEulerDiscreteScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + vae ([`AutoencoderKLFlux2`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`Qwen3ForCausalLM`]): + [Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM) + tokenizer (`Qwen2TokenizerFast`): + Tokenizer of class + [Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast). + """ + + model_cpu_offload_seq = "text_encoder->transformer->vae" + _callback_tensor_inputs = ["latents", "prompt_embeds"] + + def __init__( + self, + scheduler: FlowMatchEulerDiscreteScheduler, + vae: AutoencoderKLFlux2, + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + transformer: Flux2Transformer2DModel, + is_distilled: bool = False, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + scheduler=scheduler, + transformer=transformer, + ) + + self.register_to_config(is_distilled=is_distilled) + + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8 + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 32 + self.image_processor = Flux2ImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.latent_channels + ) + self.mask_processor = Flux2ImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.latent_channels, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + ) + self.tokenizer_max_length = 512 + self.default_sample_size = 128 + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds + def _get_qwen3_prompt_embeds( + text_encoder: Qwen3ForCausalLM, + tokenizer: Qwen2TokenizerFast, + prompt: Union[str, List[str]], + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + max_sequence_length: int = 512, + hidden_states_layers: List[int] = (9, 18, 27), + ): + dtype = text_encoder.dtype if dtype is None else dtype + device = text_encoder.device if device is None else device + + prompt = [prompt] if isinstance(prompt, str) else prompt + + all_input_ids = [] + all_attention_masks = [] + + for single_prompt in prompt: + messages = [{"role": "user", "content": single_prompt}] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + inputs = tokenizer( + text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=max_sequence_length, + ) + + all_input_ids.append(inputs["input_ids"]) + all_attention_masks.append(inputs["attention_mask"]) + + input_ids = torch.cat(all_input_ids, dim=0).to(device) + attention_mask = torch.cat(all_attention_masks, dim=0).to(device) + + # Forward pass through the model + output = text_encoder( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + use_cache=False, + ) + + # Only use outputs from intermediate layers and stack them + out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1) + out = out.to(dtype=dtype, device=device) + + batch_size, num_channels, seq_len, hidden_dim = out.shape + prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim) + + return prompt_embeds + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids + def _prepare_text_ids( + x: torch.Tensor, # (B, L, D) or (L, D) + t_coord: Optional[torch.Tensor] = None, + ): + B, L, _ = x.shape + out_ids = [] + + for i in range(B): + t = torch.arange(1) if t_coord is None else t_coord[i] + h = torch.arange(1) + w = torch.arange(1) + l = torch.arange(L) + + coords = torch.cartesian_prod(t, h, w, l) + out_ids.append(coords) + + return torch.stack(out_ids) + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids + def _prepare_latent_ids( + latents: torch.Tensor, # (B, C, H, W) + ): + r""" + Generates 4D position coordinates (T, H, W, L) for latent tensors. + + Args: + latents (torch.Tensor): + Latent tensor of shape (B, C, H, W) + + Returns: + torch.Tensor: + Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0, + H=[0..H-1], W=[0..W-1], L=0 + """ + + batch_size, _, height, width = latents.shape + + t = torch.arange(1) # [0] - time dimension + h = torch.arange(height) + w = torch.arange(width) + l = torch.arange(1) # [0] - layer dimension + + # Create position IDs: (H*W, 4) + latent_ids = torch.cartesian_prod(t, h, w, l) + + # Expand to batch: (B, H*W, 4) + latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1) + + return latent_ids + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents + def _patchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) + latents = latents.permute(0, 1, 3, 5, 2, 4) + latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents + def _unpatchify_latents(latents): + batch_size, num_channels_latents, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width) + latents = latents.permute(0, 1, 4, 2, 5, 3) + latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2) + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents + def _pack_latents(latents): + """ + pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels) + """ + + batch_size, num_channels, height, width = latents.shape + latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1) + + return latents + + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids + def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]: + """ + using position ids to scatter tokens into place + """ + x_list = [] + for data, pos in zip(x, x_ids): + _, ch = data.shape # noqa: F841 + h_ids = pos[:, 1].to(torch.int64) + w_ids = pos[:, 2].to(torch.int64) + + h = torch.max(h_ids) + 1 + w = torch.max(w_ids) + 1 + + flat_ids = h_ids * w + w_ids + + out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype) + out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data) + + # reshape from (H * W, C) to (H, W, C) and permute to (C, H, W) + + out = out.view(h, w, ch).permute(2, 0, 1) + x_list.append(out) + + return torch.stack(x_list, dim=0) + + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + prompt_embeds: Optional[torch.Tensor] = None, + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (9, 18, 27), + ): + device = device or self._execution_device + + if prompt is None: + prompt = "" + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + prompt_embeds = self._get_qwen3_prompt_embeds( + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + prompt=prompt, + device=device, + max_sequence_length=max_sequence_length, + hidden_states_layers=text_encoder_out_layers, + ) + + batch_size, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + text_ids = self._prepare_text_ids(prompt_embeds) + text_ids = text_ids.to(device) + return prompt_embeds, text_ids + + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): + if image.ndim != 4: + raise ValueError(f"Expected image dims 4, got {image.ndim}.") + + image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax") + image_latents = self._patchify_latents(image_latents) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std + + return image_latents + + def prepare_latents( + self, + image, + timestep, + batch_size, + num_latents_channels, + height, + width, + dtype, + device, + generator: torch.Generator, + latents: Optional[torch.Tensor] = None, + ): + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + # Create a dummy tensor for _prepare_latent_ids + dummy_latents = torch.zeros(shape, device=device, dtype=dtype) + latent_image_ids = self._prepare_latent_ids(dummy_latents) + latent_image_ids = latent_image_ids.to(device) + + image = image.to(device=device, dtype=dtype) + if image.shape[1] != self.latent_channels: + image_latents = self._encode_vae_image(image=image, generator=generator) + else: + image_latents = image + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + # expand init_latents for batch_size + additional_image_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + else: + image_latents = torch.cat([image_latents], dim=0) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.scale_noise(image_latents, timestep, noise) + else: + noise = latents.to(device) + latents = noise + + noise = self._pack_latents(noise) + image_latents = self._pack_latents(image_latents) + latents = self._pack_latents(latents) + return latents, noise, image_latents, latent_image_ids + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + dtype, + device, + generator, + ): + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate(mask, size=(height, width)) + mask = mask.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + + masked_image = masked_image.to(device=device, dtype=dtype) + if masked_image.shape[1] != self.latent_channels: + masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) + else: + masked_image_latents = masked_image + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + masked_image_latents = self._pack_latents(masked_image_latents) + + mask = mask.repeat(1, num_channels_latents, 1, 1) + mask = self._pack_latents(mask) + + return mask, masked_image_latents + + # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def check_inputs( + self, + prompt, + image, + mask_image, + strength, + height, + width, + output_type, + prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + padding_mask_crop=None, + guidance_scale=None, + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if ( + height is not None + and height % (self.vae_scale_factor * 2) != 0 + or width is not None + and width % (self.vae_scale_factor * 2) != 0 + ): + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if padding_mask_crop is not None: + if not isinstance(image, PIL.Image.Image): + raise ValueError( + f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}." + ) + if not isinstance(mask_image, PIL.Image.Image): + raise ValueError( + f"The mask image should be a PIL image when inpainting mask crop, but is of type" + f" {type(mask_image)}." + ) + if output_type != "pil": + raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + + if guidance_scale > 1.0 and self.config.is_distilled: + logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and not self.config.is_distilled + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + masked_image_latents: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + padding_mask_crop: Optional[int] = None, + strength: float = 0.6, + num_inference_steps: int = 50, + sigmas: Optional[List[float]] = None, + guidance_scale: Optional[float] = 8.0, + num_images_per_prompt: int = 1, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 512, + text_encoder_out_layers: Tuple[int] = (9, 18, 27), + ): + r""" + Function invoked when calling the pipeline for inpainting. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + masked_image_latents (`torch.Tensor`, `List[torch.Tensor]`): + `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask + latents tensor will be generated by `mask_image`. + height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + padding_mask_crop (`int`, *optional*, defaults to `None`): + The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to + image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region + with the same aspect ration of the image and contains all masked area, and then expand that area based + on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before + resizing to the original image size for inpainting. This is useful when the masked area is small while + the image is large and contain information irrelevant for inpainting, such as background. + strength (`float`, *optional*, defaults to 0.6): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 8.0): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. For step-wise distilled models, + `guidance_scale` is ignored. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. Note that "" is used as the negative prompt in this pipeline. + If not provided, will be generated from "". + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux2.Flux2PipelineOutput`] instead of a plain tuple. + attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + text_encoder_out_layers (`Tuple[int]`): + Layer indices to use in the `text_encoder` to derive the final prompt embeddings. + + Examples: + + Returns: + [`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`: [`~pipelines.flux2.Flux2PipelineOutput`] if + `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the + generated images. + """ + + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt=prompt, + image=image, + mask_image=mask_image, + strength=strength, + height=height, + width=width, + output_type=output_type, + prompt_embeds=prompt_embeds, + callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, + padding_mask_crop=padding_mask_crop, + guidance_scale=guidance_scale, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Preprocess mask and image + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode + ) + init_image = init_image.to(dtype=torch.float32) + + # 3. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 4. Prepare text embeddings + prompt_embeds, text_ids = self.encode_prompt( + prompt=prompt, + prompt_embeds=prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + if self.do_classifier_free_guidance: + negative_prompt = "" + if prompt is not None and isinstance(prompt, list): + negative_prompt = [negative_prompt] * len(prompt) + negative_prompt_embeds, negative_text_ids = self.encode_prompt( + prompt=negative_prompt, + prompt_embeds=negative_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + text_encoder_out_layers=text_encoder_out_layers, + ) + + # 5. Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas + if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas: + sigmas = None + image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2) + mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # 6. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + + latents, noise, image_latents, latent_image_ids = self.prepare_latents( + init_image, + latent_timestep, + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + mask_condition = self.mask_processor.preprocess( + mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords + ) + + if masked_image_latents is None: + masked_image = init_image * (mask_condition < 0.5) + else: + masked_image = masked_image_latents + + mask, masked_image_latents = self.prepare_mask_latents( + mask_condition, + masked_image, + batch_size, + num_channels_latents, + num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + ) + + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + self._num_timesteps = len(timesteps) + + # 7. Denoising loop + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latents.shape[0]).to(latents.dtype) + + latent_model_input = latents.to(self.transformer.dtype) + + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self.attention_kwargs, + return_dict=False, + )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=None, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] + + init_latents_proper = image_latents + init_mask = mask + + if i < len(timesteps) - 1: + noise_timestep = timesteps[i + 1] + init_latents_proper = self.scheduler.scale_noise( + init_latents_proper, torch.tensor([noise_timestep]), noise + ) + + latents = (1 - init_mask) * init_latents_proper + init_mask * latents + + if latents.dtype != latents_dtype: + if torch.backends.mps.is_available(): + # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + # 8. Post-processing + latents = self._unpack_latents_with_ids(latents, latent_image_ids) + + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + latents.device, latents.dtype + ) + latents = latents * latents_bn_std + latents_bn_mean + latents = self._unpatchify_latents(latents) + + if output_type == "latent": + image = latents + else: + image = self.vae.decode(latents, return_dict=False)[0] + image = self.image_processor.postprocess(image, output_type=output_type) + + if padding_mask_crop is not None: + image = [ + self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image + ] + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return Flux2PipelineOutput(images=image) \ No newline at end of file From d213e59fa81c33628f6d8b71b00c88cae5039da9 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Tue, 27 Jan 2026 00:52:41 +0530 Subject: [PATCH 02/21] Fixed mask channel mismatch and a bit of cleaning --- src/diffusers/pipelines/flux2/image_processor.py | 8 ++++++++ .../pipelines/flux2/pipeline_flux2_klein_inpaint.py | 11 ++++++----- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index f1a8742491f7..ed90d87c5758 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -36,8 +36,12 @@ class Flux2ImageProcessor(VaeImageProcessor): VAE latent channels. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `False`): + Whether to binarize the image to 0/1. do_convert_rgb (`bool`, *optional*, defaults to be `True`): Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. """ @register_to_config @@ -47,14 +51,18 @@ def __init__( vae_scale_factor: int = 16, vae_latent_channels: int = 32, do_normalize: bool = True, + do_binarize: bool = False, do_convert_rgb: bool = True, + do_convert_grayscale: bool = False, ): super().__init__( do_resize=do_resize, vae_scale_factor=vae_scale_factor, vae_latent_channels=vae_latent_channels, do_normalize=do_normalize, + do_binarize=do_binarize, do_convert_rgb=do_convert_rgb, + do_convert_grayscale=do_convert_grayscale ) @staticmethod diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index b4382ceca414..bd3aceee08a9 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -213,6 +213,7 @@ def __init__( vae_latent_channels=self.latent_channels, do_normalize=False, do_binarize=True, + do_convert_rgb=False, do_convert_grayscale=True, ) self.tokenizer_max_length = 512 @@ -510,10 +511,10 @@ def prepare_mask_latents( # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - # resize the mask to latents shape as we concatenate the mask to the latents - # we do that before converting to dtype to avoid breaking in case we're using cpu_offload - # and half precision - mask = torch.nn.functional.interpolate(mask, size=(height, width)) + # resize the mask to patchified latents shape (height // 2, width // 2) since latents + # are patchified before packing. We do that before converting to dtype to avoid breaking + # in case we're using cpu_offload and half precision + mask = torch.nn.functional.interpolate(mask, size=(height // 2, width // 2)) mask = mask.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -546,7 +547,7 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) masked_image_latents = self._pack_latents(masked_image_latents) - mask = mask.repeat(1, num_channels_latents, 1, 1) + mask = mask.repeat(1, num_channels_latents * 4, 1, 1) mask = self._pack_latents(mask) return mask, masked_image_latents From 738ac4338ab431972586bf9cd82a56c48b5baa3b Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Tue, 27 Jan 2026 18:18:43 +0000 Subject: [PATCH 03/21] Added tests and minor refactors --- .../flux2/pipeline_flux2_klein_inpaint.py | 3 +- .../dummy_torch_and_transformers_objects.py | 15 ++ .../test_pipeline_flux2_klein_inpaint.py | 177 ++++++++++++++++++ 3 files changed, 193 insertions(+), 2 deletions(-) create mode 100644 tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index bd3aceee08a9..7640620753e0 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -422,7 +422,6 @@ def encode_prompt( text_ids = text_ids.to(device) return prompt_embeds, text_ids - # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if image.ndim != 4: raise ValueError(f"Expected image dims 4, got {image.ndim}.") @@ -431,7 +430,7 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = self._patchify_latents(image_latents) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(image_latents.device, image_latents.dtype) image_latents = (image_latents - latents_bn_mean) / latents_bn_std return image_latents diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 63f381419fda..5d0e5961a7ff 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -962,6 +962,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class Flux2KleinInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class Flux2KleinPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py new file mode 100644 index 000000000000..c3fe50a9856e --- /dev/null +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py @@ -0,0 +1,177 @@ +import random +import unittest + +import numpy as np +import torch +from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM + +from diffusers import ( + AutoencoderKLFlux2, + FlowMatchEulerDiscreteScheduler, + Flux2KleinInpaintPipeline, + Flux2Transformer2DModel, +) + +from ...testing_utils import ( + enable_full_determinism, + floats_tensor, + torch_device, +) +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = Flux2KleinInpaintPipeline + params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) + batch_params = frozenset(["prompt"]) + + test_xformers_attention = False + test_layerwise_casting = True + test_group_offloading = True + + supports_dduf = False + + def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): + torch.manual_seed(0) + transformer = Flux2Transformer2DModel( + patch_size=1, + in_channels=4, + num_layers=num_layers, + num_single_layers=num_single_layers, + attention_head_dim=16, + num_attention_heads=2, + joint_attention_dim=16, + timestep_guidance_channels=256, + axes_dims_rope=[4, 4, 4, 4], + guidance_embeds=False, + ) + + # Create minimal Qwen3 config + config = Qwen3Config( + intermediate_size=16, + hidden_size=16, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + vocab_size=151936, + max_position_embeddings=512, + ) + torch.manual_seed(0) + text_encoder = Qwen3ForCausalLM(config) + + # Use a simple tokenizer for testing + tokenizer = Qwen2TokenizerFast.from_pretrained( + "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration" + ) + + torch.manual_seed(0) + vae = AutoencoderKLFlux2( + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownEncoderBlock2D",), + up_block_types=("UpDecoderBlock2D",), + block_out_channels=(4,), + layers_per_block=1, + latent_channels=1, + norm_num_groups=1, + use_quant_conv=False, + use_post_quant_conv=False, + ) + + scheduler = FlowMatchEulerDiscreteScheduler() + + return { + "scheduler": scheduler, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "transformer": transformer, + "vae": vae, + } + + def get_dummy_inputs(self, device, seed=0): + image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + mask_image = torch.ones((1, 1, 32, 32)).to(device) + + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "image": image, + "mask_image": mask_image, + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 8.0, + "height": 32, + "width": 32, + "max_sequence_length": 64, + "strength": 0.8, + "output_type": "np", + "text_encoder_out_layers": (1,), + } + return inputs + + def test_flux2_klein_inpaint_different_prompts(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + inputs = self.get_dummy_inputs(torch_device) + output_same_prompt = pipe(**inputs).images[0] + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = "a different prompt" + output_different_prompts = pipe(**inputs).images[0] + + max_diff = np.abs(output_same_prompt - output_different_prompts).max() + + # Outputs should be different here + assert max_diff > 1e-6 + + def test_flux2_klein_inpaint_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 56)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + # Update image and mask to match height/width + image = floats_tensor((1, 3, height, width), rng=random.Random(0)).to(torch_device) + mask_image = torch.ones((1, 1, height, width)).to(torch_device) + + inputs.update({"height": height, "width": width, "image": image, "mask_image": mask_image}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}", + ) + + def test_flux2_klein_inpaint_strength(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + + # Test with strength=1.0 (full denoising) + inputs = self.get_dummy_inputs(torch_device) + inputs["strength"] = 1.0 + output_full_strength = pipe(**inputs).images[0] + + # Test with strength=0.5 (partial denoising) + inputs = self.get_dummy_inputs(torch_device) + inputs["strength"] = 0.5 + output_half_strength = pipe(**inputs).images[0] + + max_diff = np.abs(output_full_strength - output_half_strength).max() + + # Outputs should be different with different strength values + assert max_diff > 1e-6 + + @unittest.skip("Needs to be revisited") + def test_encode_prompt_works_in_isolation(self): + pass From 6fd76dd6489ca5c84f2457f5bbc84215820fec27 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Wed, 28 Jan 2026 18:13:32 +0000 Subject: [PATCH 04/21] Added support for reference images for inpainting --- .../flux2/pipeline_flux2_klein_inpaint.py | 258 ++++++++++++++++-- 1 file changed, 235 insertions(+), 23 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 7640620753e0..142ff538561f 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -43,6 +43,7 @@ EXAMPLE_DOC_STRING = """ Examples: + # Inpainting with text only ```py >>> import torch >>> from diffusers import Flux2KleinInpaintPipeline @@ -60,6 +61,37 @@ >>> image = pipe(prompt=prompt, image=source, mask_image=mask).images[0] >>> image.save("flux2klein_inpainting.png") ``` + + # Inpainting with image reference conditioning + ```py + >>> import torch + >>> from diffusers import Flux2KleinInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = Flux2KleinInpaintPipeline.from_pretrained( + ... "black-forest-labs/FLUX.2-klein-base-9B", torch_dtype=torch.bfloat16 + ... ) + >>> pipe.to("cuda") + + >>> prompt = "Replace this ball" + >>> img_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" + >>> mask_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" + ... ) + >>> image_reference_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" + ... ) + + >>> source = load_image(img_url) + >>> mask = load_image(mask_url) + >>> image_reference = load_image(image_reference_url) + + >>> mask = pipe.mask_processor.blur(mask, blur_factor=12) + >>> image = pipe( + ... prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0 + ... ).images[0] + >>> image.save("flux2klein_inpainting_ref.png") + ``` """ @@ -158,8 +190,8 @@ def retrieve_latents( class Flux2KleinInpaintPipeline(DiffusionPipeline, Flux2LoraLoaderMixin): r""" - Flux2 Klein pipeline for image inpainting. - + Flux2 Klein pipeline for image inpainting with optional reference image conditioning. + Reference: [https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence) @@ -330,6 +362,57 @@ def _prepare_latent_ids( return latent_ids + @staticmethod + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids + def _prepare_image_ids( + image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + scale: int = 10, + ): + r""" + Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents. + + This function creates a unique coordinate for every pixel/patch across all input latent with different + dimensions. + + Args: + image_latents (List[torch.Tensor]): + A list of image latent feature tensors, typically of shape (C, H, W). + scale (int, optional): + A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th + latent is: 'scale + scale * i'. Defaults to 10. + + Returns: + torch.Tensor: + The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all + input latents. + + Coordinate Components (Dimension 4): + - T (Time): The unique index indicating which latent image the coordinate belongs to. + - H (Height): The row index within that latent image. + - W (Width): The column index within that latent image. + - L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1) + """ + + if not isinstance(image_latents, list): + raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") + + # create time offset for each reference image + t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] + t_coords = [t.view(-1) for t in t_coords] + + image_latent_ids = [] + for x, t in zip(image_latents, t_coords): + x = x.squeeze(0) + _, height, width = x.shape + + x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + image_latent_ids.append(x_ids) + + image_latent_ids = torch.cat(image_latent_ids, dim=0) + image_latent_ids = image_latent_ids.unsqueeze(0) + + return image_latent_ids + @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents def _patchify_latents(latents): @@ -386,7 +469,7 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch x_list.append(out) return torch.stack(x_list, dim=0) - + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline.encode_prompt def encode_prompt( self, @@ -421,7 +504,7 @@ def encode_prompt( text_ids = self._prepare_text_ids(prompt_embeds) text_ids = text_ids.to(device) return prompt_embeds, text_ids - + def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if image.ndim != 4: raise ValueError(f"Expected image dims 4, got {image.ndim}.") @@ -453,12 +536,13 @@ def prepare_latents( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - + # VAE applies 8x compression on images but we must also account for packing which requires # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_latents_channels * 4, height // 2, width // 2) + # Create a dummy tensor for _prepare_latent_ids dummy_latents = torch.zeros(shape, device=device, dtype=dtype) latent_image_ids = self._prepare_latent_ids(dummy_latents) @@ -493,6 +577,50 @@ def prepare_latents( latents = self._pack_latents(latents) return latents, noise, image_latents, latent_image_ids + def prepare_image_latents( + self, + images: List[torch.Tensor], + batch_size, + generator: torch.Generator, + device, + dtype, + ): + image_latents = [] + for image in images: + image = image.to(device=device, dtype=dtype) + + if image.shape[1] != self.latent_channels: + image_latent = self._encode_vae_image(image=image, generator=generator) + else: + image_latent = self._patchify_latents(image) + image_latents.append(image_latent) # (1, 128, H//2, W//2) + + image_latent_ids = self._prepare_image_ids(image_latents) + + # Pack each latent and concatenate + packed_latents = [] + for latent in image_latents: + packed = self._pack_latents(latent) # (1, seq_len, 128) + packed = packed.squeeze(0) # (seq_len, 128) - remove batch dim + packed_latents.append(packed) + + # Concatenate all reference tokens along sequence dimension + image_latents = torch.cat(packed_latents, dim=0) # (N*seq_len, 128) + image_latents = image_latents.unsqueeze(0) # (1, N*seq_len, 128) + + if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: + additional_per_prompt = batch_size // image_latents.shape[0] + image_latents = torch.cat([image_latents] * additional_per_prompt, dim=0) + image_latent_ids = torch.cat([image_latent_ids] * additional_per_prompt, dim=0) + elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image_reference` of batch size {image_latents.shape[0]} to {batch_size} text prompts." + ) + + image_latent_ids = image_latent_ids.to(device) + + return image_latents, image_latent_ids + def prepare_mask_latents( self, mask, @@ -545,7 +673,7 @@ def prepare_mask_latents( # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) masked_image_latents = self._pack_latents(masked_image_latents) - + mask = mask.repeat(1, num_channels_latents * 4, 1, 1) mask = self._pack_latents(mask) @@ -579,7 +707,7 @@ def check_inputs( ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - + if ( height is not None and height % (self.vae_scale_factor * 2) != 0 @@ -644,13 +772,14 @@ def num_timesteps(self): @property def interrupt(self): return self._interrupt - + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, image: PipelineImageInput = None, + image_reference: Optional[PipelineImageInput] = None, mask_image: PipelineImageInput = None, masked_image_latents: PipelineImageInput = None, height: Optional[int] = None, @@ -686,6 +815,13 @@ def __call__( or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. + image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*): + `Image`, numpy array or tensor representing an image batch to be used as the reference for the + masked area. This allows conditioning the inpainted region on a specific reference image. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image_reference`, but if passing latents directly it is not encoded again. mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a @@ -794,20 +930,68 @@ def __call__( self._attention_kwargs = attention_kwargs self._interrupt = False - # 2. Preprocess mask and image - if padding_mask_crop is not None: - crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) - resize_mode = "fill" + # 2. Preprocess image + multiple_of = self.vae_scale_factor * 2 + if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + image = torch.cat(image, dim=0) + img = image[0] if isinstance(image, list) else image + image_height, image_width = self.image_processor.get_default_height_width(img) + image_width = image_width // multiple_of * multiple_of + image_height = image_height // multiple_of * multiple_of + image = self.image_processor.resize(image, image_height, image_width) + + # Use the resolution of the input image + width = image_width + height = image_height + + # 2.1 Preprocess mask + if padding_mask_crop is not None: + crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop) + resize_mode = "fill" + else: + crops_coords = None + resize_mode = "default" + + original_image = image + init_image = self.image_processor.preprocess( + image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode + ) else: - crops_coords = None - resize_mode = "default" + raise ValueError("image must be provided correctly for inpainting") - original_image = image - init_image = self.image_processor.preprocess( - image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode - ) init_image = init_image.to(dtype=torch.float32) + # 2.2 Preprocess reference image + processed_image_reference = None + if image_reference is not None and not ( + isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels + ): + if ( + isinstance(image_reference, list) + and isinstance(image_reference[0], torch.Tensor) + and image_reference[0].ndim == 4 + ): + image_reference = torch.cat(image_reference, dim=0) + + img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference + image_reference_height, image_reference_width = self.image_processor.get_default_height_width( + img_reference + ) + image_reference_width = image_reference_width // multiple_of * multiple_of + image_reference_height = image_reference_height // multiple_of * multiple_of + image_reference = self.image_processor.resize( + image_reference, image_reference_height, image_reference_width + ) + processed_image_reference = self.image_processor.preprocess( + image_reference, + image_reference_height, + image_reference_width, + crops_coords=crops_coords, + resize_mode=resize_mode, + ) + processed_image_reference = processed_image_reference.to(dtype=torch.float32) + # 3. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -879,6 +1063,19 @@ def __call__( latents, ) + image_reference_latents = None + image_reference_ids = None + if processed_image_reference is not None: + # Convert preprocessed reference image to list format expected by prepare_image_latents + ref_images = [processed_image_reference[i : i + 1] for i in range(processed_image_reference.shape[0])] + image_reference_latents, image_reference_ids = self.prepare_image_latents( + ref_images, + batch_size * num_images_per_prompt, + generator, + device, + self.vae.dtype, + ) + mask_condition = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) @@ -904,6 +1101,11 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + if image_reference_ids is not None: + combined_image_ids = torch.cat([latent_image_ids, image_reference_ids], dim=1) + else: + combined_image_ids = latent_image_ids + # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): @@ -913,7 +1115,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents.to(self.transformer.dtype) + latent_model_input = latents + img_ids = latent_image_ids + + # Concatenate reference image latents and IDs if provided + if image_reference_latents is not None: + latent_model_input = torch.cat([latent_model_input, image_reference_latents], dim=1) + img_ids = combined_image_ids + + latent_model_input = latent_model_input.to(self.transformer.dtype) with self.transformer.cache_context("cond"): noise_pred = self.transformer( @@ -922,10 +1132,11 @@ def __call__( guidance=None, encoder_hidden_states=prompt_embeds, txt_ids=text_ids, # B, text_seq_len, 4 - img_ids=latent_image_ids, # B, image_seq_len, 4 + img_ids=img_ids, # B, image_seq_len, 4 joint_attention_kwargs=self.attention_kwargs, return_dict=False, )[0] + noise_pred = noise_pred[:, : latents.size(1)] if self.do_classifier_free_guidance: with self.transformer.cache_context("uncond"): @@ -935,11 +1146,12 @@ def __call__( guidance=None, encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, - img_ids=latent_image_ids, + img_ids=img_ids, joint_attention_kwargs=self._attention_kwargs, return_dict=False, )[0] - noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) + neg_noise_pred = neg_noise_pred[:, : latents.size(1)] + noise_pred = neg_noise_pred + self.guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype @@ -1004,4 +1216,4 @@ def __call__( if not return_dict: return (image,) - return Flux2PipelineOutput(images=image) \ No newline at end of file + return Flux2PipelineOutput(images=image) From 2516f06e22e61ad6e300cca34eeccdebbf7dd917 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Wed, 28 Jan 2026 18:17:14 +0000 Subject: [PATCH 05/21] Style fixes --- .../pipelines/flux2/image_processor.py | 2 +- .../flux2/pipeline_flux2_klein_inpaint.py | 24 ++++++++++--------- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/flux2/image_processor.py b/src/diffusers/pipelines/flux2/image_processor.py index ed90d87c5758..1c1e669c58da 100644 --- a/src/diffusers/pipelines/flux2/image_processor.py +++ b/src/diffusers/pipelines/flux2/image_processor.py @@ -62,7 +62,7 @@ def __init__( do_normalize=do_normalize, do_binarize=do_binarize, do_convert_rgb=do_convert_rgb, - do_convert_grayscale=do_convert_grayscale + do_convert_grayscale=do_convert_grayscale, ) @staticmethod diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 142ff538561f..a7754f323175 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -74,13 +74,13 @@ >>> pipe.to("cuda") >>> prompt = "Replace this ball" - >>> img_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" + >>> img_url = ( + ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" + ... ) >>> mask_url = ( ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" ... ) - >>> image_reference_url = ( - ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" - ... ) + >>> image_reference_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" >>> source = load_image(img_url) >>> mask = load_image(mask_url) @@ -513,7 +513,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): image_latents = self._patchify_latents(image_latents) latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) - latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + image_latents.device, image_latents.dtype + ) image_latents = (image_latents - latents_bn_mean) / latents_bn_std return image_latents @@ -816,12 +818,12 @@ def __call__( list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. image_reference (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`, *optional*): - `Image`, numpy array or tensor representing an image batch to be used as the reference for the - masked area. This allows conditioning the inpainted region on a specific reference image. For both - numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list - or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a - list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image - latents as `image_reference`, but if passing latents directly it is not encoded again. + `Image`, numpy array or tensor representing an image batch to be used as the reference for the masked + area. This allows conditioning the inpainted region on a specific reference image. For both numpy array + and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, + the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, + the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as + `image_reference`, but if passing latents directly it is not encoded again. mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a From c44b69af643174bf4401f6064a29bd3dbc0ab502 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Thu, 29 Jan 2026 00:19:52 +0530 Subject: [PATCH 06/21] Fixed the example docstring --- .../pipelines/flux2/pipeline_flux2_klein_inpaint.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index a7754f323175..50cf6be9f988 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -74,13 +74,13 @@ >>> pipe.to("cuda") >>> prompt = "Replace this ball" - >>> img_url = ( - ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/image/example_1.png" - ... ) + >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" >>> mask_url = ( - ... "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/mask/example_1.png" + ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true" + ... ) + >>> image_reference_url = ( + ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s" ... ) - >>> image_reference_url = "https://raw.githubusercontent.com/Fantasy-Studio/Paint-by-Example/main/examples/reference/example_1.jpg" >>> source = load_image(img_url) >>> mask = load_image(mask_url) From 9502d77b59bb4f407192cce7c7982edcfac87d56 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Thu, 29 Jan 2026 14:37:01 +0530 Subject: [PATCH 07/21] Corrected mask latent preparation for correct dimensional alignment --- .../pipelines/flux2/pipeline_flux2_klein_inpaint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 50cf6be9f988..7363903ae3ee 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -640,10 +640,9 @@ def prepare_mask_latents( # latent height and width to be divisible by 2. height = 2 * (int(height) // (self.vae_scale_factor * 2)) width = 2 * (int(width) // (self.vae_scale_factor * 2)) - # resize the mask to patchified latents shape (height // 2, width // 2) since latents - # are patchified before packing. We do that before converting to dtype to avoid breaking - # in case we're using cpu_offload and half precision - mask = torch.nn.functional.interpolate(mask, size=(height // 2, width // 2)) + + # Interpolate to VAE latent size + mask = torch.nn.functional.interpolate(mask, size=(height, width)) mask = mask.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt @@ -676,8 +675,9 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) masked_image_latents = self._pack_latents(masked_image_latents) - mask = mask.repeat(1, num_channels_latents * 4, 1, 1) - mask = self._pack_latents(mask) + mask = mask.repeat(1, self.latent_channels, 1, 1) # Repeat to 128 channels + mask = self._patchify_latents(mask) # Patchify: 128 -> 512 channels, spatial 64->32 + mask = self._pack_latents(mask) # Pack to (B, seq_len, 512) return mask, masked_image_latents From f85ee3b51315ca2ddb10d6a950a3d850a1c2acd8 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sat, 14 Mar 2026 16:36:36 +0530 Subject: [PATCH 08/21] replace masked_image_latents context with clean_source_latents, fix mask spatial alignment and remove unused VAE encoding --- .../flux2/pipeline_flux2_klein_inpaint.py | 70 ++++++------------- 1 file changed, 22 insertions(+), 48 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 7363903ae3ee..3bd33dcb65dd 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -626,34 +626,22 @@ def prepare_image_latents( def prepare_mask_latents( self, mask, - masked_image, batch_size, - num_channels_latents, num_images_per_prompt, height, width, dtype, device, - generator, ): - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (self.vae_scale_factor * 2)) - width = 2 * (int(width) // (self.vae_scale_factor * 2)) - - # Interpolate to VAE latent size - mask = torch.nn.functional.interpolate(mask, size=(height, width)) + # Interpolate the mask directly to the final packed spatial size. + target_h = int(height) // (self.vae_scale_factor * 2) + target_w = int(width) // (self.vae_scale_factor * 2) + mask = torch.nn.functional.interpolate(mask, size=(target_h, target_w), mode="bilinear") mask = mask.to(device=device, dtype=dtype) batch_size = batch_size * num_images_per_prompt - masked_image = masked_image.to(device=device, dtype=dtype) - if masked_image.shape[1] != self.latent_channels: - masked_image_latents = self._encode_vae_image(image=masked_image, generator=generator) - else: - masked_image_latents = masked_image - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + # duplicate mask for each generation per prompt, using mps friendly method if mask.shape[0] < batch_size: if not batch_size % mask.shape[0] == 0: raise ValueError( @@ -662,24 +650,11 @@ def prepare_mask_latents( " of masks that you pass is divisible by the total requested batch size." ) mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1) - # aligning device to prevent device errors when concating it with the latent model input - masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - masked_image_latents = self._pack_latents(masked_image_latents) + # Pack to (B, seq_len, 1), will broadcast against (B, seq_len, C) latents + mask = self._pack_latents(mask) - mask = mask.repeat(1, self.latent_channels, 1, 1) # Repeat to 128 channels - mask = self._patchify_latents(mask) # Patchify: 128 -> 512 channels, spatial 64->32 - mask = self._pack_latents(mask) # Pack to (B, seq_len, 512) - - return mask, masked_image_latents + return mask # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): @@ -1065,6 +1040,14 @@ def __call__( latents, ) + clean_source_latents, clean_source_latent_ids = self.prepare_image_latents( + [init_image], + batch_size * num_images_per_prompt, + generator, + device, + self.vae.dtype, + ) + image_reference_latents = None image_reference_ids = None if processed_image_reference is not None: @@ -1082,31 +1065,23 @@ def __call__( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) - if masked_image_latents is None: - masked_image = init_image * (mask_condition < 0.5) - else: - masked_image = masked_image_latents - - mask, masked_image_latents = self.prepare_mask_latents( + mask = self.prepare_mask_latents( mask_condition, - masked_image, batch_size, - num_channels_latents, num_images_per_prompt, height, width, prompt_embeds.dtype, device, - generator, ) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) + # Always include the clean-source position IDs and append reference IDs when present. + combined_image_ids = torch.cat([latent_image_ids, clean_source_latent_ids], dim=1) if image_reference_ids is not None: - combined_image_ids = torch.cat([latent_image_ids, image_reference_ids], dim=1) - else: - combined_image_ids = latent_image_ids + combined_image_ids = torch.cat([combined_image_ids, image_reference_ids], dim=1) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1117,13 +1092,12 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = latents - img_ids = latent_image_ids + latent_model_input = torch.cat([latents, clean_source_latents], dim=1) + img_ids = combined_image_ids # Concatenate reference image latents and IDs if provided if image_reference_latents is not None: latent_model_input = torch.cat([latent_model_input, image_reference_latents], dim=1) - img_ids = combined_image_ids latent_model_input = latent_model_input.to(self.transformer.dtype) From 9ffef8f95d27787dd9d8d36e15773bc8df9499d9 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sun, 15 Mar 2026 21:04:51 +0530 Subject: [PATCH 09/21] Fix T-coordinate collision for conditioning --- .../flux2/pipeline_flux2_klein_inpaint.py | 34 ++++++------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 3bd33dcb65dd..5376eae7101f 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -1040,27 +1040,19 @@ def __call__( latents, ) - clean_source_latents, clean_source_latent_ids = self.prepare_image_latents( - [init_image], + ref_images = [init_image[i : i + 1] for i in range(init_image.shape[0])] + if processed_image_reference is not None: + # Convert preprocessed reference image to list format + ref_images += [processed_image_reference[i : i + 1] for i in range(processed_image_reference.shape[0])] + + condition_image_latents, condition_image_ids = self.prepare_image_latents( + ref_images, batch_size * num_images_per_prompt, generator, device, self.vae.dtype, ) - image_reference_latents = None - image_reference_ids = None - if processed_image_reference is not None: - # Convert preprocessed reference image to list format expected by prepare_image_latents - ref_images = [processed_image_reference[i : i + 1] for i in range(processed_image_reference.shape[0])] - image_reference_latents, image_reference_ids = self.prepare_image_latents( - ref_images, - batch_size * num_images_per_prompt, - generator, - device, - self.vae.dtype, - ) - mask_condition = self.mask_processor.preprocess( mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords ) @@ -1078,10 +1070,8 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) self._num_timesteps = len(timesteps) - # Always include the clean-source position IDs and append reference IDs when present. - combined_image_ids = torch.cat([latent_image_ids, clean_source_latent_ids], dim=1) - if image_reference_ids is not None: - combined_image_ids = torch.cat([combined_image_ids, image_reference_ids], dim=1) + # Combine base latent position IDs with condition image position IDs. + combined_image_ids = torch.cat([latent_image_ids, condition_image_ids], dim=1) # 7. Denoising loop with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -1092,13 +1082,9 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - latent_model_input = torch.cat([latents, clean_source_latents], dim=1) + latent_model_input = torch.cat([latents, condition_image_latents], dim=1) img_ids = combined_image_ids - # Concatenate reference image latents and IDs if provided - if image_reference_latents is not None: - latent_model_input = torch.cat([latent_model_input, image_reference_latents], dim=1) - latent_model_input = latent_model_input.to(self.transformer.dtype) with self.transformer.cache_context("cond"): From b202267968d8e44eaab52b68b8596249662c772e Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sun, 15 Mar 2026 22:07:28 +0530 Subject: [PATCH 10/21] Changed the default strength from 0.6 to 0.8 --- src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 5376eae7101f..edcfbcc41bea 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -762,7 +762,7 @@ def __call__( height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, - strength: float = 0.6, + strength: float = 0.8, num_inference_steps: int = 50, sigmas: Optional[List[float]] = None, guidance_scale: Optional[float] = 8.0, From d75204c45da3120587832410124060cfce4181e6 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Mon, 16 Mar 2026 00:54:07 +0530 Subject: [PATCH 11/21] Added reference image test and updated the frozenset --- .../test_pipeline_flux2_klein_inpaint.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py index c3fe50a9856e..ee0602d0c26b 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py @@ -25,7 +25,7 @@ class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = Flux2KleinInpaintPipeline - params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) + params = frozenset(["prompt", "image", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"]) batch_params = frozenset(["prompt"]) test_xformers_attention = False @@ -172,6 +172,26 @@ def test_flux2_klein_inpaint_strength(self): # Outputs should be different with different strength values assert max_diff > 1e-6 + def test_flux2_klein_inpaint_image_reference(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + # Add a reference image to the inputs + ref_image = floats_tensor((1, 3, 32, 32), rng=random.Random(1)).to(torch_device) + inputs["image_reference"] = ref_image + + image = pipe(**inputs).images[0] + + expected_height = inputs["height"] - inputs["height"] % (pipe.vae_scale_factor * 2) + expected_width = inputs["width"] - inputs["width"] % (pipe.vae_scale_factor * 2) + + output_height, output_width, _ = image.shape + self.assertEqual( + (output_height, output_width), + (expected_height, expected_width), + f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)} when conditioned on a reference image.", + ) + @unittest.skip("Needs to be revisited") def test_encode_prompt_works_in_isolation(self): pass From bcfba17e9795ddf552723d972344febe47d7ea91 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Tue, 17 Mar 2026 11:54:18 +0530 Subject: [PATCH 12/21] Validated ref image, latent passing support and fixed ref image preprocessing --- .../flux2/pipeline_flux2_klein_inpaint.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index edcfbcc41bea..9b4825c80e0b 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -505,6 +505,7 @@ def encode_prompt( text_ids = text_ids.to(device) return prompt_embeds, text_ids + # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if image.ndim != 4: raise ValueError(f"Expected image dims 4, got {image.ndim}.") @@ -673,6 +674,7 @@ def check_inputs( prompt, image, mask_image, + image_reference, strength, height, width, @@ -727,6 +729,13 @@ def check_inputs( if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + if image_reference is not None: + if not isinstance(image_reference, (PIL.Image.Image, torch.Tensor, np.ndarray, list)): + raise ValueError( + f"`image_reference` has to be of type `PIL.Image.Image`, `torch.Tensor`, `np.ndarray`, or `list`" + f" but is {type(image_reference)}." + ) + if guidance_scale > 1.0 and self.config.is_distilled: logger.warning(f"Guidance scale {guidance_scale} is ignored for step-wise distilled models.") @@ -758,7 +767,6 @@ def __call__( image: PipelineImageInput = None, image_reference: Optional[PipelineImageInput] = None, mask_image: PipelineImageInput = None, - masked_image_latents: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, padding_mask_crop: Optional[int] = None, @@ -806,9 +814,6 @@ def __call__( color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, 1)`, or `(H, W)`. - masked_image_latents (`torch.Tensor`, `List[torch.Tensor]`): - `Tensor` representing an image batch to mask `image` generated by VAE. If not provided, the mask - latents tensor will be generated by `mask_image`. height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): The height in pixels of the generated image. This is set to 1024 by default for the best results. width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor): @@ -820,7 +825,7 @@ def __call__( on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large and contain information irrelevant for inpainting, such as background. - strength (`float`, *optional*, defaults to 0.6): + strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a starting point and more noise is added the higher the `strength`. The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising @@ -893,6 +898,7 @@ def __call__( prompt=prompt, image=image, mask_image=mask_image, + image_reference=image_reference, strength=strength, height=height, width=width, @@ -909,7 +915,12 @@ def __call__( # 2. Preprocess image multiple_of = self.vae_scale_factor * 2 - if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): + if isinstance(image, torch.Tensor) and image.ndim == 4 and image.size(1) == self.latent_channels: + init_image = image + original_image = image + crops_coords = None + resize_mode = "default" + elif image is not None: if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: image = torch.cat(image, dim=0) img = image[0] if isinstance(image, list) else image @@ -964,8 +975,7 @@ def __call__( image_reference, image_reference_height, image_reference_width, - crops_coords=crops_coords, - resize_mode=resize_mode, + resize_mode="crop", ) processed_image_reference = processed_image_reference.to(dtype=torch.float32) @@ -1113,7 +1123,7 @@ def __call__( return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] - noise_pred = neg_noise_pred + self.guidance_scale * (noise_pred - neg_noise_pred) + noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype From b8aaa196fee9cf7d855e0a470ce7ba7e5ec3bef8 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Fri, 20 Mar 2026 11:16:22 +0530 Subject: [PATCH 13/21] Refined preprocessing with 1MP resolution cap and timestep tracking --- .../flux2/pipeline_flux2_klein_inpaint.py | 65 +++++++++++++++---- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 9b4825c80e0b..6216644603f9 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -13,6 +13,7 @@ # limitations under the License. import inspect +import math from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np @@ -250,6 +251,7 @@ def __init__( ) self.tokenizer_max_length = 512 self.default_sample_size = 128 + self._current_timestep = None @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds @@ -470,6 +472,27 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch return torch.stack(x_list, dim=0) + @staticmethod + def _get_raw_image_size(image: PipelineImageInput) -> Tuple[int, int]: + """Helper to get (height, width) without rounding/scaling.""" + if isinstance(image, list): + image = image[0] + + if isinstance(image, PIL.Image.Image): + return image.height, image.width + elif isinstance(image, torch.Tensor): + return image.shape[-2], image.shape[-1] + elif isinstance(image, np.ndarray): + return ( + image.shape[-3] if image.ndim > 3 else image.shape[-2], + image.shape[-2] if image.ndim > 3 else image.shape[-1], + ) + + if hasattr(image, "shape"): + return image.shape[-2], image.shape[-1] + + raise ValueError(f"Unsupported image type: {type(image)}") + # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline.encode_prompt def encode_prompt( self, @@ -755,6 +778,10 @@ def attention_kwargs(self): def num_timesteps(self): return self._num_timesteps + @property + def current_timestep(self): + return self._current_timestep + @property def interrupt(self): return self._interrupt @@ -924,10 +951,16 @@ def __call__( if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: image = torch.cat(image, dim=0) img = image[0] if isinstance(image, list) else image - image_height, image_width = self.image_processor.get_default_height_width(img) - image_width = image_width // multiple_of * multiple_of - image_height = image_height // multiple_of * multiple_of - image = self.image_processor.resize(image, image_height, image_width) + raw_h, raw_w = self._get_raw_image_size(img) + + if raw_h * raw_w > 1024 * 1024: + scale = math.sqrt(1024 * 1024 / (raw_h * raw_w)) + image = self.image_processor.resize(image, int(raw_h * scale), int(raw_w * scale)) + img = image[0] if isinstance(image, list) else image + raw_h, raw_w = self._get_raw_image_size(img) + + image_width = (raw_w // multiple_of) * multiple_of + image_height = (raw_h // multiple_of) * multiple_of # Use the resolution of the input image width = image_width @@ -963,14 +996,19 @@ def __call__( image_reference = torch.cat(image_reference, dim=0) img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference - image_reference_height, image_reference_width = self.image_processor.get_default_height_width( - img_reference - ) - image_reference_width = image_reference_width // multiple_of * multiple_of - image_reference_height = image_reference_height // multiple_of * multiple_of - image_reference = self.image_processor.resize( - image_reference, image_reference_height, image_reference_width - ) + raw_ref_h, raw_ref_w = self._get_raw_image_size(img_reference) + + if raw_ref_h * raw_ref_w > 1024 * 1024: + scale = math.sqrt(1024 * 1024 / (raw_ref_h * raw_ref_w)) + image_reference = self.image_processor.resize( + image_reference, int(raw_ref_h * scale), int(raw_ref_w * scale) + ) + img_reference = image_reference[0] if isinstance(image_reference, list) else image_reference + raw_ref_h, raw_ref_w = self._get_raw_image_size(img_reference) + + image_reference_width = (raw_ref_w // multiple_of) * multiple_of + image_reference_height = (raw_ref_h // multiple_of) * multiple_of + processed_image_reference = self.image_processor.preprocess( image_reference, image_reference_height, @@ -1089,6 +1127,7 @@ def __call__( if self.interrupt: continue + self._current_timestep = t # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) @@ -1161,6 +1200,8 @@ def __call__( if XLA_AVAILABLE: xm.mark_step() + self._current_timestep = None + # 8. Post-processing latents = self._unpack_latents_with_ids(latents, latent_image_ids) From edcbaba1dce932526b5993b2c791db7d86437890 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sat, 21 Mar 2026 10:53:52 +0530 Subject: [PATCH 14/21] Updated typing, improved validation and changed the example docstring --- .../flux2/pipeline_flux2_klein_inpaint.py | 97 +++++++++++-------- .../test_pipeline_flux2_klein_inpaint.py | 2 +- 2 files changed, 56 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 6216644603f9..0271956d814f 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -14,7 +14,7 @@ import inspect import math -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable import numpy as np import PIL @@ -75,12 +75,12 @@ >>> pipe.to("cuda") >>> prompt = "Replace this ball" - >>> img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500" + >>> img_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/the-ball-stadion-football-the-pitch-39362.jpeg" >>> mask_url = ( - ... "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true" + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/ball_mask.png" ... ) >>> image_reference_url = ( - ... "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s" + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/ball.jpg" ... ) >>> source = load_image(img_url) @@ -118,10 +118,10 @@ def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - sigmas: Optional[List[float]] = None, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, **kwargs, ): r""" @@ -177,7 +177,7 @@ def retrieve_timesteps( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( - encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" + encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample" ): if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": return encoder_output.latent_dist.sample(generator) @@ -258,11 +258,11 @@ def __init__( def _get_qwen3_prompt_embeds( text_encoder: Qwen3ForCausalLM, tokenizer: Qwen2TokenizerFast, - prompt: Union[str, List[str]], - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, + prompt: str | list[str], + dtype: torch.dtype | None = None, + device: torch.device | None = None, max_sequence_length: int = 512, - hidden_states_layers: List[int] = (9, 18, 27), + hidden_states_layers: list[int] | tuple[int, ...] = (9, 18, 27), ): dtype = text_encoder.dtype if dtype is None else dtype device = text_encoder.device if device is None else device @@ -315,7 +315,7 @@ def _get_qwen3_prompt_embeds( # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids def _prepare_text_ids( x: torch.Tensor, # (B, L, D) or (L, D) - t_coord: Optional[torch.Tensor] = None, + t_coord: torch.Tensor | None = None, ): B, L, _ = x.shape out_ids = [] @@ -367,7 +367,7 @@ def _prepare_latent_ids( @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids def _prepare_image_ids( - image_latents: List[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] scale: int = 10, ): r""" @@ -377,7 +377,7 @@ def _prepare_image_ids( dimensions. Args: - image_latents (List[torch.Tensor]): + image_latents (list[torch.Tensor]): A list of image latent feature tensors, typically of shape (C, H, W). scale (int, optional): A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th @@ -473,7 +473,7 @@ def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch return torch.stack(x_list, dim=0) @staticmethod - def _get_raw_image_size(image: PipelineImageInput) -> Tuple[int, int]: + def _get_raw_image_size(image: PipelineImageInput) -> tuple[int, int]: """Helper to get (height, width) without rounding/scaling.""" if isinstance(image, list): image = image[0] @@ -496,12 +496,12 @@ def _get_raw_image_size(image: PipelineImageInput) -> Tuple[int, int]: # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline.encode_prompt def encode_prompt( self, - prompt: Union[str, List[str]], - device: Optional[torch.device] = None, + prompt: str | list[str], + device: torch.device | None = None, num_images_per_prompt: int = 1, - prompt_embeds: Optional[torch.Tensor] = None, + prompt_embeds: torch.Tensor | None = None, max_sequence_length: int = 512, - text_encoder_out_layers: Tuple[int] = (9, 18, 27), + text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), ): device = device or self._execution_device @@ -554,8 +554,8 @@ def prepare_latents( width, dtype, device, - generator: torch.Generator, - latents: Optional[torch.Tensor] = None, + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, ): if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -605,7 +605,7 @@ def prepare_latents( def prepare_image_latents( self, - images: List[torch.Tensor], + images: list[torch.Tensor], batch_size, generator: torch.Generator, device, @@ -751,6 +751,19 @@ def check_inputs( ) if output_type != "pil": raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.") + else: + if image is not None: + if not isinstance(image, (PIL.Image.Image, torch.Tensor, np.ndarray, list)): + raise ValueError( + f"`image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `np.ndarray`, or `list`" + f" but is {type(image)}." + ) + if mask_image is not None: + if not isinstance(mask_image, (PIL.Image.Image, torch.Tensor, np.ndarray, list)): + raise ValueError( + f"`mask_image` has to be of type `PIL.Image.Image`, `torch.Tensor`, `np.ndarray`, or `list`" + f" but is {type(mask_image)}." + ) if image_reference is not None: if not isinstance(image_reference, (PIL.Image.Image, torch.Tensor, np.ndarray, list)): @@ -790,29 +803,29 @@ def interrupt(self): @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]] = None, - image: PipelineImageInput = None, - image_reference: Optional[PipelineImageInput] = None, - mask_image: PipelineImageInput = None, - height: Optional[int] = None, - width: Optional[int] = None, - padding_mask_crop: Optional[int] = None, + prompt: str | list[str] | None = None, + image: PipelineImageInput | None = None, + image_reference: PipelineImageInput | None = None, + mask_image: PipelineImageInput | None = None, + height: int | None = None, + width: int | None = None, + padding_mask_crop: int | None = None, strength: float = 0.8, num_inference_steps: int = 50, - sigmas: Optional[List[float]] = None, - guidance_scale: Optional[float] = 8.0, + sigmas: list[float] | None = None, + guidance_scale: float = 8.0, num_images_per_prompt: int = 1, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.Tensor] = None, - prompt_embeds: Optional[torch.Tensor] = None, - negative_prompt_embeds: Optional[torch.Tensor] = None, - output_type: Optional[str] = "pil", + generator: torch.Generator | list[torch.Generator] | None = None, + latents: torch.Tensor | None = None, + prompt_embeds: torch.Tensor | None = None, + negative_prompt_embeds: torch.Tensor | None = None, + output_type: str = "pil", return_dict: bool = True, - attention_kwargs: Optional[Dict[str, Any]] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], + attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] = ["latents"], max_sequence_length: int = 512, - text_encoder_out_layers: Tuple[int] = (9, 18, 27), + text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), ): r""" Function invoked when calling the pipeline for inpainting. diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py index ee0602d0c26b..a8be061e1d00 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py @@ -25,7 +25,7 @@ class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = Flux2KleinInpaintPipeline - params = frozenset(["prompt", "image", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"]) + params = frozenset(["prompt", "image", "image_reference", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"]) batch_params = frozenset(["prompt"]) test_xformers_attention = False From 5e4f20e02d27131e1151f396a9535a39128d68be Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Sat, 21 Mar 2026 05:55:51 +0000 Subject: [PATCH 15/21] Style fixes --- src/diffusers/pipelines/__init__.py | 9 +++++++-- .../pipelines/flux2/pipeline_flux2_klein_inpaint.py | 10 +++++----- .../utils/dummy_torch_and_transformers_objects.py | 8 ++++---- .../flux2/test_pipeline_flux2_klein_inpaint.py | 12 +++++++----- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 52da80b58fc4..3d6929c57f61 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -129,7 +129,12 @@ ] _import_structure["bria"] = ["BriaPipeline"] _import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"] - _import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinInpaintPipeline", "Flux2KleinKVPipeline"] + _import_structure["flux2"] = [ + "Flux2Pipeline", + "Flux2KleinPipeline", + "Flux2KleinInpaintPipeline", + "Flux2KleinKVPipeline", + ] _import_structure["flux"] = [ "FluxControlPipeline", "FluxControlInpaintPipeline", @@ -671,7 +676,7 @@ FluxPriorReduxPipeline, ReduxImageEncoder, ) - from .flux2 import Flux2KleinKVPipeline, Flux2KleinInpaintPipeline, Flux2KleinPipeline, Flux2Pipeline + from .flux2 import Flux2KleinInpaintPipeline, Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline from .glm_image import GlmImagePipeline from .helios import HeliosPipeline, HeliosPyramidPipeline from .hidream_image import HiDreamImagePipeline diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 0271956d814f..707ca1796376 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -136,15 +136,15 @@ def retrieve_timesteps( must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): + timesteps (`list[int]`, *optional*): Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, `num_inference_steps` and `sigmas` must be `None`. - sigmas (`List[float]`, *optional*): + sigmas (`list[float]`, *optional*): Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, `num_inference_steps` and `timesteps` must be `None`. Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + `tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None and sigmas is not None: @@ -262,7 +262,7 @@ def _get_qwen3_prompt_embeds( dtype: torch.dtype | None = None, device: torch.device | None = None, max_sequence_length: int = 512, - hidden_states_layers: list[int] | tuple[int, ...] = (9, 18, 27), + hidden_states_layers: list[int] = (9, 18, 27), ): dtype = text_encoder.dtype if dtype is None else dtype device = text_encoder.device if device is None else device @@ -501,7 +501,7 @@ def encode_prompt( num_images_per_prompt: int = 1, prompt_embeds: torch.Tensor | None = None, max_sequence_length: int = 512, - text_encoder_out_layers: tuple[int, ...] = (9, 18, 27), + text_encoder_out_layers: tuple[int] = (9, 18, 27), ): device = device or self._execution_device diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index c6da694d5989..8c39a8cbe7cd 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -1202,7 +1202,7 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) -class Flux2KleinKVPipeline(metaclass=DummyObject): +class Flux2KleinInpaintPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): @@ -1215,9 +1215,9 @@ def from_config(cls, *args, **kwargs): @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) - - -class Flux2KleinInpaintPipeline(metaclass=DummyObject): + + +class Flux2KleinKVPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs): diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py index a8be061e1d00..4dcf01e49b69 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py @@ -25,7 +25,9 @@ class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): pipeline_class = Flux2KleinInpaintPipeline - params = frozenset(["prompt", "image", "image_reference", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"]) + params = frozenset( + ["prompt", "image", "image_reference", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"] + ) batch_params = frozenset(["prompt"]) test_xformers_attention = False @@ -175,16 +177,16 @@ def test_flux2_klein_inpaint_strength(self): def test_flux2_klein_inpaint_image_reference(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) - + # Add a reference image to the inputs ref_image = floats_tensor((1, 3, 32, 32), rng=random.Random(1)).to(torch_device) inputs["image_reference"] = ref_image - + image = pipe(**inputs).images[0] - + expected_height = inputs["height"] - inputs["height"] % (pipe.vae_scale_factor * 2) expected_width = inputs["width"] - inputs["width"] % (pipe.vae_scale_factor * 2) - + output_height, output_width, _ = image.shape self.assertEqual( (output_height, output_width), From 60e1ed2b7f3871431954ce6735c4dacc00218c10 Mon Sep 17 00:00:00 2001 From: Aditya Borate Date: Thu, 26 Mar 2026 17:33:37 +0000 Subject: [PATCH 16/21] Fixed batch inference discrepancy and addressed review comments --- .../flux2/pipeline_flux2_klein_inpaint.py | 100 ++++++++++-------- .../test_pipeline_flux2_klein_inpaint.py | 2 +- 2 files changed, 59 insertions(+), 43 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 707ca1796376..2a3b1bbb77a0 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -251,7 +251,6 @@ def __init__( ) self.tokenizer_max_length = 512 self.default_sample_size = 128 - self._current_timestep = None @staticmethod # Copied from diffusers.pipelines.flux2.pipeline_flux2_klein.Flux2KleinPipeline._get_qwen3_prompt_embeds @@ -365,9 +364,9 @@ def _prepare_latent_ids( return latent_ids @staticmethod - # Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids def _prepare_image_ids( - image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...] + image_latents: list[torch.Tensor], # list of (B_i, C, H, W) before packing + batch_size: int, scale: int = 10, ): r""" @@ -398,20 +397,34 @@ def _prepare_image_ids( if not isinstance(image_latents, list): raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.") - # create time offset for each reference image - t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))] - t_coords = [t.view(-1) for t in t_coords] - - image_latent_ids = [] - for x, t in zip(image_latents, t_coords): - x = x.squeeze(0) - _, height, width = x.shape + all_image_latent_ids = [] + t_offset = scale + for x in image_latents: + b_i, _, height, width = x.shape + # Create IDs for a single image at this t_offset + t = torch.tensor([t_offset]).view(-1) x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) - image_latent_ids.append(x_ids) - image_latent_ids = torch.cat(image_latent_ids, dim=0) - image_latent_ids = image_latent_ids.unsqueeze(0) + if b_i == 1 or b_i == batch_size: + x_ids = x_ids.unsqueeze(0).expand(batch_size, -1, -1) + all_image_latent_ids.append(x_ids) + t_offset += scale + else: + # multiple images per sample in the batch + item_ids = [x_ids] + for _ in range(1, b_i): + t_offset += scale + t = torch.tensor([t_offset]).view(-1) + item_ids.append( + torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1)) + ) + x_ids = torch.cat(item_ids, dim=0) # (b_i * h * w, 4) + x_ids = x_ids.unsqueeze(0).expand(batch_size, -1, -1) + all_image_latent_ids.append(x_ids) + t_offset += scale + + image_latent_ids = torch.cat(all_image_latent_ids, dim=1) return image_latent_ids @@ -483,10 +496,9 @@ def _get_raw_image_size(image: PipelineImageInput) -> tuple[int, int]: elif isinstance(image, torch.Tensor): return image.shape[-2], image.shape[-1] elif isinstance(image, np.ndarray): - return ( - image.shape[-3] if image.ndim > 3 else image.shape[-2], - image.shape[-2] if image.ndim > 3 else image.shape[-1], - ) + if image.ndim >= 3: + return image.shape[-3], image.shape[-2] + return image.shape[-2], image.shape[-1] if hasattr(image, "shape"): return image.shape[-2], image.shape[-1] @@ -619,29 +631,29 @@ def prepare_image_latents( image_latent = self._encode_vae_image(image=image, generator=generator) else: image_latent = self._patchify_latents(image) - image_latents.append(image_latent) # (1, 128, H//2, W//2) + image_latents.append(image_latent) - image_latent_ids = self._prepare_image_ids(image_latents) + image_latent_ids = self._prepare_image_ids(image_latents, batch_size) - # Pack each latent and concatenate - packed_latents = [] + # Pack each latent and combine batch properly + final_latents = [] for latent in image_latents: - packed = self._pack_latents(latent) # (1, seq_len, 128) - packed = packed.squeeze(0) # (seq_len, 128) - remove batch dim - packed_latents.append(packed) + packed = self._pack_latents(latent) # (B_i, seq_len, 128) + b_i = packed.shape[0] - # Concatenate all reference tokens along sequence dimension - image_latents = torch.cat(packed_latents, dim=0) # (N*seq_len, 128) - image_latents = image_latents.unsqueeze(0) # (1, N*seq_len, 128) + if b_i == 1 and batch_size > 1: + packed = packed.repeat(batch_size, 1, 1) + elif b_i == batch_size: + pass + else: + # Concatenate all reference tokens along sequence dimension for each sample + seq_len = packed.shape[1] + packed = packed.reshape(1, b_i * seq_len, -1) + if batch_size > 1: + packed = packed.repeat(batch_size, 1, 1) + final_latents.append(packed) - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - additional_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_per_prompt, dim=0) - image_latent_ids = torch.cat([image_latent_ids] * additional_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image_reference` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) + image_latents = torch.cat(final_latents, dim=1) # (batch_size, total_seq_len, 128) image_latent_ids = image_latent_ids.to(device) @@ -707,6 +719,12 @@ def check_inputs( padding_mask_crop=None, guidance_scale=None, ): + if image is None: + raise ValueError("`image` has to be provided for inpainting.") + + if mask_image is None: + raise ValueError("`mask_image` has to be provided for inpainting.") + if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") @@ -951,6 +969,7 @@ def __call__( self._guidance_scale = guidance_scale self._attention_kwargs = attention_kwargs + self._current_timestep = None self._interrupt = False # 2. Preprocess image @@ -991,8 +1010,6 @@ def __call__( init_image = self.image_processor.preprocess( image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode ) - else: - raise ValueError("image must be provided correctly for inpainting") init_image = init_image.to(dtype=torch.float32) @@ -1101,17 +1118,16 @@ def __call__( latents, ) - ref_images = [init_image[i : i + 1] for i in range(init_image.shape[0])] + ref_images = [init_image] if processed_image_reference is not None: - # Convert preprocessed reference image to list format - ref_images += [processed_image_reference[i : i + 1] for i in range(processed_image_reference.shape[0])] + ref_images.append(processed_image_reference) condition_image_latents, condition_image_ids = self.prepare_image_latents( ref_images, batch_size * num_images_per_prompt, generator, device, - self.vae.dtype, + prompt_embeds.dtype, ) mask_condition = self.mask_processor.preprocess( diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py index 4dcf01e49b69..62fbea8fc3c9 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py @@ -28,7 +28,7 @@ class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase) params = frozenset( ["prompt", "image", "image_reference", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"] ) - batch_params = frozenset(["prompt"]) + batch_params = frozenset(["prompt", "image", "mask_image"]) test_xformers_attention = False test_layerwise_casting = True From a026f0f75be1786d8a21ef9d4d7c45d67cce4186 Mon Sep 17 00:00:00 2001 From: Aditya Borate Date: Thu, 26 Mar 2026 23:08:16 +0530 Subject: [PATCH 17/21] Fixed a typo MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 2a3b1bbb77a0..0c4cf3b87c78 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -1097,7 +1097,7 @@ def __call__( if num_inference_steps < 1: raise ValueError( - f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline" + f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline " f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) From e8f590b1e8cf1c09b5f4b0252fa12b1fccdb4c48 Mon Sep 17 00:00:00 2001 From: Aditya Borate Date: Wed, 1 Apr 2026 20:34:28 +0530 Subject: [PATCH 18/21] Apply suggestion from @asomoza MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Álvaro Somoza --- src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 0c4cf3b87c78..1277a64287b8 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -1187,7 +1187,7 @@ def __call__( encoder_hidden_states=negative_prompt_embeds, txt_ids=negative_text_ids, img_ids=img_ids, - joint_attention_kwargs=self._attention_kwargs, + joint_attention_kwargs=self.attention_kwargs, return_dict=False, )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] From 2d83f134a92de1037eff7413e8f16cc2456f9b33 Mon Sep 17 00:00:00 2001 From: Aditya Borate <23110065@iitgn.ac.in> Date: Wed, 1 Apr 2026 20:38:58 +0530 Subject: [PATCH 19/21] Reused encoded latents and fix channel check consistency --- .../pipelines/flux2/pipeline_flux2_klein_inpaint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 1277a64287b8..fc0c6a7ee2b4 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -611,9 +611,9 @@ def prepare_latents( latents = noise noise = self._pack_latents(noise) - image_latents = self._pack_latents(image_latents) + packed_image_latents = self._pack_latents(image_latents) latents = self._pack_latents(latents) - return latents, noise, image_latents, latent_image_ids + return latents, noise, packed_image_latents, image_latents, latent_image_ids def prepare_image_latents( self, @@ -627,10 +627,10 @@ def prepare_image_latents( for image in images: image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: + if image.shape[1] != self.latent_channels * 4: image_latent = self._encode_vae_image(image=image, generator=generator) else: - image_latent = self._patchify_latents(image) + image_latent = image image_latents.append(image_latent) image_latent_ids = self._prepare_image_ids(image_latents, batch_size) @@ -1105,7 +1105,7 @@ def __call__( # 6. Prepare latent variables num_channels_latents = self.transformer.config.in_channels // 4 - latents, noise, image_latents, latent_image_ids = self.prepare_latents( + latents, noise, image_latents, image_latents_encoded, latent_image_ids = self.prepare_latents( init_image, latent_timestep, batch_size * num_images_per_prompt, @@ -1118,7 +1118,7 @@ def __call__( latents, ) - ref_images = [init_image] + ref_images = [image_latents_encoded] if processed_image_reference is not None: ref_images.append(processed_image_reference) From 41d8a98f085277e93e97c5250d02cc8b100fca02 Mon Sep 17 00:00:00 2001 From: Aditya Borate Date: Thu, 9 Apr 2026 19:57:09 +0000 Subject: [PATCH 20/21] fixed pre-encoded latent preprocessing for source and ref images --- .../flux2/pipeline_flux2_klein_inpaint.py | 25 +++++++++++++------ .../test_pipeline_flux2_klein_inpaint.py | 2 +- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index fc0c6a7ee2b4..43e759c48c95 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -587,10 +587,15 @@ def prepare_latents( latent_image_ids = latent_image_ids.to(device) image = image.to(device=device, dtype=dtype) - if image.shape[1] != self.latent_channels: + if image.shape[1] != self.latent_channels * 4: image_latents = self._encode_vae_image(image=image, generator=generator) else: image_latents = image + latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype) + latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + image_latents.device, image_latents.dtype + ) + image_latents = (image_latents - latents_bn_mean) / latents_bn_std if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: # expand init_latents for batch_size @@ -600,8 +605,6 @@ def prepare_latents( raise ValueError( f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." ) - else: - image_latents = torch.cat([image_latents], dim=0) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) @@ -974,11 +977,13 @@ def __call__( # 2. Preprocess image multiple_of = self.vae_scale_factor * 2 - if isinstance(image, torch.Tensor) and image.ndim == 4 and image.size(1) == self.latent_channels: + if isinstance(image, torch.Tensor) and image.ndim == 4 and image.size(1) == self.latent_channels * 4: init_image = image original_image = image crops_coords = None resize_mode = "default" + height = image.shape[2] * self.vae_scale_factor * 2 + width = image.shape[3] * self.vae_scale_factor * 2 elif image is not None: if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: image = torch.cat(image, dim=0) @@ -1011,12 +1016,10 @@ def __call__( image, image_height, image_width, crops_coords=crops_coords, resize_mode=resize_mode ) - init_image = init_image.to(dtype=torch.float32) - # 2.2 Preprocess reference image processed_image_reference = None if image_reference is not None and not ( - isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels + isinstance(image_reference, torch.Tensor) and image_reference.size(1) == self.latent_channels * 4 ): if ( isinstance(image_reference, list) @@ -1045,7 +1048,13 @@ def __call__( image_reference_width, resize_mode="crop", ) - processed_image_reference = processed_image_reference.to(dtype=torch.float32) + else: + if image_reference is not None: + bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_reference.device, image_reference.dtype) + bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to( + image_reference.device, image_reference.dtype + ) + processed_image_reference = (image_reference - bn_mean) / bn_std # 3. Define call parameters if prompt is not None and isinstance(prompt, str): diff --git a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py index 62fbea8fc3c9..807dcdda13bf 100644 --- a/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py +++ b/tests/pipelines/flux2/test_pipeline_flux2_klein_inpaint.py @@ -28,7 +28,7 @@ class Flux2KleinInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase) params = frozenset( ["prompt", "image", "image_reference", "mask_image", "height", "width", "guidance_scale", "prompt_embeds"] ) - batch_params = frozenset(["prompt", "image", "mask_image"]) + batch_params = frozenset(["prompt", "image", "image_reference", "mask_image"]) test_xformers_attention = False test_layerwise_casting = True From eac2a72adc7fb6d9c9c03fc495e2199fa6875c35 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 9 Apr 2026 20:30:53 +0000 Subject: [PATCH 21/21] Apply style fixes --- src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py index 43e759c48c95..38ccf8890452 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2_klein_inpaint.py @@ -76,9 +76,7 @@ >>> prompt = "Replace this ball" >>> img_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/the-ball-stadion-football-the-pitch-39362.jpeg" - >>> mask_url = ( - ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/ball_mask.png" - ... ) + >>> mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/ball_mask.png" >>> image_reference_url = ( ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/klein_inpaint/ball.jpg" ... )