From 2832c19a63da086e767efbfd2e90e514fcb24b41 Mon Sep 17 00:00:00 2001 From: Satyam Ashtikar Date: Sat, 23 May 2026 15:52:00 +0530 Subject: [PATCH] fix(flux2): warn on bnb int8+bfloat16 and fix encode_prompt dtype Two related fixes for Flux2Pipeline: 1. Add a warning for bitsandbytes 8-bit + bfloat16 quantization. This combination causes precision loss and corrupted images in FLUX models. The warning alerts users immediately at pipeline initialization and suggests using NF4 4-bit quantization instead. 2. Fix encode_prompt() for pre-computed embedding workflows. - Skips prompt string formatting if embeddings are already provided. - Automatically casts pre-computed embeddings to the exact precision (dtype) expected by the pipeline. This prevents silent image corruption when loading embeddings from a different pipeline. --- .../pipelines/flux2/pipeline_flux2.py | 35 ++++++++++++++++--- 1 file changed, 30 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index b1645b4ae244..9421fef37d53 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -60,6 +60,26 @@ UPSAMPLING_MAX_IMAGE_SIZE = 768**2 +def _warn_if_bnb_int8_bfloat16(text_encoder, transformer) -> None: + """ + Warns if the text encoder or transformer is loaded with bitsandbytes 8-bit quantization and bfloat16 dtype. + + This combination causes precision loss during MatMul operations resulting in corrupted outputs. + """ + for name, module in [("text_encoder", text_encoder), ("transformer", transformer)]: + if module is None: + continue + if getattr(module, "is_loaded_in_8bit", False): + module_dtype = getattr(module, "dtype", None) + if module_dtype == torch.bfloat16: + logger.warning( + f"`{name}` is loaded with bitsandbytes 8-bit quantization and `torch.bfloat16` dtype. " + "This combination causes precision loss during MatMul operations and can result in corrupted " + "or noisy outputs for FLUX models. It is highly recommended to use NF4 4-bit quantization " + "instead (e.g. `load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.bfloat16`)." + ) + + # Adapted from # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 def format_input( @@ -300,6 +320,8 @@ def __init__( self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE + _warn_if_bnb_int8_bfloat16(text_encoder, transformer) + @staticmethod def _get_mistral_3_small_prompt_embeds( text_encoder: Mistral3ForConditionalGeneration, @@ -579,12 +601,11 @@ def encode_prompt( ): device = device or self._execution_device - if prompt is None: - prompt = "" - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt_embeds is None: + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._get_mistral_3_small_prompt_embeds( text_encoder=self.text_encoder, tokenizer=self.tokenizer, @@ -595,6 +616,10 @@ def encode_prompt( hidden_states_layers=text_encoder_out_layers, ) + # Normalize prompt_embeds dtype and device to match the pipeline's precision. + target_dtype = self.transformer.dtype if self.transformer is not None else self.text_encoder.dtype + prompt_embeds = prompt_embeds.to(device=device, dtype=target_dtype) + batch_size, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)