feat: add MegatronMIMO model conversion#3905
Conversation
|
/ok to test 9ce2fea |
Signed-off-by: Li Ding <liding@nvidia.com>
9ce2fea to
ce528bf
Compare
|
/ok to test ce528bf |
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test 617c404 |
|
/ok to test 6408e9b |
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test acbd4cd |
Signed-off-by: Li Ding <liding@nvidia.com>
|
/ok to test 0979433 |
| f"--component {name!r}: ranks per component ({ranks_per_component}) " | ||
| f"is not divisible by total_model_parallel_size ({mp}); " | ||
| f"specify dp=N explicitly." | ||
| ) | ||
| parallelism.data_parallel_size = ranks_per_component // mp | ||
| parallelism.rank_offset = offset | ||
| # Whether auto-filled or user-supplied, advance offset by this | ||
| # component's ranks so the next auto-filled component lands after it. | ||
| offset += parallelism.total_ranks | ||
|
|
||
|
|
||
| def run_import( |
There was a problem hiding this comment.
Minor: _auto_fill_layout is order-dependent when mixing user-supplied dp= and auto-filled components. If a user-supplied component (which keeps rank_offset=0 by default) appears after an auto-filled component, both can end up at the same rank_offset, causing overlapping rank assignments. This only happens with mixed user/auto layouts and dict iteration order, so it may never surface in practice — but worth a comment or a post-loop overlap check to future-proof.
Light Code ReviewOverall this is a well-structured PR with comprehensive unit test coverage (131 test functions across 10 new/modified test files). Bug fix in state.py (correctness): The save_generator fix (replacing keys_for_file with tensors_to_save.keys()) is a genuine correctness fix -- the old code could produce a model.safetensors.index.json that maps keys to files that do not contain them. No dedicated regression test for this specific bug path was added; consider adding one if practical. _auto_fill_layout ordering dependence (minor): In examples/conversion/convert_megatron_mimo.py, when mixing user-supplied dp= components with auto-filled ones, the auto-fill is iteration-order-dependent. A user-supplied component keeps its default rank_offset=0 and is not repositioned, so a later auto-filled component could overlap it in rank space. See inline comment. _rebuild_derived_specs_from_standard_provider mutates in place: Setting standard_provider.mtp_num_layers = None mutates the provider object in place. If the same provider instance is used in a non-MIMO path elsewhere, MTP would be silently disabled. Documented as intentional for MIMO v1, but worth noting. Missing test for _bridge_parallel_state_globals: The build_model.py helper sets about 10 parallel_state globals but has no direct unit test. A targeted test would catch regressions if the mcore global names change. The PR adds excellent unit test coverage for CLI parsing, route validation, orchestrator import/export/streaming, Qwen3.5-VL default conversion, MIMO provider factory, TransformerConfig finalization, and a functional checkpoint roundtrip test. Suggested test cases: No perf tests impacted. |
Summary
This PR adds the framework-level conversion path for MegatronMIMO:
MimoModelwith per-component parallelism;safetensors.Dense Qwen3.5-VL is the first model family using the framework. It only adds standard bridge/provider metadata needed for default route construction; the conversion infrastructure itself is model-family agnostic.
This enables the workflow: download an HF VLM checkpoint → convert to MIMO → continue training/SFT/LoRA with heterogeneous component parallelism.
Why MIMO Needs Dedicated Conversion Infrastructure
The standard bridge path assumes one provider, one model, one global
parallel_state, and one rank topology. MegatronMIMO breaks those assumptions: language, vision, and future modality components can use different TP/PP/DPlayouts and may live on different rank sets.
There are two core problems this PR solves:
Global distributed state. Standard bridge mappings read TP/PP/DP groups from MCore global
parallel_state. MegatronMIMO intentionally does not initialize one global topology; each component owns its ownProcessGroupCollection. The conversion orchestrator therefore runs one component route at a time, temporarily bridges MCoreparallel_statefrom that component's process groups, and attaches the samepg_collectionto the target submodule. This lets existing bridge mapping/gather code run without changing standard conversion internals.Component-local naming. Standard bridges describe parameters in a single Megatron namespace, e.g.
language_model.*andvision_model.*. InsideMimoModel, those are separate target submodules, so the componentprefix must be stripped before dispatch. A route table maps each component to its target submodule and builds a route-local mapping registry by cloning the source bridge mappings with the component prefix removed.
The rest of the framework follows from those two constraints:
TransformerConfigs inside MIMO specs are finalized before model construction.parallel_state.MimoModelschema has no MTP submodule.Framework-Level Design
The implementation is centered on
MegatronMIMOBridge, anAutoBridgesubclass with MIMO-specific entry points.to_megatron_mimo_provider()builds aMegatronMIMOProvider,to_megatron_model()imports HF weights into a constructedMimoModel, andimport_ckpt()/export_ckpt()provide the checkpoint-level user workflow.Default model support is metadata-driven:
mimo_source_prefixes, e.g.{"language": "language_model.", "images": "vision_model."};modality_keys = {"images": "qwen_visual"}, plusbuild_language_model_spec()andspecial_token_ids;MIMOComponentroutes and builds the provider viaMegatronMIMOProvider.from_standard_provider().For each route, the framework clones the standard bridge mapping registry, strips the component prefix, temporarily exposes the component process groups through
parallel_state, and invokes the existing standard bridge weight import/export code on the target submodule. This keeps most conversion logic shared with the standard bridge path.Explicit
register_mimo_conversion_spec()support remains as an escape hatch for models whose provider or route construction cannot be derived from standard metadata.Checkpoint save/load uses the regular Megatron distributed checkpointing path with MIMO-aware module names and component process groups. Derived model specs are rebuilt from the persisted standard provider on load, so HF import -> MIMO checkpoint -> fresh export works across processes.
Adding another model family should usually require only standard bridge/provider metadata; the generic orchestrator should not need model-family edits.
Validation
Qwen/Qwen3.5-27BHF → MIMO checkpoint → HF export on 8 GPUs (language=tp=4,images=tp=4), non-MTP parity1184/1184,0mismatches. Expected ignored MTP keys:15.Not included in this PR: an L0/L1 functional test. The current conversion path needs multiple GPUs and a real HF model; a synthetic 2-GPU functional test can be added as follow-up.
Limitations
User Surface
The concrete import/export examples