diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index abb79b527589..d2f30fe24573 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -408,15 +408,15 @@ def forward( ) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) - # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention - valid_text = ( - torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) - if Tmax > 0 - else torch.zeros((B, 0), device=device, dtype=torch.bool) - ) - attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ - :, None, None, : - ] + # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention. + # Skip building it when there's no padding so flash-attn (which rejects non-None masks) can run. + if Tmax > 0 and not bool((text_lens == Tmax).all()): + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] + else: + attention_mask = None # AdaLN sample = self.time_proj(timestep) diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py index bff0894df08b..156000bb3835 100644 --- a/tests/models/transformers/test_models_transformer_ernie_image.py +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -23,6 +23,7 @@ from ...testing_utils import torch_device from ..testing_utils import ( + AttentionBackendTesterMixin, BaseModelTesterConfig, ModelTesterMixin, TorchCompileTesterMixin, @@ -130,3 +131,7 @@ def test_compile_works_with_aot(self, tmp_path): @pytest.mark.skip(reason="Fullgraph is broken.") def test_compile_on_different_shapes(self): super().test_compile_on_different_shapes() + + +class TestErnieImageTransformerAttentionBackend(ErnieImageTransformerTesterConfig, AttentionBackendTesterMixin): + """Attention backend tests for ErnieImage Transformer."""