From 8a9deea23293d2b183b75b0a226191087742bb0f Mon Sep 17 00:00:00 2001 From: Mahbod Date: Mon, 25 May 2026 02:13:43 +0200 Subject: [PATCH] fix(ernie_image): pass attn_mask=None when text is unpadded `ErnieImageTransformer2DModel.forward` unconditionally builds an attention mask (image-ones concatenated with the text validity vector) and forwards it to `dispatch_attention_fn` as `attn_mask`. When all text sequences are at maximum length the mask is all-ones and unnecessary, but flash-attn 2 rejects any `attn_mask` argument and raises `ValueError: attn_mask is not supported for flash-attn 2.` -- so calling `pipeline.transformer.set_attention_backend("flash")` on an `ErnieImagePipeline` crashes (issue #13801), while the equivalent code on `ZImagePipeline` and `Flux2KleinPipeline` works fine. Make the mask construction conditional on whether any text is actually padded (`text_lens.any() < Tmax`). When no padding is needed the mask is `None`, matching the convention used by `ZImageTransformer2DModel._build_inputs` (where `attn_mask = None` when every sequence is at `max_seqlen`). The attention processor and layer already accept `attention_mask: torch.Tensor | None = None`, so this is a transformer-level fix with no downstream changes. Verified numerically: for both unpadded (`text_lens=[8,8]`) and padded (`text_lens=[8,5]`) inputs the forward pass produces a byte-identical output tensor before and after the change. Closes #13801 --- .../transformers/transformer_ernie_image.py | 18 +++++++++--------- .../test_models_transformer_ernie_image.py | 18 +++++++++++++++++- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index abb79b527589..d4315d40ebf5 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -408,15 +408,15 @@ def forward( ) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) - # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention - valid_text = ( - torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) - if Tmax > 0 - else torch.zeros((B, 0), device=device, dtype=torch.bool) - ) - attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ - :, None, None, : - ] + # Only build the mask when there's real padding. flash-attn 2 rejects + # any non-None attn_mask, so we leave it None for unpadded inputs. + if Tmax > 0 and bool((text_lens < Tmax).any()): + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] + else: + attention_mask = None # AdaLN sample = self.time_proj(timestep) diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py index bff0894df08b..8f085efc59dd 100644 --- a/tests/models/transformers/test_models_transformer_ernie_image.py +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -101,7 +101,23 @@ def get_dummy_inputs(self, height: int = 16, width: int = 16, batch_size: int = class TestErnieImageTransformer(ErnieImageTransformerTesterConfig, ModelTesterMixin): - pass + def test_attention_mask_is_none_when_text_is_unpadded(self): + # Regression for #13801: unpadded text should produce attn_mask=None. + from unittest import mock + + from diffusers.models.transformers import transformer_ernie_image as t + + model = self.model_class(**self.get_init_dict()).to(torch_device).eval() + captured = [] + + def spy(query, *a, attn_mask=None, **k): + captured.append(attn_mask) + return torch.zeros_like(query) + + with torch.no_grad(), mock.patch.object(t, "dispatch_attention_fn", side_effect=spy): + model(**self.get_dummy_inputs()) + + assert captured and all(m is None for m in captured) class TestErnieImageTransformerTraining(ErnieImageTransformerTesterConfig, TrainingTesterMixin):