Skip to content

[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture #3051

@allenphilipj

Description

@allenphilipj

[PyTorch] GroupedLinear does not propagate skip_fp8_weight_update during FP8 CUDA graph capture

Summary

When training Nemotron-3-Nano 30B-A3B with TransformerEngine CUDA graph replay enabled, MoE expert and router gradients are much smaller than in an otherwise identical run where graph replay is delayed beyond the debug window.

This appears related to GroupedLinear.forward() not passing the FP8 graph-capture skip_fp8_weight_update tensor through to _GroupedLinear.forward(). _GroupedLinear.forward() already accepts skip_fp8_weight_update and passes it to quantize_weight(..., skip_update_flag=...), but the module-level GroupedLinear.forward() call site still hardcodes:

None,  # skip_fp8_weight_update

By contrast, Linear, LayerNormLinear, and LayerNormMLP retrieve the graph-capture skip tensor when FP8GlobalStateManager.fp8_graph_capturing() is true, force is_first_microbatch=False, and pass the tensor into their autograd functions. This looks similar to the intent of #1854, but GroupedLinear still has the hardcoded None call site on current main and v2.15.

Note: the exact accessor differs by TE version. In v2.12 / v2.14, the other modules use FP8GlobalStateManager.get_skip_fp8_weight_update_tensor(). On current main / v2.15, they use FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor directly.

Environment

  • Model: Nemotron-3-Nano 30B-A3B
  • Environment location: Docker / NGC container
  • Container: nvcr.io/nvidia/nemo:26.02
  • GPU: 1 node x 8x NVIDIA B200
  • CUDA: 13.0
  • PyTorch: 2.10
  • Megatron-Bridge: v0.4.1
  • Megatron-Core: v0.17.0rc0
  • TransformerEngine: observed with TE 2.12; Megatron-Bridge v0.4.1 resolves TE as 2.14.0+71bbefbf; local source checks show the same GroupedLinear call site is still present on main and v2.15

Expected behavior

Enabling TransformerEngine CUDA graph replay should not materially change gradient numerics. MoE expert, router, and shared-expert gradient norms should be comparable between graph-replay-active and graph-replay-delayed runs.

Observed behavior

Using the Nemotron-3-Nano recipe with:

cfg.model.cuda_graph_impl = "transformer_engine"
cfg.model.cuda_graph_scope = ["mamba", "attn", "moe"]
cfg.model.use_te_rng_tracker = True
cfg.model.moe_token_dispatcher_type = "alltoall"
cfg.model.moe_expert_capacity_factor = 1.0
cfg.model.moe_token_drop_policy = "probs"
cfg.model.moe_pad_expert_input_to_capacity = True

we compared two runs that only differed in CUDA graph warmup:

# Graph replay enabled
model.cuda_graph_warmup_steps: 3

# Baseline: graph replay delayed until after this debug run exits
model.cuda_graph_warmup_steps: 1000
train.exit_interval: 1000

Per-parameter L2 gradient norms, averaged by parameter group, diverged strongly:

Parameter group Graph replay active mean Graph replay delayed mean Delayed / active
MoE experts 3.14e-07 6.33e-05 202x
MoE router 6.52e-07 2.18e-03 3347x
Shared expert 3.35e-04 6.36e-03 19x
Self-attention 4.54e-04 1.95e-02 43x
Mamba mixer 1.38e-03 4.42e-03 3.2x
Embedding 1.45e-01 1.18e-01 comparable

Many individual MoE expert weights received exactly zero gradient with CUDA graphs enabled. Router suppression was depth-dependent and reached ~88,000x in early layers.

Suspected mechanism

During CUDA graph capture/replay with cached FP8 weights, the TE graph wrapper sets a dynamic skip flag from the replay-time is_first_microbatch value:

skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

TE modules are expected to feed the resulting graph-capture tensor into weight quantization so that replay can skip FP8 weight updates after the first microbatch without baking in the captured microbatch state. In the fused wgrad accumulation path, is_first_microbatch also affects whether wgrad overwrites or accumulates into main_grad:

is_first_microbatch=True  -> overwrite main_grad
is_first_microbatch=False -> accumulate into main_grad

GroupedLinear currently does not appear to retrieve/pass this graph-capture skip tensor, unlike the other TE linear modules. As a result, MoE expert GroupedLinear can appear to capture fixed first-microbatch / weight-update behavior and replay it for later microbatches.

The local workaround that fixes the GroupedLinear path is:

if FP8GlobalStateManager.fp8_graph_capturing():
    skip_fp8_weight_update = (
        FP8GlobalStateManager.quantization_state.skip_fp8_weight_update_tensor
    )
else:
    skip_fp8_weight_update = None

if skip_fp8_weight_update is not None:
    is_first_microbatch = False

# pass skip_fp8_weight_update into non_tensor_args instead of None

For older TE versions where the helper exists, the equivalent getter is:

skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()

Reproducer outline

A minimal reproducer can be built from the public Megatron-Bridge Nemotron-3-Nano recipe:

from megatron.bridge.recipes.nemotronh.nemotron_3_nano import nemotron_3_nano_pretrain_config
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig

cfg = nemotron_3_nano_pretrain_config()

cfg.model.tensor_model_parallel_size = 1
cfg.model.expert_model_parallel_size = 8
cfg.model.expert_tensor_parallel_size = 1
cfg.model.pipeline_model_parallel_size = 1
cfg.model.sequence_parallel = False

cfg.train.global_batch_size = 512
cfg.train.micro_batch_size = 2
cfg.train.exit_interval = 1000
cfg.model.seq_length = 8192
cfg.dataset.seq_length = 8192

cfg.mixed_precision = MixedPrecisionConfig(
    bf16=True,
    fp8="hybrid",
    fp8_recipe="tensorwise",
    fp8_wgrad=True,
    fp8_dot_product_attention=False,
    fp8_multi_head_attention=False,
)

cfg.model.moe_token_dispatcher_type = "alltoall"
cfg.model.moe_expert_capacity_factor = 1.0
cfg.model.moe_token_drop_policy = "probs"
cfg.model.moe_pad_expert_input_to_capacity = True

cfg.model.cuda_graph_impl = "transformer_engine"
cfg.model.cuda_graph_scope = ["mamba", "attn", "moe"]
cfg.model.use_te_rng_tracker = True

# Compare:
cfg.model.cuda_graph_warmup_steps = 3     # reproduces suppression
cfg.model.cuda_graph_warmup_steps = 1000  # baseline, graph replay delayed

Request

Could GroupedLinear be updated to retrieve and pass the FP8 graph-capture skip_fp8_weight_update tensor during CUDA graph capture, matching the behavior of Linear, LayerNormLinear, and LayerNormMLP?

A regression test comparing GroupedLinear FP8 gradients under TE CUDA graph replay vs graph-delayed/eager execution would likely catch this. If possible, the test should cover fuse_wgrad_accumulation=True / main_grad accumulation across repeated microbatches, since that is where the gradient suppression is most visible in the Nemotron MoE run.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions