[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056
[JAX] [PyT] [Common] Enable D=256 BWD cuDNN fused attn for Blackwell CC 10.x #3056KshitijLakhani wants to merge 9 commits into
Conversation
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
51ad582 to
d177ecf
Compare
for more information, see https://pre-commit.ci
…n fused attn Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables cuDNN 9.23 / FE 1.24's dedicated deterministic SDPA backward kernel for head-dim 256 on Blackwell (SM10.x / CC 10.x) GPUs. The C++ backend selector, JAX test skip guards, and PyTorch test cases are all updated in concert.
Confidence Score: 4/5The C++ backend gate is additive and well-guarded; existing paths are unchanged and the new path only activates under a very specific combination of hardware, cuDNN version, and kernel parameters. The core logic in fused_attn.cpp looks correct and conservative. The two test-side issues (a duplicate comment line and a JAX skip-guard that allows some bias configs to pass without actually exercising the new kernel) are minor and do not affect production correctness. No existing behavior is altered. The JAX test's _check_configs bias skip logic (tests/jax/test_fused_attn.py, lines 465-475) deserves a second look to ensure the skip conditions exactly match the C++ gate. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_get_fused_attn_backend called] --> B{dtype FP16/BF16?}
B -- No --> Z[Other backend]
B -- Yes --> C{head_dim_qk/v <= 128?}
C -- Yes --> ARB[NVTE_F16_Arbitrary backend]
C -- No --> D{d_qk==d_v==256 AND is_training\nAND sm_arch in 100-109\nAND cuDNN >= 9.23?}
D -- Yes --> E{bias==NO_BIAS\ndropout==0\nsoftmax==VANILLA?}
E -- No --> SKIP[Skip: Fall through to next condition]
E -- Yes --> F{window_size == -1,-1\nOR causal mask + right_win in -1,0?}
F -- No --> SKIP
F -- Yes --> ARB
D -- No --> G{d_qk/v <= 256 AND Hopper?\nOR Blackwell fprop?\nOR other existing rules}
G -- Yes --> ARB
G -- No --> SKIP
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | ||
| # (for non-causal masks) full-window attention. |
There was a problem hiding this comment.
The comment block ends with a repeated phrase: line 383 (# (for non-causal masks) full-window attention.) is a verbatim fragment of line 382, left over from editing. It should be removed.
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. | |
| # (for non-causal masks) full-window attention. | |
| # vanilla type of softmax only, no dropout, no ALiBi, and (for non-causal masks) full-window attention only. |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| # Non-learnable bias is fine (bias is allowed as an input); only dBias is | ||
| # unsupported. The JAX runner asks for dBias iff the bias shape is [1, h, s, s] | ||
| # (see test_backward), so gate on that. | ||
| unsupported = None | ||
| if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: | ||
| unsupported = "pre-scale bias" | ||
| elif self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: | ||
| unsupported = ( | ||
| "bias gradients (dBias); frozen/non-learnable bias inputs" | ||
| " (i.e. non-1HSS bias shapes) are supported" | ||
| ) |
There was a problem hiding this comment.
JAX skip logic diverges from C++ backend gate for non-1HSS bias
The comment says "frozen/non-learnable bias inputs (i.e. non-1HSS bias shapes) are supported" and the skip block deliberately allows those configs to proceed. However, the C++ gate in fused_attn.cpp requires bias_type == NVTE_NO_BIAS for the new D=256 BWD path, meaning any config with attn_bias_type != NO_BIAS && bias_shape != _1HSS will silently fall back to a different backend rather than exercising the new kernel. The test will not fail, but it also will not validate the D=256 BWD path for those configs, and the inline comment creates a misleading expectation that such configs are actually routed through it.
Description
Support for D=256 BWD for Blackwell CC 10x via the C++ API (which TE uses) was added in cuDNN 9.23 + cuDNN FE 1.24. Enabling this support in TE attention
Type of change
Changes
Add guard when picking the backend (sub backend) in TE common.
Add tests for D=256 case in TE PyT and TE JAX
Checklist: