[PyTorch]Add Casting-Free FP8-Flow-MoE Blockwise Optimizations #2544
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR introduces blockwise, scaling-aware FP8 transpose optimizations for FP8 MoE that enable a casting-free, FP8-centric MoE dataflow in TransformerEngine by eliminating unnecessary cast and re-quantization steps, while maintaining numerical stability in existing FP8 training workflows.
This PR is designed to be used in conjunction with PR NVIDIA/Megatron-LM#2764
Further optimizations are introduced via two additional PRs:
Background / Motivation
The design and theoretical background of this PR are described in the paper:
FP8-Flow-MoE: A Casting-Free FP8 Recipe without Double Quantization Error
The follow figure illustrates the optimized MoE dataflow and highlights the key optimization points (marked as ①–⑤).
1. FP8 Quantization Before Dispatch (DeepEP → GroupedLinear)
Quantization is performed before DeepEP dispatch, and row-wise FP8 tensors are directly fed into GroupedLinear.
dispatch → permute → expert computationentirely in FP8Float8BlockwiseQTensoris propagated with a COMPACT layout (for_rowwise_scale_inv) along thedispatch → permute → GroupedLinearpath, avoiding layout-induced.T.contiguous()calls and reducing unnecessary memory copies.(Shown as marker ① in the figure)
2. Scaling-Aware FP8 Transpose for Wgrad
GroupedLinear requires:
To avoid
dequantize → transpose → requantize, this PR introducesscaling_aware_fp8_transpose, which:(Shown as marker ④ in the figure)
3. Fused Permute + Padding / Unpermute + Unpadding
We fuse two memory movement operators along the MoE path:
permute + padin the forward passunpermute + unpadin the backward passFor details of this optimization, please refer to PR #1921
(Shown as marker ② in the figure)
4. Fused Activation + Quantization
Activation and FP8 quantization are fused into a single kernel, Produces FP8 outputs directly, while enabling FP8 persistence
(Shown as marker ③ in the figure)
5. Add fine-grained recompute
moe_expertBecause the entire
dispatch → permute → GroupedLinearpath stays in FP8, we enable fine-grained recomputation at themoe_expertlevel:moelevel(Shown as marker ⑤ in the figure)
Performance Results
We evaluate FP8-Flow-MoE on DeepSeek-V3 (671B) to validate scalability and robustness under realistic large-scale training conditions.
Throughput
Measured throughput (TGS, tokens/GPU/s) under different expert parallelism (EP) on DeepSeek-V3 (671B) :
vs. BF16
+6% (EP8), +8% (EP16), +16% (EP32)
vs. TransformerEngine blockwise FP8 recipe
+3% (EP8), +8% (EP16), up to +21% (EP32)
Memory Efficiency
With AC = selective checkpointing and recompute-modules = moe_expert:
Numerical Accuracy
We trained for >200B tokens. The loss deviation of FP8-Flow-MoE stays within 0.19% compared to both BF16 baselines, with no observed instability or divergence.
Limitations
Type of change
Changes
Please list the changes introduced in this PR:
fused_bias_swiglu.pyandfused_weighted_swiglu_quant.pyfused_a2a.pyFloat8BlockwiseQTensorinputs ingrouped_linear.pyscaling_aware_fp8_transposeoperator intriton/blockwise_scaling_aware_fp8_transpose.pyChecklist: