diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index b1645b4ae244..9421fef37d53 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -60,6 +60,26 @@ UPSAMPLING_MAX_IMAGE_SIZE = 768**2 +def _warn_if_bnb_int8_bfloat16(text_encoder, transformer) -> None: + """ + Warns if the text encoder or transformer is loaded with bitsandbytes 8-bit quantization and bfloat16 dtype. + + This combination causes precision loss during MatMul operations resulting in corrupted outputs. + """ + for name, module in [("text_encoder", text_encoder), ("transformer", transformer)]: + if module is None: + continue + if getattr(module, "is_loaded_in_8bit", False): + module_dtype = getattr(module, "dtype", None) + if module_dtype == torch.bfloat16: + logger.warning( + f"`{name}` is loaded with bitsandbytes 8-bit quantization and `torch.bfloat16` dtype. " + "This combination causes precision loss during MatMul operations and can result in corrupted " + "or noisy outputs for FLUX models. It is highly recommended to use NF4 4-bit quantization " + "instead (e.g. `load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.bfloat16`)." + ) + + # Adapted from # https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68 def format_input( @@ -300,6 +320,8 @@ def __init__( self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE + _warn_if_bnb_int8_bfloat16(text_encoder, transformer) + @staticmethod def _get_mistral_3_small_prompt_embeds( text_encoder: Mistral3ForConditionalGeneration, @@ -579,12 +601,11 @@ def encode_prompt( ): device = device or self._execution_device - if prompt is None: - prompt = "" - - prompt = [prompt] if isinstance(prompt, str) else prompt - if prompt_embeds is None: + if prompt is None: + prompt = "" + prompt = [prompt] if isinstance(prompt, str) else prompt + prompt_embeds = self._get_mistral_3_small_prompt_embeds( text_encoder=self.text_encoder, tokenizer=self.tokenizer, @@ -595,6 +616,10 @@ def encode_prompt( hidden_states_layers=text_encoder_out_layers, ) + # Normalize prompt_embeds dtype and device to match the pipeline's precision. + target_dtype = self.transformer.dtype if self.transformer is not None else self.text_encoder.dtype + prompt_embeds = prompt_embeds.to(device=device, dtype=target_dtype) + batch_size, seq_len, _ = prompt_embeds.shape prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)