Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/diffusers/models/transformers/transformer_ernie_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,15 @@ def forward(
)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))

# Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention
valid_text = (
torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1)
if Tmax > 0
else torch.zeros((B, 0), device=device, dtype=torch.bool)
)
attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[
:, None, None, :
]
# Attention mask: True = valid (attend), False = padding (mask out), matches sdpa bool convention.
# Skip building it when there's no padding so flash-attn (which rejects non-None masks) can run.
if Tmax > 0 and not bool((text_lens == Tmax).all()):
valid_text = torch.arange(Tmax, device=device).view(1, Tmax) < text_lens.view(B, 1)
attention_mask = torch.cat([torch.ones((B, N_img), device=device, dtype=torch.bool), valid_text], dim=1)[
:, None, None, :
]
else:
attention_mask = None

# AdaLN
sample = self.time_proj(timestep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from ...testing_utils import torch_device
from ..testing_utils import (
AttentionBackendTesterMixin,
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
Expand Down Expand Up @@ -130,3 +131,7 @@ def test_compile_works_with_aot(self, tmp_path):
@pytest.mark.skip(reason="Fullgraph is broken.")
def test_compile_on_different_shapes(self):
super().test_compile_on_different_shapes()


class TestErnieImageTransformerAttentionBackend(ErnieImageTransformerTesterConfig, AttentionBackendTesterMixin):
"""Attention backend tests for ErnieImage Transformer."""
Loading