Enable NVFP4 fused grouped MLP #3048
Conversation
9f3dc12 to
e94bfb3
Compare
Greptile SummaryThis PR enables NVFP4 support in the fused grouped MLP forward and backward kernels (
Confidence Score: 5/5Safe to merge; the NVFP4 fused grouped MLP paths are well-structured and previous critical bugs (floor division in scale views, CUDA graph sync) have been fixed in this revision. The core forward and backward NVFP4 dispatch logic correctly uses ceil_div throughout, the shared NVFP4 utilities are properly consolidated into _common.py, and the single-group fast path and multi-group general-GEMM fallback are both guarded appropriately. The two findings are a missing .to(torch.float32) cast on fc2_alpha_tensor in the backward (dtype asymmetry with the forward, but unlikely to surface if amaxes are always float32 in practice) and a redundant set_usage call in the FC2 discrete-weight loop. transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — the fc2_alpha_tensor dtype cast and the overall NVFP4 dgrad/wgrad dispatch path deserve a close read. Important Files Changed
Reviews (14): Last reviewed commit: "Merge branch 'main' into nvfp4-grouped-m..." | Re-trigger Greptile |
e94bfb3 to
0a5186e
Compare
0a5186e to
2f83517
Compare
e412b71 to
ff9284e
Compare
Signed-off-by: sraman-rgb <270218152+sraman-rgb@users.noreply.github.com>
ff9284e to
956a3c6
Compare
for more information, see https://pre-commit.ci
vthumbe1503
left a comment
There was a problem hiding this comment.
LGTM mostly. Have left minor comments. We need to eventually abstract out Dense logic into a different layer called DenseMLP (and that can eventually use fusion of Linear + Swiglu + Linear in te_ops). but that is beyond the scope of this PR.
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
3a41112 to
6c2ee78
Compare
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
|
/te-ci |
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Route the NVFP4 RHT grouped MLP test through the shared recipe helpers, compare the fused path against a TE unfused reference, and keep plain NVFP4 on the non-RHT fallback path. Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
for more information, see https://pre-commit.ci
These fused ops now support both MXFP8 and NVFP4, so the recipe-specific suffix is misleading. Rename the four `CuTeGEMM` GLU/Unary forward/backward classes and update callsites and tests. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Adds `ceil_div` next to `round_up_to_nearest_multiple` in `pytorch/utils.py` and replaces the `(x + d - 1) // d` patterns in the fused grouped MLP ops with it. Also fixes asymmetric floor-divs in the NVFP4 scale-view paths (`data_(in_)k // k_sf_divisor`) that would underestimate the scale-block count for padded layouts; the MXFP8 branches already used ceil-div for the same dimension. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Functions like `_group_quantize_for_grouped_mlp` already constrain their inputs (e.g. an NVFP4Quantizer always yields an NVFP4Tensor), so the `getattr(..., default)` fallbacks for `_rowwise_data`/`_with_gemm_swizzled_scales`/etc. were dead code that obscured intent. Replace those with direct attribute access, drop the dead double-`getattr` for the non-underscored public name that doesn't exist, and add brief comments on the type contract. The `getattr` calls that remain are legitimate (polymorphic inputs, dynamic attribute names, user-stamped optional flags on `torch.nn.Parameter`). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Tim Moon <tmoon@nvidia.com>
Use explicit grouped linear quantizer roles and keep the NVFP4 RHT grouped MLP reference in plain PyTorch, with RHT applied only to reference wgrad. Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
afe0e58 to
8c2404d
Compare
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
|
/te-ci pytorch |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: