Checkpoint save/load FSDP and ShardTensor support#1472
Checkpoint save/load FSDP and ShardTensor support#1472pzharrington wants to merge 6 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR introduces FSDP and ShardTensor-aware checkpoint save/load by centralising distributed state-dict logic in
A minor style issue was also identified: a test file saves a plain Last reviewed commit: 52d03eb |
| def _redistribute_sd_for_dtensor( | ||
| placements: dict[str, tuple[Any, tuple[Any, ...]]], | ||
| state_dict: dict[str, Any], | ||
| device_type: str = "cuda", | ||
| ) -> dict[str, Any]: | ||
| """Convert plain tensors in *state_dict* to DTensors matching *placements*. | ||
|
|
||
| Entries whose key appears in *placements* are converted via | ||
| ``distribute_tensor`` so that each rank receives its correct local shard. | ||
| All other entries are left unchanged. | ||
| """ | ||
| if not placements: | ||
| return state_dict | ||
|
|
||
| # Determine a fallback (Replicate) mesh for keys NOT in placements but | ||
| # that may still need to be DTensors (e.g. when FSDP use_orig_params=False | ||
| # promotes all params to DTensor). | ||
| fallback_mesh = next(iter(placements.values()))[0] if placements else None | ||
|
|
||
| out: dict[str, Any] = {} | ||
| for key, value in state_dict.items(): | ||
| if not isinstance(value, torch.Tensor) or isinstance(value, DTensor): | ||
| out[key] = value | ||
| continue | ||
|
|
||
| if key in placements: | ||
| mesh, plc = placements[key] | ||
| out[key] = distribute_tensor(value.to(device_type), mesh, list(plc)) | ||
| elif fallback_mesh is not None: | ||
| out[key] = distribute_tensor( | ||
| value.to(device_type), fallback_mesh, [Replicate()] | ||
| ) | ||
| else: | ||
| out[key] = value | ||
| return out |
There was a problem hiding this comment.
Docstring/implementation mismatch in _redistribute_sd_for_dtensor
The function docstring states "All other entries are left unchanged," but the implementation distributes every plain tensor (regardless of whether its key appears in placements) to the fallback_mesh as a Replicate DTensor when fallback_mesh is not None (lines 191–194). For a model that has both user-managed ShardTensor parameters and plain (non-distributed) parameters, this would wrap the plain tensors in a Replicate DTensor backed by the ShardTensor mesh — which may cause shape/device issues downstream when set_model_state_dict tries to assign them.
The docstring should be corrected to accurately describe this fallback behavior, or (if plain parameters are not meant to become DTensors) the fallback branch should be guarded so it only fires for keys that are actually expected to be DTensors.
| def _redistribute_sd_for_dtensor( | |
| placements: dict[str, tuple[Any, tuple[Any, ...]]], | |
| state_dict: dict[str, Any], | |
| device_type: str = "cuda", | |
| ) -> dict[str, Any]: | |
| """Convert plain tensors in *state_dict* to DTensors matching *placements*. | |
| Entries whose key appears in *placements* are converted via | |
| ``distribute_tensor`` so that each rank receives its correct local shard. | |
| All other entries are left unchanged. | |
| """ | |
| if not placements: | |
| return state_dict | |
| # Determine a fallback (Replicate) mesh for keys NOT in placements but | |
| # that may still need to be DTensors (e.g. when FSDP use_orig_params=False | |
| # promotes all params to DTensor). | |
| fallback_mesh = next(iter(placements.values()))[0] if placements else None | |
| out: dict[str, Any] = {} | |
| for key, value in state_dict.items(): | |
| if not isinstance(value, torch.Tensor) or isinstance(value, DTensor): | |
| out[key] = value | |
| continue | |
| if key in placements: | |
| mesh, plc = placements[key] | |
| out[key] = distribute_tensor(value.to(device_type), mesh, list(plc)) | |
| elif fallback_mesh is not None: | |
| out[key] = distribute_tensor( | |
| value.to(device_type), fallback_mesh, [Replicate()] | |
| ) | |
| else: | |
| out[key] = value | |
| return out | |
| def _redistribute_sd_for_dtensor( | |
| placements: dict[str, tuple[Any, tuple[Any, ...]]], | |
| state_dict: dict[str, Any], | |
| device_type: str = "cuda", | |
| ) -> dict[str, Any]: | |
| """Convert plain tensors in *state_dict* to DTensors matching *placements*. | |
| Entries whose key appears in *placements* are converted via | |
| ``distribute_tensor`` so that each rank receives its correct local shard. | |
| When fallback_mesh is provided, remaining plain tensors are also | |
| distributed as ``Replicate`` DTensors on that mesh. | |
| """ |
| weights_file = f"{shared_tmp_dir}/trained_shard.mdlus" | ||
| if dm.rank == 0: | ||
| torch.save(model.state_dict(), weights_file) |
There was a problem hiding this comment.
Incorrect file extension for torch.save output
The weights file is named with a .mdlus extension, but it is saved using torch.save(model.state_dict(), ...), which produces a raw .pt-format file — not the zip/tar archive that the .mdlus format actually requires. The test passes because load_model_weights dispatches on model type (not file extension): since _PosEmbedModel is a plain nn.Module, it calls torch.load directly and ignores the extension. However, the naming is misleading and will confuse future readers about what formats are valid for .mdlus files.
| weights_file = f"{shared_tmp_dir}/trained_shard.mdlus" | |
| if dm.rank == 0: | |
| torch.save(model.state_dict(), weights_file) | |
| weights_file = f"{shared_tmp_dir}/trained_shard.pt" | |
| if dm.rank == 0: | |
| torch.save(model.state_dict(), weights_file) |
PhysicsNeMo Pull Request
Description
Summary
save_checkpointandload_checkpointnow automatically detect FSDP-wrapped and DTensor/ShardTensor-distributed models and use PyTorch's Distributed Checkpoint (DCP) state-dict APIs to gather/scatter model and optimizer state. In distributed mode all ranks call the functions collectively, while only rank 0 performs file I/O. This eliminates the need for manual parameter gathering/scattering that recipe code (e.g. StormCast) previously had to implement.load_model_weightsutility: A convenience function for loading a single.mdlusor.ptfile directly into a (potentially distributed) model, handling FSDP + DTensor redistribution automatically.parallel.py(gather_training_state,scatter_optimizer_state,shard_state_dict,scatter_object,get_state_dict_shard)and ~50 lines of rank-0 CPU model/optimizer bookkeeping from trainer.py. All ranks now participate symmetrically in_resume_or_init, callingload_checkpoint/save_checkpointdirectly.physicsnemo.core.Module.save: Added an optionalstate_dictparameter sosave_checkpointcan pass a pre-gathered full state dictionary for FSDP/DTensor models without callingself.state_dict()on the distributed module.StateDictOptions.broadcast_from_rank0(used in the pure-FSDP load path) was introduced in PyTorch 2.5. This option enables rank 0 to broadcast the full state dict to all other ranks without manual scatter, which is the standard non-DTensor distributed load mechanism.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.