[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture
Summary
When training Nemotron-3-Nano 30B-A3B with TransformerEngine CUDA graph replay enabled, MoE expert and router gradients are much smaller than in an otherwise identical run where graph replay is delayed beyond the debug window.
This appears related to GroupedLinear.forward() not passing the FP8 graph-capture skip_fp8_weight_update tensor through to _GroupedLinear.forward(). _GroupedLinear.forward() already accepts skip_fp8_weight_update and passes it to quantize_weight(..., skip_update_flag=...), but the module-level GroupedLinear.forward() call site still hardcodes:
None, # skip_fp8_weight_update
By contrast, Linear, LayerNormLinear, and LayerNormMLP retrieve the graph-capture skip tensor when FP8GlobalStateManager.fp8_graph_capturing() is true, force is_first_microbatch=False, and pass the tensor into their autograd functions. This looks similar to the intent of #1854, but GroupedLinear still has the hardcoded None call site on current main and v2.15.
Note: the exact accessor differs by TE version. In v2.12 / v2.14, the other modules use FP8GlobalStateManager.get_skip_fp8_weight_update_tensor(). On current main / v2.15, they use FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor directly.
Environment
- Model: Nemotron-3-Nano 30B-A3B
- Environment location: Docker / NGC container
- Container:
nvcr.io/nvidia/nemo:26.02
- GPU: 1 node x 8x NVIDIA B200
- CUDA: 13.0
- PyTorch: 2.10
- Megatron-Bridge: v0.4.1
- Megatron-Core: v0.17.0rc0
- TransformerEngine: observed with TE 2.12; Megatron-Bridge v0.4.1 resolves TE as
2.14.0+71bbefbf; local source checks show the same GroupedLinear call site is still present on main and v2.15
Expected behavior
Enabling TransformerEngine CUDA graph replay should not materially change gradient numerics. MoE expert, router, and shared-expert gradient norms should be comparable between graph-replay-active and graph-replay-delayed runs.
Observed behavior
Using the Nemotron-3-Nano recipe with:
cfg.model.cuda_graph_impl = "transformer_engine"
cfg.model.cuda_graph_scope = ["mamba", "attn", "moe"]
cfg.model.use_te_rng_tracker = True
cfg.model.moe_token_dispatcher_type = "alltoall"
cfg.model.moe_expert_capacity_factor = 1.0
cfg.model.moe_token_drop_policy = "probs"
cfg.model.moe_pad_expert_input_to_capacity = True
we compared two runs that only differed in CUDA graph warmup:
# Graph replay enabled
model.cuda_graph_warmup_steps: 3
# Baseline: graph replay delayed until after this debug run exits
model.cuda_graph_warmup_steps: 1000
train.exit_interval: 1000
Per-parameter L2 gradient norms, averaged by parameter group, diverged strongly:
| Parameter group |
Graph replay active mean |
Graph replay delayed mean |
Delayed / active |
| MoE experts |
3.14e-07 |
6.33e-05 |
202x |
| MoE router |
6.52e-07 |
2.18e-03 |
3347x |
| Shared expert |
3.35e-04 |
6.36e-03 |
19x |
| Self-attention |
4.54e-04 |
1.95e-02 |
43x |
| Mamba mixer |
1.38e-03 |
4.42e-03 |
3.2x |
| Embedding |
1.45e-01 |
1.18e-01 |
comparable |
Many individual MoE expert weights received exactly zero gradient with CUDA graphs enabled. Router suppression was depth-dependent and reached ~88,000x in early layers.
Suspected mechanism
During CUDA graph capture/replay with cached FP8 weights, the TE graph wrapper sets a dynamic skip flag from the replay-time is_first_microbatch value:
skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]
TE modules are expected to feed the resulting graph-capture tensor into weight quantization so that replay can skip FP8 weight updates after the first microbatch without baking in the captured microbatch state. In the fused wgrad accumulation path, is_first_microbatch also affects whether wgrad overwrites or accumulates into main_grad:
is_first_microbatch=True -> overwrite main_grad
is_first_microbatch=False -> accumulate into main_grad
GroupedLinear currently does not appear to retrieve/pass this graph-capture skip tensor, unlike the other TE linear modules. As a result, MoE expert GroupedLinear can appear to capture fixed first-microbatch / weight-update behavior and replay it for later microbatches.
The local workaround that fixes the GroupedLinear path is:
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = (
FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor
)
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
# pass skip_fp8_weight_update into non_tensor_args instead of None
For older TE versions where the helper exists, the equivalent getter is:
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
Reproducer outline
A minimal reproducer can be built from the public Megatron-Bridge Nemotron-3-Nano recipe:
from megatron.bridge.recipes.nemotronh.nemotron_3_nano import nemotron_3_nano_pretrain_config
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig
cfg = nemotron_3_nano_pretrain_config()
cfg.model.tensor_model_parallel_size = 1
cfg.model.expert_model_parallel_size = 8
cfg.model.expert_tensor_parallel_size = 1
cfg.model.pipeline_model_parallel_size = 1
cfg.model.sequence_parallel = False
cfg.train.global_batch_size = 512
cfg.train.micro_batch_size = 2
cfg.train.exit_interval = 1000
cfg.model.seq_length = 8192
cfg.dataset.seq_length = 8192
cfg.mixed_precision = MixedPrecisionConfig(
bf16=True,
fp8="hybrid",
fp8_recipe="tensorwise",
fp8_wgrad=True,
fp8_dot_product_attention=False,
fp8_multi_head_attention=False,
)
cfg.model.moe_token_dispatcher_type = "alltoall"
cfg.model.moe_expert_capacity_factor = 1.0
cfg.model.moe_token_drop_policy = "probs"
cfg.model.moe_pad_expert_input_to_capacity = True
cfg.model.cuda_graph_impl = "transformer_engine"
cfg.model.cuda_graph_scope = ["mamba", "attn", "moe"]
cfg.model.use_te_rng_tracker = True
# Compare:
cfg.model.cuda_graph_warmup_steps = 3 # reproduces suppression
cfg.model.cuda_graph_warmup_steps = 1000 # baseline, graph replay delayed
Request
Could GroupedLinear be updated to retrieve and pass the FP8 graph-capture skip_fp8_weight_update tensor during CUDA graph capture, matching the behavior of Linear, LayerNormLinear, and LayerNormMLP?
A regression test comparing GroupedLinear FP8 gradients under TE CUDA graph replay vs graph-delayed/eager execution would likely catch this. If possible, the test should cover fuse_wgrad_accumulation=True / main_grad accumulation across repeated microbatches, since that is where the gradient suppression is most visible in the Nemotron MoE run.
[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture
Summary
When training Nemotron-3-Nano 30B-A3B with TransformerEngine CUDA graph replay enabled, MoE expert and router gradients are much smaller than in an otherwise identical run where graph replay is delayed beyond the debug window.
This appears related to
GroupedLinear.forward()not passing the FP8 graph-captureskip_fp8_weight_updatetensor through to_GroupedLinear.forward()._GroupedLinear.forward()already acceptsskip_fp8_weight_updateand passes it toquantize_weight(..., skip_update_flag=...), but the module-levelGroupedLinear.forward()call site still hardcodes:By contrast,
Linear,LayerNormLinear, andLayerNormMLPretrieve the graph-capture skip tensor whenFP8GlobalStateManager.fp8_graph_capturing()is true, forceis_first_microbatch=False, and pass the tensor into their autograd functions. This looks similar to the intent of #1854, butGroupedLinearstill has the hardcodedNonecall site on currentmainandv2.15.Note: the exact accessor differs by TE version. In
v2.12/v2.14, the other modules useFP8GlobalStateManager.get_skip_fp8_weight_update_tensor(). On currentmain/v2.15, they useFP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensordirectly.Environment
nvcr.io/nvidia/nemo:26.022.14.0+71bbefbf; local source checks show the sameGroupedLinearcall site is still present onmainandv2.15Expected behavior
Enabling TransformerEngine CUDA graph replay should not materially change gradient numerics. MoE expert, router, and shared-expert gradient norms should be comparable between graph-replay-active and graph-replay-delayed runs.
Observed behavior
Using the Nemotron-3-Nano recipe with:
we compared two runs that only differed in CUDA graph warmup:
Per-parameter L2 gradient norms, averaged by parameter group, diverged strongly:
3.14e-076.33e-05202x6.52e-072.18e-033347x3.35e-046.36e-0319x4.54e-041.95e-0243x1.38e-034.42e-033.2x1.45e-011.18e-01Many individual MoE expert weights received exactly zero gradient with CUDA graphs enabled. Router suppression was depth-dependent and reached ~88,000x in early layers.
Suspected mechanism
During CUDA graph capture/replay with cached FP8 weights, the TE graph wrapper sets a dynamic skip flag from the replay-time
is_first_microbatchvalue:TE modules are expected to feed the resulting graph-capture tensor into weight quantization so that replay can skip FP8 weight updates after the first microbatch without baking in the captured microbatch state. In the fused wgrad accumulation path,
is_first_microbatchalso affects whether wgrad overwrites or accumulates intomain_grad:GroupedLinearcurrently does not appear to retrieve/pass this graph-capture skip tensor, unlike the other TE linear modules. As a result, MoE expertGroupedLinearcan appear to capture fixed first-microbatch / weight-update behavior and replay it for later microbatches.The local workaround that fixes the
GroupedLinearpath is:For older TE versions where the helper exists, the equivalent getter is:
Reproducer outline
A minimal reproducer can be built from the public Megatron-Bridge Nemotron-3-Nano recipe:
Request
Could
GroupedLinearbe updated to retrieve and pass the FP8 graph-captureskip_fp8_weight_updatetensor during CUDA graph capture, matching the behavior ofLinear,LayerNormLinear, andLayerNormMLP?A regression test comparing
GroupedLinearFP8 gradients under TE CUDA graph replay vs graph-delayed/eager execution would likely catch this. If possible, the test should coverfuse_wgrad_accumulation=True/main_gradaccumulation across repeated microbatches, since that is where the gradient suppression is most visible in the Nemotron MoE run.