Skip to content

fix(export): correct unified_export_megatron at EP > 1 and DP > 1#1631

Open
yueshen2016 wants to merge 1 commit into
NVIDIA:mainfrom
yueshen2016:fix-ep-export-mamba-moe
Open

fix(export): correct unified_export_megatron at EP > 1 and DP > 1#1631
yueshen2016 wants to merge 1 commit into
NVIDIA:mainfrom
yueshen2016:fix-ep-export-mamba-moe

Conversation

@yueshen2016
Copy link
Copy Markdown
Contributor

@yueshen2016 yueshen2016 commented Jun 4, 2026

Summary

Two related bugs surface when exporting a Megatron-Core Mamba-MoE checkpoint (e.g., Nemotron 3 Ultra) to HuggingFace with non-trivial expert and data parallelism.

1. TEGroupedMLP expert-id collision at EP > 1

_grouped_mlp_slicing iterates range(num_experts), but num_experts == module.num_gemms is the local expert count on each EP rank. Using local ids 0..N-1 for the saved HF key prefix means every EP rank emits the same key set (experts.0..N-1.*), and the last writer wins for each layer — at EP=8 this silently loses 7/8 of every MoE layer's experts.

The fix:

  • Reads the authoritative local-to-global mapping from module.local_expert_indices (or module.experts.local_expert_indices) when exposed; falls back to the standard Megatron contiguous layout [ep_rank*N, (ep_rank+1)*N − 1].
  • Saves each expert under its global id so key sets are disjoint across EP ranks.
  • Adds a collective-safe missing-key check via all_reduce(MAX) over a CUDA-resident int32 flag — CPU tensors trip RuntimeError: No backend type associated with device type cpu on the NCCL-backed EP process group.
  • Builds the per-rank state dict locally, then byte-stream all-gathers it across the EP group via torch.saveall_gather_objecttorch.load. all_gather_object pickling fails on quantized uint8 weights because their UntypedStorage has no dtype attr; the byte-stream round-trip uses PyTorch's own tensor codec to sidestep this.
  • Pre-moves shared scales / aux tensors to CPU once so the gather payload doesn't repeatedly clone GPU tensors.

When EP == 1 the gather is skipped and behavior is identical to the original loop modulo the global-id rename (a no-op when local_expert_indices is the trivial 0..N-1).

2. config.json race at DP > 1

The post-save block that injects quantization_config into config.json had no DP-rank gating. With DP > 1 there are multiple last-stage main ranks; all of them read-then-write the same file, and any interleaving can leave another rank reading a truncated file, raising JSONDecodeError. Guarded with is_last_stage_main_rank AND get_data_parallel_rank() == 0 and bracketed with torch.distributed.barrier()s so every PP/DP rank waits for the single writer.

Two new imports (get_data_parallel_rank, get_expert_model_parallel_*) and an io import are added at the module level.

Repro

Without this patch, exporting Nemotron 3 Ultra (108-layer Mamba-MoE hybrid, 512 routed experts) at --pp 9 --tp 1 --ep 8 produces a HuggingFace checkpoint where only 64 of 512 experts per MoE layer are present, and the config.json write races to a JSONDecodeError on DP=8. With this patch both succeed and vLLM loads the resulting checkpoint and serves real generations.

Test plan

  • Exported a 108-layer 512-expert Nemotron-3-Ultra MoE checkpoint at PP=9 TP=1 EP=8 DP=8 (72 ranks across 9 nodes × 8 GPUs). Inspected model.safetensors.index.json and confirmed all 512 experts per MoE layer are present in the HF shards.
  • Sanity-loaded the checkpoint with vllm serve and verified end-to-end generation (no KeyError on weight loading, real model output).
  • Pre-existing unrelated ruff RUF059 warnings on the same file (lines 1453, 1473) are untouched by this PR.
  • No new unit-test coverage added — the bug requires a multi-node EP/DP > 1 distributed run and is non-trivial to mock. Open to suggestions on a CI-friendly regression test.

Follow-up

The SequentialMLP path in _get_transformer_layer_state_dict (iterating local_experts.linear_fc{1,2}) has the same local-id-collision issue and will need an equivalent global-id + EP-gather treatment when MoE recipes start exercising that spec. Not addressed here because the recipes in current Megatron-Bridge mostly use TEGroupedMLP.

🤖 Generated with Claude Code

Signed-off-by: James Shen yueshen@nvidia.com

Summary by CodeRabbit

  • Bug Fixes

    • Prevented concurrent corruption during distributed model export by coordinating writes and gating sidecar/config file updates to a single writer rank with barriers.
  • Improvements

    • Made multi-rank expert-parallel export collective-safe with consistent global expert indexing and cross-rank verification of missing pieces.
    • Reduced memory pressure by moving large quantization/auxiliary tensors to CPU and using safe per-rank serialization and exchange for final exports.

@yueshen2016 yueshen2016 requested a review from a team as a code owner June 4, 2026 22:52
@yueshen2016 yueshen2016 requested a review from meenchen June 4, 2026 22:52
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 4, 2026

Review Change Stack

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 060f8690-12c4-44bc-b45b-fc7577796713

📥 Commits

Reviewing files that changed from the base of the PR and between 0083a75 and 46b5c8c.

📒 Files selected for processing (1)
  • modelopt/torch/export/unified_export_megatron.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/unified_export_megatron.py

📝 Walkthrough

Walkthrough

Coordinates Megatron EP ranks for safe export: adds distributed imports, gates single-writer sidecar and quant config writes with barriers, documents EP semantics, computes local-to-global expert mappings with validation, and gathers per-expert state across EP ranks into the exported _state_dict.

Changes

Distributed Expert Export for Expert-Model-Parallel Layouts

Layer / File(s) Summary
Distributed coordination imports
modelopt/torch/export/unified_export_megatron.py
Adds io import and Megatron parallel-rank helpers (ep_group, get_expert_parallel_rank, get_expert_parallel_world_size, get_data_parallel_rank) for EP identification and collectives.
save_pretrained: single-writer sidecar gating
modelopt/torch/export/unified_export_megatron.py
Restricts sidecar/non-safetensors copy/write (tokenizer, remote-code, preprocessor/config sidecars) to DP rank 0 and EP rank 0 within the last PP main writer to avoid concurrent partial-write windows.
hf_quant_config.json single-writer gate
modelopt/torch/export/unified_export_megatron.py
Requires last PP main rank plus DP rank 0 and EP rank 0 (and quantization enabled) to assemble/write hf_quant_config.json.
config.json coordinated patch & barriers
modelopt/torch/export/unified_export_megatron.py
Runs config.json quantization_config patch only on last PP main writer with DP 0 and EP 0, and brackets JSON read/modify/write with distributed barriers to prevent corruption.
MTP safetensors comment
modelopt/torch/export/unified_export_megatron.py
Adjusts comment describing single-file (unsharded) safetensors export path.
EP grouped MLP: docs & semantics
modelopt/torch/export/unified_export_megatron.py
Documents EP>1 collective behavior and local-to-global expert ID translation rules required for the gather.
EP grouped MLP: mapping, validation & state gather
modelopt/torch/export/unified_export_megatron.py
Computes ep_size/ep_rank, resolves local_expert_indices (authoritative or contiguous fallback), performs EP-group MAX all-reduce to detect missing weight{i} keys, moves quantization/scales/aux tensors to CPU, constructs per-local-expert HF-style state on CPU, and when ep_size>1 serializes per-rank expert state to bytes and exchanges via all_gather_object/torch.load before merging into the final _state_dict.

Sequence Diagram

sequenceDiagram
  participant Exporter
  participant EP_Rank
  participant EP_Group
  participant Filesystem
  Exporter->>EP_Rank: compute ep_rank/ep_size and local_expert_indices
  EP_Rank->>EP_Group: all_reduce(MAX) on local missing-key flags
  EP_Rank->>EP_Rank: build per-local-expert tensors on CPU and torch.save -> BytesIO
  EP_Rank->>EP_Group: all_gather_object(serialized_bytes)
  EP_Rank->>EP_Rank: torch.load gathered bytes and merge into _state_dict
  EP_Rank->>Filesystem: gated single-writer writes (config.json, sidecars, hf_quant_config.json) by DP0+EP0 in last PP main
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • ChenhanYu
  • kevalmorabia97
  • h-guo18
🚥 Pre-merge checks | ✅ 5 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and concisely summarizes the main fix: correcting unified_export_megatron behavior when both EP and DP are greater than 1.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed torch.load with weights_only=False has proper inline comment. No numpy.load, hardcoded trust_remote_code, eval/exec, nosec, or problematic dependencies found. All criteria pass.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 408-428: The file-writing for the last model stage currently lets
every last-stage main rank perform shared side-effect operations (copying
sidecars, writing config.json / hf_quant_config.json) which races when DP > 1;
hoist the single-writer guard so only the last-stage main rank with
get_data_parallel_rank() == 0 performs all shared writes. Concretely, wrap the
entire last-stage sidecar/config writing sequence (the blocks that copy
sidecars, write hf_quant_config.json and config.json and any uses of
save_directory) inside a single conditional: if is_last_stage_main_rank and
get_data_parallel_rank() == 0 and self._hf_quant_config (as needed), and
surround it with torch.distributed.barrier() calls before and after so all PP/DP
ranks wait for the single writer; keep convert_hf_quant_config_format and
json.dump usage unchanged but called only from that guarded writer. Ensure no
per-rank writes to the same files occur outside this guarded block.
- Around line 1175-1190: The torch.load call that reconstructs peer-gathered
expert states currently passes weights_only=False; change it to
weights_only=True when reloading the byte stream produced from
local_expert_state (dict[str, torch.Tensor]) in the block that gathers
gathered_bytes (the torch.distributed.all_gather_object usage and subsequent
loop over gathered_bytes), i.e., update the torch.load(...) invocation that
feeds self._state_dict.update(s) to use weights_only=True; if you truly need
weights_only=False instead, add an inline comment justifying why this
peer-provided payload is trusted and include the required security-review
exception in the PR description.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d25eb08a-5d03-4923-af06-927bcd4372fb

📥 Commits

Reviewing files that changed from the base of the PR and between 01dec93 and 35467a8.

📒 Files selected for processing (1)
  • modelopt/torch/export/unified_export_megatron.py

Comment thread modelopt/torch/export/unified_export_megatron.py
Comment thread modelopt/torch/export/unified_export_megatron.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 4, 2026

Codecov Report

❌ Patch coverage is 25.33333% with 56 lines in your changes missing coverage. Please review.
✅ Project coverage is 77.36%. Comparing base (01dec93) to head (46b5c8c).

Files with missing lines Patch % Lines
modelopt/torch/export/unified_export_megatron.py 25.33% 56 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1631      +/-   ##
==========================================
+ Coverage   77.09%   77.36%   +0.27%     
==========================================
  Files         482      482              
  Lines       52961    52997      +36     
==========================================
+ Hits        40830    41003     +173     
+ Misses      12131    11994     -137     
Flag Coverage Δ
examples 43.10% <21.33%> (+1.05%) ⬆️
gpu 59.49% <25.33%> (-0.33%) ⬇️
regression 15.19% <1.33%> (+0.06%) ⬆️
unit 53.88% <1.33%> (-0.05%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yueshen2016 yueshen2016 force-pushed the fix-ep-export-mamba-moe branch from 35467a8 to 460ff39 Compare June 4, 2026 23:08
@yueshen2016
Copy link
Copy Markdown
Contributor Author

Force-pushed 460ff391f to address two issues I flagged during a self-review:

1. config.json write gate now also pins EP rank 0.

Previous gate (is_last_stage_main_rank AND dp_rank == 0) was topology-dependent: it worked on our PP=9/TP=1/EP=8 run because DP=8 and EP=8 share the same 8 GPUs per PP stage (ETP=1), so DP=0 happened to also imply EP=0. On a cluster where DP and EP are orthogonal axes, multiple EP ranks would still race. Added get_expert_model_parallel_rank() == 0 for an unambiguous single writer regardless of topology.

3. _per_layer_quant_config now records all global expert ids on every rank, not just the local 1/EP slice.

Without this, the writer rank would emit an hf_quant_config.json missing (EP-1)/EP of the routed-expert entries — vLLM's mixed-precision loader would then treat those experts as un-quantized and crash on shape/dtype mismatch. I'd worked around this downstream by manually rewriting config.json -> quantization_config post-export; this commit fixes it at the source so other users don't need the workaround. Within a TEGroupedMLP layer all routed experts share the same qformat/block_size (one recipe pattern matches the whole *mixer.experts.* glob), so the local-rank metadata is sufficient to record for all global ids.

Concerns 2 (local_expert_indices could be a tensor → defensive .tolist()), 4 (gather memory cost), 5 (weights_only=False forward-compat), 6 (inline comment expansion), and 7 (unit test) are deferred to a follow-up if maintainers think the current code is good enough to merge in this iteration.

@yueshen2016 yueshen2016 force-pushed the fix-ep-export-mamba-moe branch 2 times, most recently from 6beb3cb to 0083a75 Compare June 4, 2026 23:30
@yueshen2016
Copy link
Copy Markdown
Contributor Author

Force-pushed 0083a75d3 addressing all three actionable items from CodeRabbit's pre-merge checks and a security audit point.

Responses to inline review comments

@coderabbitai inline #1 (line 432, Major — hoist single-writer guard)

Applied. You were right that gating only the config.json patch leaves a race on the surrounding writes. The single-writer guard is now hoisted around two more blocks:

  • Lines 314-350 (sidecar copies / save_pretrained calls). Wrapped in if get_data_parallel_rank() == 0 and get_expert_model_parallel_rank() == 0: inside the outer is_last_stage_main_rank check. The MTP-state-dict load (lines 352-355) is intentionally left running on every last-stage main rank because it mutates per-rank layer_state_dicts that later feed each rank's own safetensors save — single-writer here would lose 7/8 of the MTP head weights.
  • Lines 362-385 (hf_quant_config.json write). Same hoisted gate; self._hf_quant_config now stays empty on non-writer ranks, which makes the existing config.json gate at the bottom of the method naturally a no-op on them. The barriers we added at the config.json site already cover the non-writer ranks' wait.

@coderabbitai inline #2 (line 1214, Critical — weights_only=False)

Applied. Added an inline # weights_only=False: comment justifying the use per SECURITY.md:

The payload is generated by us via torch.save in this same function on a sibling rank in the EP process group of this job — it never leaves the cluster's collective and is not user-supplied. weights_only=True would refuse to deserialize the dict[str, Tensor] because quantized uint8 tensors store custom storage metadata that the safe-loader allowlist doesn't cover.

Also added a # noqa: S614 suppression on the line so the security-anti-patterns check passes.

Pre-merge check: Security Anti-Patterns — Error (line 1213)

Resolved by the comment above on line 1213.

Responses to other reviewers

@codecov (Patch coverage 2.08%)

The new code lives in distributed-only code paths (_grouped_mlp_slicing, EP/DP-gated config writers) that require ≥1 multi-node multi-process distributed run to exercise. The EP > 1 code path specifically requires torch.distributed.is_initialized() == True with expert_model_parallel_world_size > 1, which isn't covered by current unit tests in this file (existing tests run single-rank).

The single-rank fallback (if ep_size == 1: self._state_dict.update(local_expert_state) — the else branch at line 1217) is testable and could be exercised by a unit test instantiating a TEGroupedMLP-like object with a synthetic state_dict. I'd be happy to add that in a follow-up, but the substantive correctness here is the EP > 1 path which needs a distributed test harness this repo doesn't currently have for unified_export_megatron.

End-to-end validation: ran the patched code against the Nemotron 3 Ultra checkpoint at PP=9 TP=1 EP=8 → 512 experts/layer present in the exported HF shards, vLLM 0.21 with quant_algo: MIXED_PRECISION loads via the modelopt_mixed path and serves real generations.

Pre-merge check: Docstring coverage 66.67%

This appears to be a whole-file metric. The two functions I substantively touched (_grouped_mlp_slicing and the surrounding save_pretrained block) already have docstrings; my changes expanded the existing _grouped_mlp_slicing docstring to document the new collective-behavior contract (EP-group all_gather_object). I haven't added docstrings to other pre-existing undocumented functions in the file since they're outside the scope of this PR. Happy to tighten if maintainers prefer a separate docstring-coverage cleanup commit.


Signed-off-by: James Shen yueshen@nvidia.com

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 4, 2026

Tip

For best results, initiate chat on the files or code changes.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# List CodeRabbit's review comments on PR `#1631` to find comment IDs to resolve
gh api repos/NVIDIA/Model-Optimizer/pulls/1631/comments --jq '.[] | select(.user.login == "coderabbitai") | {id: .id, body: .body[0:120]}'

Length of output: 472


🏁 Script executed:

#!/bin/bash
# Try fetching review comments via the reviews API
gh api repos/NVIDIA/Model-Optimizer/pulls/1631/reviews --jq '.[] | select(.user.login == "coderabbitai") | {id: .id, state: .state, body: .body[0:100]}'

Length of output: 458



@yueshen2016 I wasn't able to resolve the inline comments programmatically from the sandbox, but please feel free to mark them as resolved directly in the GitHub UI.

To summarize: all three actionable items have been satisfactorily addressed in 0083a75d3:

Item Status
#1 — Hoist single-writer guard (Major) ✅ Applied
#2weights_only=False justification (Critical/Security) ✅ Applied + # noqa: S614
Security Anti-Patterns pre-merge check ✅ Resolved

The deferred items (defensive .tolist(), gather memory cost, weights_only forward-compat note, comment expansion, unit test) are all non-blocking for correctness. Would you like me to open a follow-up GitHub issue to track them so they don't get lost?

🐇 ✨


✏️ Learnings added
Learnt from: yueshen2016
URL: https://github.com/NVIDIA/Model-Optimizer/pull/1631

Timestamp: 2026-06-04T23:33:00.547Z
Learning: In `modelopt/torch/export/unified_export_megatron.py`, `torch.load(..., weights_only=False)` is used in `_grouped_mlp_slicing` for deserializing EP peer state dicts. This is intentional and safe: the payload is generated by `torch.save` within the same function on a sibling rank in the EP process group and never leaves the cluster collective. `weights_only=True` is incompatible because quantized uint8 tensors use custom storage metadata not covered by the safe-loader allowlist. The line carries `# noqa: S614` and an inline justification comment per SECURITY.md requirements.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: shengliangxu
Repo: NVIDIA/Model-Optimizer PR: 1423
File: modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_only.yaml:27-37
Timestamp: 2026-05-18T18:56:18.772Z
Learning: In `modelopt_recipes/configs/ptq/presets/model/nvfp4_mlp_only.yaml`, the overlapping quantizer patterns `*mlp*weight_quantizer` / `*mlp*input_quantizer` (inline) and `*.experts.*weight_quantizer` / `*.experts.*input_quantizer` (via imported `experts_nvfp4` unit) are intentional. This mirrors the original Python `_nvfp4_selective_quant_cfg(["*mlp*", "*block_sparse_moe*", "*.experts.*"])` behavior. The `experts_nvfp4` import is needed to cover expert layers NOT nested under `mlp` (e.g., `decoder.experts.X.weight_quantizer`). Both paths apply the same NVFP4 config so the overlap is idempotent and harmless.

Learnt from: ChenhanYu
Repo: NVIDIA/Model-Optimizer PR: 1496
File: tools/launcher/examples/moonshotai/Kimi-K2.5/hf_offline_dflash.yaml:15-15
Timestamp: 2026-05-21T19:18:55.357Z
Learning: In `tools/launcher/examples/moonshotai/Kimi-K2.5/hf_offline_dflash.yaml`, the `global_vars.hf_model` path `/hf-local/moonshotai/Kimi-K2.6` is an intentional stand-in for Kimi-K2.5 weights that are not yet staged on cw-dfw. This is documented in the JIRA spec for OMNIML-4740 and the commit message body. Do not flag this as a mismatch during review until the K2.5 weights are confirmed staged.

Learnt from: CR
Repo: NVIDIA/Model-Optimizer PR: 0
File: coderabbit-custom-pre-merge-checks-unique-id-file-non-traceable-F7F2B60C-1728-4C9A-8889-4F2235E186CA.txt:0-0
Timestamp: 2026-04-28T13:40:56.970Z
Learning: Applies to **/*.py : Review torch.load(..., weights_only=False) against SECURITY.md practices. Flag as CRITICAL if used without an inline comment justifying why it is safe (e.g., confirming the file is internally-generated and not user-supplied).

Two related bugs surface when exporting Mamba-MoE Megatron-Core
checkpoints (e.g., Nemotron 3 Ultra) to HuggingFace with non-trivial
expert / data parallelism.

1. **TEGroupedMLP expert-id collision at EP > 1**. `_grouped_mlp_slicing`
   iterates `range(num_experts)`, but `num_experts == module.num_gemms`
   is the *local* expert count on each EP rank. Using local ids 0..N-1
   for the saved HF key prefix means all EP ranks emit the same key
   set (`experts.0..N-1.*`), and the last writer wins for each layer —
   at EP=8 this silently loses 7/8 of every MoE layer's experts.

   This change:

   - Reads the authoritative local-to-global mapping from
     `module.local_expert_indices` (or
     `module.experts.local_expert_indices`) when exposed; falls back to
     the standard Megatron contiguous layout `[ep_rank*N,
     (ep_rank+1)*N - 1]`.
   - Saves each expert under its *global* id so the key sets are
     disjoint across EP ranks.
   - Adds a collective-safe missing-key check via `all_reduce(MAX)`
     over a CUDA-resident int32 flag — CPU tensors trip
     `RuntimeError: No backend type associated with device type cpu`
     on the NCCL-backed EP process group.
   - Builds the per-rank state dict locally, then byte-stream
     all-gathers it across the EP group via `torch.save` →
     `all_gather_object` → `torch.load`. (`all_gather_object` pickling
     fails on quantized uint8 weights because their `UntypedStorage`
     has no `dtype` attr; the byte-stream round-trip uses PyTorch's
     own tensor codec instead.)
   - Pre-moves shared scales / aux tensors to CPU once so the gather
     payload doesn't repeatedly clone GPU tensors.
   - Records `_per_layer_quant_config` for **all global expert ids**
     (0..num_experts*ep_size - 1) on every rank, not just the local
     slice — `_per_layer_quant_config` is later serialized into
     `hf_quant_config.json`, and if each EP rank only recorded its
     local 1/EP slice the writer rank's dict would miss (EP-1)/EP of
     the routed-expert entries. Within a single TEGroupedMLP layer all
     routed experts share the same qformat / block_size by
     construction, so it's safe to reuse the local
     qformat/block_size for the global record.

   When EP == 1 the gather is skipped and behavior is identical to
   the original loop modulo the global-id rename (no-op there).

2. **config.json write race at DP > 1 or EP > 1**. The post-save block
   that injects `quantization_config` into `config.json` had no DP/EP
   rank gating. With DP > 1 or EP > 1, multiple ranks satisfy
   `is_last_stage_main_rank` simultaneously (one per DP×EP cell on
   the last PP stage with TP rank 0), all of them read-then-write the
   same file, and any pair can interleave such that another rank
   reads a truncated mid-write file and raises `JSONDecodeError`.
   Guarded with
   `is_last_stage_main_rank AND get_data_parallel_rank() == 0 AND
   get_expert_model_parallel_rank() == 0`, bracketed with
   `torch.distributed.barrier()` so every other rank waits for the
   single writer.

The new imports (`get_data_parallel_rank`,
`get_expert_model_parallel_*`) and `io` are added at the module level.

Verified end-to-end on Nemotron 3 Ultra (108-layer Mamba-MoE,
512 routed experts) exported with PP=9 TP=1 EP=8 → all 512 experts
per layer present in the HF shards, vLLM 0.21 loads the checkpoint
via the `modelopt_mixed` path and serves real generations.

The TODO note is intentionally left for a follow-up: the
`SequentialMLP` path (around `local_experts.linear_fc{1,2}` in
`_get_transformer_layer_state_dict`) has the same local-id-collision
issue and will need an equivalent global-id + EP-gather treatment
when MoE recipes start exercising that spec.

Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the fix-ep-export-mamba-moe branch from 0083a75 to 46b5c8c Compare June 5, 2026 02:39
@yueshen2016
Copy link
Copy Markdown
Contributor Author

Force-pushed 46b5c8c03 to fix the code-quality failure.

The check failed on a single RUF003 (ambiguous unicode) in a comment I wrote — DP×EP used U+00D7 MULTIPLICATION SIGN. Replaced with ASCII x. Also normalized 5 em-dashes () to ASCII -- in our added comments to keep the diff fully ASCII and avoid any future RUF002 trip.

Locally verified ruff check --select=RUF001,RUF002,RUF003 and ruff format --check both pass. The two pre-existing RUF059 warnings at lines 1504 and 1524 are upstream and not touched by this PR (confirmed by stashing and re-checking).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant