fix(ernie_image): pass attn_mask=None when text is unpadded#13804
Open
Ace3Z wants to merge 1 commit into
Open
Conversation
e0bcc10 to
b42802d
Compare
b42802d to
00080ef
Compare
`ErnieImageTransformer2DModel.forward` unconditionally builds an
attention mask (image-ones concatenated with the text validity vector)
and forwards it to `dispatch_attention_fn` as `attn_mask`. When all text
sequences are at maximum length the mask is all-ones and unnecessary,
but flash-attn 2 rejects any `attn_mask` argument and raises
`ValueError: attn_mask is not supported for flash-attn 2.` -- so calling
`pipeline.transformer.set_attention_backend("flash")` on an
`ErnieImagePipeline` crashes (issue huggingface#13801), while the equivalent code
on `ZImagePipeline` and `Flux2KleinPipeline` works fine.
Make the mask construction conditional on whether any text is actually
padded (`text_lens.any() < Tmax`). When no padding is needed the mask
is `None`, matching the convention used by
`ZImageTransformer2DModel._build_inputs` (where `attn_mask = None` when
every sequence is at `max_seqlen`). The attention processor and layer
already accept `attention_mask: torch.Tensor | None = None`, so this is
a transformer-level fix with no downstream changes.
Verified numerically: for both unpadded (`text_lens=[8,8]`) and padded
(`text_lens=[8,5]`) inputs the forward pass produces a byte-identical
output tensor before and after the change.
Closes huggingface#13801
00080ef to
8a9deea
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #13801.
pipeline.transformer.set_attention_backend("flash")on anErnieImagePipelineraisesValueError: attn_mask is not yet supported for flash-attn 2., while the equivalent call onZImagePipelineandFlux2KleinPipelineworks fine.ErnieImageTransformer2DModel.forwardwas unconditionally building an attention mask (image ones concatenated with the text validity vector) and passing it todispatch_attention_fnasattn_mask. When every text in the batch is at max length the mask is all ones and isn't doing anything, but the flash-attn 2 backend rejects any non-Noneattn_mask(the check atattention_dispatch.py:1106).ZImageTransformer2DModel._build_inputsalready does the right thing: it setsattn_mask = Nonewhen every sequence is atmax_seqlen. Applied the same convention here:ErnieImageSingleStreamAttnProcessorandErnieImageSharedAdaLNBlockalready acceptattention_mask: torch.Tensor | None, so this is a transformer level change with no downstream edits.Added one regression test in the existing
test_models_transformer_ernie_image.py. It uses the existing tester config'sget_init_dictandget_dummy_inputs(which produce unpadded text by default), interceptsdispatch_attention_fn, and asserts theattn_maskit sees is None. Passes on this branch, fails on main.For the padded case I ran the forward pass on the same model with the same seed and compared output hashes before vs after the change: byte identical (
4442ce01ccddd5bfeither way fortext_lens=[8,5]). The all-True mask the unfixed code passed is mathematically equivalent to None for attention math, so output is unchanged.Also stubbed the actual flash-attn 2 backend dispatch check (
if attn_mask is not None: raise ValueError) and ran it against the model. On main it reproduces the issue's exactValueError. On this branch the unpadded path runs clean.