Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions src/diffusers/pipelines/flux2/pipeline_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@
UPSAMPLING_MAX_IMAGE_SIZE = 768**2


def _warn_if_bnb_int8_bfloat16(text_encoder, transformer) -> None:
"""
Warns if the text encoder or transformer is loaded with bitsandbytes 8-bit quantization and bfloat16 dtype.

This combination causes precision loss during MatMul operations resulting in corrupted outputs.
"""
for name, module in [("text_encoder", text_encoder), ("transformer", transformer)]:
if module is None:
continue
if getattr(module, "is_loaded_in_8bit", False):
module_dtype = getattr(module, "dtype", None)
if module_dtype == torch.bfloat16:
logger.warning(
f"`{name}` is loaded with bitsandbytes 8-bit quantization and `torch.bfloat16` dtype. "
"This combination causes precision loss during MatMul operations and can result in corrupted "
"or noisy outputs for FLUX models. It is highly recommended to use NF4 4-bit quantization "
"instead (e.g. `load_in_4bit=True, bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.bfloat16`)."
)


# Adapted from
# https://github.com/black-forest-labs/flux2/blob/5a5d316b1b42f6b59a8c9194b77c8256be848432/src/flux2/text_encoder.py#L68
def format_input(
Expand Down Expand Up @@ -300,6 +320,8 @@ def __init__(
self.system_message_upsampling_i2i = SYSTEM_MESSAGE_UPSAMPLING_I2I
self.upsampling_max_image_size = UPSAMPLING_MAX_IMAGE_SIZE

_warn_if_bnb_int8_bfloat16(text_encoder, transformer)

@staticmethod
def _get_mistral_3_small_prompt_embeds(
text_encoder: Mistral3ForConditionalGeneration,
Expand Down Expand Up @@ -579,12 +601,11 @@ def encode_prompt(
):
device = device or self._execution_device

if prompt is None:
prompt = ""

prompt = [prompt] if isinstance(prompt, str) else prompt

if prompt_embeds is None:
if prompt is None:
prompt = ""
prompt = [prompt] if isinstance(prompt, str) else prompt

prompt_embeds = self._get_mistral_3_small_prompt_embeds(
text_encoder=self.text_encoder,
tokenizer=self.tokenizer,
Expand All @@ -595,6 +616,10 @@ def encode_prompt(
hidden_states_layers=text_encoder_out_layers,
)

# Normalize prompt_embeds dtype and device to match the pipeline's precision.
target_dtype = self.transformer.dtype if self.transformer is not None else self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(device=device, dtype=target_dtype)

batch_size, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
Expand Down
Loading