Skip to content

fix(ernie_image): pass attn_mask=None when text is unpadded#13804

Open
Ace3Z wants to merge 1 commit into
huggingface:mainfrom
Ace3Z:fix/ernie-image-flashattn-13801
Open

fix(ernie_image): pass attn_mask=None when text is unpadded#13804
Ace3Z wants to merge 1 commit into
huggingface:mainfrom
Ace3Z:fix/ernie-image-flashattn-13801

Conversation

@Ace3Z
Copy link
Copy Markdown

@Ace3Z Ace3Z commented May 25, 2026

Closes #13801.

pipeline.transformer.set_attention_backend("flash") on an ErnieImagePipeline raises ValueError: attn_mask is not yet supported for flash-attn 2., while the equivalent call on ZImagePipeline and Flux2KleinPipeline works fine.

ErnieImageTransformer2DModel.forward was unconditionally building an attention mask (image ones concatenated with the text validity vector) and passing it to dispatch_attention_fn as attn_mask. When every text in the batch is at max length the mask is all ones and isn't doing anything, but the flash-attn 2 backend rejects any non-None attn_mask (the check at attention_dispatch.py:1106).

ZImageTransformer2DModel._build_inputs already does the right thing: it sets attn_mask = None when every sequence is at max_seqlen. Applied the same convention here:

if Tmax > 0 and bool((text_lens < Tmax).any()):
    valid_text = ...
    attention_mask = torch.cat([torch.ones(...), valid_text], dim=1)[:, None, None, :]
else:
    attention_mask = None

ErnieImageSingleStreamAttnProcessor and ErnieImageSharedAdaLNBlock already accept attention_mask: torch.Tensor | None, so this is a transformer level change with no downstream edits.

Added one regression test in the existing test_models_transformer_ernie_image.py. It uses the existing tester config's get_init_dict and get_dummy_inputs (which produce unpadded text by default), intercepts dispatch_attention_fn, and asserts the attn_mask it sees is None. Passes on this branch, fails on main.

For the padded case I ran the forward pass on the same model with the same seed and compared output hashes before vs after the change: byte identical (4442ce01ccddd5bf either way for text_lens=[8,5]). The all-True mask the unfixed code passed is mathematically equivalent to None for attention math, so output is unchanged.

Also stubbed the actual flash-attn 2 backend dispatch check (if attn_mask is not None: raise ValueError) and ran it against the model. On main it reproduces the issue's exact ValueError. On this branch the unpadded path runs clean.

@github-actions github-actions Bot added size/M PR with diff < 200 LOC models tests fixes-issue and removed size/M PR with diff < 200 LOC labels May 25, 2026
@Ace3Z Ace3Z force-pushed the fix/ernie-image-flashattn-13801 branch from e0bcc10 to b42802d Compare May 25, 2026 08:48
@github-actions github-actions Bot added the size/S PR with diff < 50 LOC label May 25, 2026
@Ace3Z Ace3Z force-pushed the fix/ernie-image-flashattn-13801 branch from b42802d to 00080ef Compare May 25, 2026 10:11
`ErnieImageTransformer2DModel.forward` unconditionally builds an
attention mask (image-ones concatenated with the text validity vector)
and forwards it to `dispatch_attention_fn` as `attn_mask`. When all text
sequences are at maximum length the mask is all-ones and unnecessary,
but flash-attn 2 rejects any `attn_mask` argument and raises
`ValueError: attn_mask is not supported for flash-attn 2.` -- so calling
`pipeline.transformer.set_attention_backend("flash")` on an
`ErnieImagePipeline` crashes (issue huggingface#13801), while the equivalent code
on `ZImagePipeline` and `Flux2KleinPipeline` works fine.

Make the mask construction conditional on whether any text is actually
padded (`text_lens.any() < Tmax`). When no padding is needed the mask
is `None`, matching the convention used by
`ZImageTransformer2DModel._build_inputs` (where `attn_mask = None` when
every sequence is at `max_seqlen`). The attention processor and layer
already accept `attention_mask: torch.Tensor | None = None`, so this is
a transformer-level fix with no downstream changes.

Verified numerically: for both unpadded (`text_lens=[8,8]`) and padded
(`text_lens=[8,5]`) inputs the forward pass produces a byte-identical
output tensor before and after the change.

Closes huggingface#13801
@Ace3Z Ace3Z force-pushed the fix/ernie-image-flashattn-13801 branch from 00080ef to 8a9deea Compare May 25, 2026 10:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Incompatibility between FlashAttention and ERNIE Image

1 participant