diff --git a/src/diffusers/models/transformers/transformer_ernie_image.py b/src/diffusers/models/transformers/transformer_ernie_image.py index abb79b527589..d4315d40ebf5 100644 --- a/src/diffusers/models/transformers/transformer_ernie_image.py +++ b/src/diffusers/models/transformers/transformer_ernie_image.py @@ -408,15 +408,15 @@ def forward( ) rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) - # Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention - valid_text = ( - torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) - if Tmax > 0 - else torch.zeros((B, 0), device=device, dtype=torch.bool) - ) - attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ - :, None, None, : - ] + # Only build the mask when there's real padding. flash-attn 2 rejects + # any non-None attn_mask, so we leave it None for unpadded inputs. + if Tmax > 0 and bool((text_lens < Tmax).any()): + valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1) + attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[ + :, None, None, : + ] + else: + attention_mask = None # AdaLN sample = self.time_proj(timestep) diff --git a/tests/models/transformers/test_models_transformer_ernie_image.py b/tests/models/transformers/test_models_transformer_ernie_image.py index bff0894df08b..8f085efc59dd 100644 --- a/tests/models/transformers/test_models_transformer_ernie_image.py +++ b/tests/models/transformers/test_models_transformer_ernie_image.py @@ -101,7 +101,23 @@ def get_dummy_inputs(self, height: int = 16, width: int = 16, batch_size: int = class TestErnieImageTransformer(ErnieImageTransformerTesterConfig, ModelTesterMixin): - pass + def test_attention_mask_is_none_when_text_is_unpadded(self): + # Regression for #13801: unpadded text should produce attn_mask=None. + from unittest import mock + + from diffusers.models.transformers import transformer_ernie_image as t + + model = self.model_class(**self.get_init_dict()).to(torch_device).eval() + captured = [] + + def spy(query, *a, attn_mask=None, **k): + captured.append(attn_mask) + return torch.zeros_like(query) + + with torch.no_grad(), mock.patch.object(t, "dispatch_attention_fn", side_effect=spy): + model(**self.get_dummy_inputs()) + + assert captured and all(m is None for m in captured) class TestErnieImageTransformerTraining(ErnieImageTransformerTesterConfig, TrainingTesterMixin):