Skip to content

Enable NVFP4 fused grouped MLP #3048

Open
sraman-rgb wants to merge 17 commits into
NVIDIA:mainfrom
sraman-rgb:nvfp4-grouped-mlp-main
Open

Enable NVFP4 fused grouped MLP #3048
sraman-rgb wants to merge 17 commits into
NVIDIA:mainfrom
sraman-rgb:nvfp4-grouped-mlp-main

Conversation

@sraman-rgb
Copy link
Copy Markdown
Contributor

@sraman-rgb sraman-rgb commented May 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sraman-rgb sraman-rgb requested a review from timmoon10 as a code owner May 27, 2026 17:38
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 27, 2026
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch from 9f3dc12 to e94bfb3 Compare May 27, 2026 17:45
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 27, 2026

Greptile Summary

This PR enables NVFP4 support in the fused grouped MLP forward and backward kernels (ForwardGroupedMLP_CuTeGEMMGLU/Unary and their backward counterparts). The main additions are: NVFP4 quantization and kernel dispatch paths in fuser_forward/fuser_backward, a new ceil_div utility in utils.py, and consolidation of the NVFP4 helpers (_group_quantize_for_grouped_mlp, _nvfp4_amax, _nvfp4_single_tensor_from_grouped) from the forward/backward files into _common.py.

  • fuse_grouped_mlp_ops in _common.py now activates for recipe.nvfp4() in addition to recipe.mxfp8(), with an additional guard that disables NVFP4 fusion when recipe.disable_rht is set, since the NVFP4 path relies on graph-safe grouped quantize that requires RHT.
  • The _ForwardGroupedMLP_CuTeGEMMBase.fuser_forward and _BackwardGroupedMLP_CuTeGEMMDBase.fuser_backward both branch on use_nvfp4 to select between the NVFP4 bfloat16 dequant path (using general_grouped_gemm_for_grouped_tensor) and the existing MXFP8 CuTe DSL quant kernel path for FC2.
  • GroupedLinear.get_quantizer_roles is added to expose per-group quantizer role metadata for the new recipe integration.

Confidence Score: 5/5

Safe to merge; the NVFP4 fused grouped MLP paths are well-structured and previous critical bugs (floor division in scale views, CUDA graph sync) have been fixed in this revision.

The core forward and backward NVFP4 dispatch logic correctly uses ceil_div throughout, the shared NVFP4 utilities are properly consolidated into _common.py, and the single-group fast path and multi-group general-GEMM fallback are both guarded appropriately. The two findings are a missing .to(torch.float32) cast on fc2_alpha_tensor in the backward (dtype asymmetry with the forward, but unlikely to surface if amaxes are always float32 in practice) and a redundant set_usage call in the FC2 discrete-weight loop.

transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py — the fc2_alpha_tensor dtype cast and the overall NVFP4 dgrad/wgrad dispatch path deserve a close read.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Adds NVFP4 quantization dispatch for FC1 fused kernel and FC2 general-grouped-gemm path; includes minor redundant set_usage call in discrete-weight FC2 loop.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Adds NVFP4 dgrad/wgrad dispatch; fc2_alpha_tensor computation lacks .to(torch.float32) present in the forward analog, creating a dtype asymmetry.
transformer_engine/pytorch/ops/_common.py Consolidates _group_quantize_for_grouped_mlp, _nvfp4_amax, _nvfp4_single_tensor_from_grouped from forward/backward files; fuse_grouped_mlp_ops now enables nvfp4 recipe; correct use of ceil_div throughout.
transformer_engine/pytorch/ops/basic/grouped_linear.py Adds get_quantizer_roles method for per-group quantizer metadata and imports QuantizerRole; straightforward addition.
transformer_engine/pytorch/utils.py Adds ceil_div utility with zero-denominator guard; used throughout the NVFP4 scale view computations to fix previous floor-division bugs.

Reviews (14): Last reviewed commit: "Merge branch 'main' into nvfp4-grouped-m..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch from e94bfb3 to 0a5186e Compare May 27, 2026 18:00
@timmoon10 timmoon10 added org-contribution and removed community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. labels May 27, 2026
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch from 0a5186e to 2f83517 Compare May 27, 2026 18:18
@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 27, 2026
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch 2 times, most recently from e412b71 to ff9284e Compare May 27, 2026 20:08
Signed-off-by: sraman-rgb <270218152+sraman-rgb@users.noreply.github.com>
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch from ff9284e to 956a3c6 Compare May 27, 2026 20:25
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM mostly. Have left minor comments. We need to eventually abstract out Dense logic into a different layer called DenseMLP (and that can eventually use fusion of Linear + Swiglu + Linear in te_ops). but that is beyond the scope of this PR.

Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/ops/basic/grouped_linear.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py Outdated
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch from 3a41112 to 6c2ee78 Compare May 28, 2026 00:31
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
sraman-rgb and others added 2 commits May 27, 2026 19:40
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Siddhartha Raman Sundara Raman <sraman@nvidia.com>
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/_common.py
Comment thread transformer_engine/pytorch/ops/_common.py Outdated
Comment thread transformer_engine/pytorch/ops/_common.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py
Comment thread transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py Outdated
@timmoon10
Copy link
Copy Markdown
Member

/te-ci

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@sraman-rgb sraman-rgb requested a review from ksivaman as a code owner May 28, 2026 15:01
Comment thread transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py
sraman-rgb and others added 6 commits May 28, 2026 14:55
Route the NVFP4 RHT grouped MLP test through the shared recipe helpers, compare the fused path against a TE unfused reference, and keep plain NVFP4 on the non-RHT fallback path.

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
These fused ops now support both MXFP8 and NVFP4, so the recipe-specific
suffix is misleading. Rename the four `CuTeGEMM` GLU/Unary forward/backward
classes and update callsites and tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Adds `ceil_div` next to `round_up_to_nearest_multiple` in `pytorch/utils.py`
and replaces the `(x + d - 1) // d` patterns in the fused grouped MLP ops
with it. Also fixes asymmetric floor-divs in the NVFP4 scale-view paths
(`data_(in_)k // k_sf_divisor`) that would underestimate the scale-block
count for padded layouts; the MXFP8 branches already used ceil-div for
the same dimension.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Functions like `_group_quantize_for_grouped_mlp` already constrain their
inputs (e.g. an NVFP4Quantizer always yields an NVFP4Tensor), so the
`getattr(..., default)` fallbacks for `_rowwise_data`/`_with_gemm_swizzled_scales`/etc.
were dead code that obscured intent. Replace those with direct attribute
access, drop the dead double-`getattr` for the non-underscored public name
that doesn't exist, and add brief comments on the type contract. The
`getattr` calls that remain are legitimate (polymorphic inputs, dynamic
attribute names, user-stamped optional flags on `torch.nn.Parameter`).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Comment thread tests/pytorch/test_fusible_ops.py Outdated
Use explicit grouped linear quantizer roles and keep the NVFP4 RHT grouped MLP reference in plain PyTorch, with RHT applied only to reference wgrad.

Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
Comment thread tests/pytorch/test_fusible_ops.py Outdated
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@sraman-rgb sraman-rgb force-pushed the nvfp4-grouped-mlp-main branch from afe0e58 to 8c2404d Compare May 29, 2026 23:55
Signed-off-by: Siddhartha Raman S <sraman@nvidia.com>
@timmoon10
Copy link
Copy Markdown
Member

/te-ci pytorch

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending CI

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. org-contribution

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants