Skip to content

Conversation

@xiaoxi-wangfj
Copy link
Contributor

@xiaoxi-wangfj xiaoxi-wangfj commented Dec 26, 2025

Description

This PR introduces blockwise, scaling-aware FP8 transpose optimizations for FP8 MoE that enable a casting-free, FP8-centric MoE dataflow in TransformerEngine by eliminating unnecessary cast and re-quantization steps, while maintaining numerical stability in existing FP8 training workflows.

This PR is designed to be used in conjunction with PR NVIDIA/Megatron-LM#2764

Further optimizations are introduced via two additional PRs:

Background / Motivation

The design and theoretical background of this PR are described in the paper:
FP8-Flow-MoE: A Casting-Free FP8 Recipe without Double Quantization Error

The follow figure illustrates the optimized MoE dataflow and highlights the key optimization points (marked as ①–⑤).

FP8FLOW-MoE

1. FP8 Quantization Before Dispatch (DeepEP → GroupedLinear)

Quantization is performed before DeepEP dispatch, and row-wise FP8 tensors are directly fed into GroupedLinear.

  • Keeps dispatch → permute → expert computation entirely in FP8
  • Float8BlockwiseQTensor is propagated with a COMPACT layout (for _rowwise_scale_inv) along the dispatch → permute → GroupedLinear path, avoiding layout-induced .T.contiguous() calls and reducing unnecessary memory copies.

(Shown as marker ① in the figure)

2. Scaling-Aware FP8 Transpose for Wgrad

GroupedLinear requires:

  • row-wise FP8 for Fprop/Dgrad
  • column-wise FP8 for Wgrad

To avoid dequantize → transpose → requantize , this PR introduces scaling_aware_fp8_transpose, which:

  • Converts row-wise FP8 to column-wise FP8 via exponent manipulation only
  • Preserves scale consistency across layouts
  • reduce cpu overhead

(Shown as marker ④ in the figure)

3. Fused Permute + Padding / Unpermute + Unpadding

We fuse two memory movement operators along the MoE path:

  • permute + pad in the forward pass
  • unpermute + unpad in the backward pass

For details of this optimization, please refer to PR #1921

(Shown as marker ② in the figure)

4. Fused Activation + Quantization

Activation and FP8 quantization are fused into a single kernel, Produces FP8 outputs directly, while enabling FP8 persistence

(Shown as marker ③ in the figure)

5. Add fine-grained recompute moe_expert

Because the entire dispatch → permute → GroupedLinear path stays in FP8, we enable fine-grained recomputation at the moe_expert level:

  • Saves ~50% peak activation memory and avoids recomputation of the router compared to recomputing the full module moe level

(Shown as marker ⑤ in the figure)

Performance Results

We evaluate FP8-Flow-MoE on DeepSeek-V3 (671B) to validate scalability and robustness under realistic large-scale training conditions.

Throughput

Measured throughput (TGS, tokens/GPU/s) under different expert parallelism (EP) on DeepSeek-V3 (671B) :

  • vs. BF16
    +6% (EP8), +8% (EP16), +16% (EP32)

  • vs. TransformerEngine blockwise FP8 recipe
    +3% (EP8), +8% (EP16), up to +21% (EP32)

Memory Efficiency

With AC = selective checkpointing and recompute-modules = moe_expert:

  • At EP8:
    • ~8 GB lower peak memory vs. BF16
    • ~16.5 GB lower peak memory vs. blockwise FP8

Numerical Accuracy

We trained for >200B tokens. The loss deviation of FP8-Flow-MoE stays within 0.19% compared to both BF16 baselines, with no observed instability or divergence.

Limitations

  • Currently validated on NVIDIA Hopper architecture with blockwise FP8 recipe

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Megatron-LM: Added fused FP8 kernels for activation + quantization in fused_bias_swiglu.py and fused_weighted_swiglu_quant.py
  • Megatron-LM: Integrated FP8 dispatch and expert recomputation support in Megatron-LM fused_a2a.py
  • TransformerEngine: Added support for Float8BlockwiseQTensor inputs in grouped_linear.py
  • TransformerEngine: Added scaling_aware_fp8_transpose operator in triton/blockwise_scaling_aware_fp8_transpose.py

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…ization

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
1. add fp8 rowwise scaling-aware transpose op for wgrad columwise.
2. support Float8BlockwiseQTensor input in grouped_linear.
3. _rowwise_scale_inv is propagated with a COMPACT layout along the `dispatch → permute → GroupedLinear` path.

Signed-off-by: xiaoxi-wangfj <690912414@qq.com>
Co-authored-by: dantesuu@gmail.com
Co-authored-by: xzhu@zhejianglab.org
Co-authored-by: 123sssmmm@gmail.com
@rich-junwang
Copy link

Quick question, does this work with mxfp8 or it only applies to fp8? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants