Fix TEGroupedLinear quantization for expert parallelism (EP > 1)#833
Fix TEGroupedLinear quantization for expert parallelism (EP > 1)#833yueshen2016 wants to merge 1 commit intomainfrom
Conversation
📝 WalkthroughWalkthroughThe changes refactor Mixture-of-Experts (MoE) calibration handling in PyTorch quantization across three modules. They add explicit MoE calibration validation and local expert amax synchronization in model_calib.py, remove the specialized _QuantMoELayer class from megatron.py, and improve argument parsing robustness in transformer_engine.py's grouped linear quantization path for varying input configurations. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #833 +/- ##
=======================================
Coverage 73.72% 73.72%
=======================================
Files 196 196
Lines 20457 20457
=======================================
Hits 15082 15082
Misses 5375 5375 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
0deb9b6 to
a85f04e
Compare
43f7a19 to
53aec4f
Compare
Signed-off-by: James Shen <yueshen@nvidia.com>
53aec4f to
febe313
Compare
What does this PR do?
Type of change: Bug fix / Compatibility update
Overview:
Fix
te_grouped_quantized_linear_fnargument parsing for TEGroupedLinear quantization when parallelism configuration results in fewer local experts per GPU.Problem
TransformerEngine changed the _GroupedLinear.forward signature in PR #2377 (released in TE 2.10):
Old signature (TE < 2.10): forward(ctx, inp, m_splits: List[int], use_bias, is_first_microbatch, ...)
New signature (TE >= 2.10): forward(ctx, inp, non_tensor_args: Tuple, *weights_and_biases) where non_tensor_args = (m_splits, use_bias, is_first_microbatch, ...)
Without this fix, ModelOpt's quantization code fails with newer TE versions because it tries to access m_splits directly from args[idx + 1], but in TE >= 2.10, that position contains the non_tensor_args tuple instead.
Root Cause
The code assumed m_splits was always directly accessible at args[idx + 1], but TransformerEngine PR #2377 changed the signature to pack all non-tensor arguments into a tuple.
Taking Qwen3-30B-A3B (with
num_gemms=21, threshold=44) as an example:Solution
Added version checking to handle both signatures:
Usage
Works seamlessly with any TransformerEngine version:
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Bug Fixes
Improvements
✏️ Tip: You can customize this high-level summary in your review settings.