From 84ae4b402372698fd176318e1187e3d50a84ce6a Mon Sep 17 00:00:00 2001 From: Mahbod Date: Mon, 25 May 2026 10:56:59 +0200 Subject: [PATCH] Skip building attn_mask when text_lens are uniform (#13801) ErnieImageTransformer2DModel.forward built a bool attention mask from text_lens on every call, including the common case where every sample already has full-length text. flash-attn 2 rejects any non-None attn_mask, so set_attention_backend('flash') crashed even though the all-True mask was effectively a no-op. Z-Image's _prepare_for_attention takes the same shortcut. Closes #13801 --- .../transformers/transformer_ernie_image.py | 18 +++++++++--------- .../test_models_transformer_ernie_image.py | 5 +++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index abb79b527589..d2f30fe24573 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -408,15 +408,15 @@ def forward( ) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) - # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention - valid_text = ( - torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) - if Tmax > 0 - else torch.zeros((B, 0), device=device, dtype=torch.bool) - ) - attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ - :, None, None, : - ] + # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention. + # Skip building it when there's no padding so flash-attn (which rejects non-None masks) can run. + if Tmax > 0 and not bool((text_lens == Tmax).all()): + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] + else: + attention_mask = None # AdaLN sample = self.time_proj(timestep) diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py index bff0894df08b..156000bb3835 100644 --- a/tests/models/transformers/test_models_transformer_ernie_image.py +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -23,6 +23,7 @@ from ...testing_utils import torch_device from ..testing_utils import ( + AttentionBackendTesterMixin, BaseModelTesterConfig, ModelTesterMixin, TorchCompileTesterMixin, @@ -130,3 +131,7 @@ def test_compile_works_with_aot(self, tmp_path): @pytest.mark.skip(reason="Fullgraph is broken.") def test_compile_on_different_shapes(self): super().test_compile_on_different_shapes() + + +class TestErnieImageTransformerAttentionBackend(ErnieImageTransformerTesterConfig, AttentionBackendTesterMixin): + """Attention backend tests for ErnieImage Transformer."""