Skip to content

Checkpoint save/load FSDP and ShardTensor support#1472

Draft
pzharrington wants to merge 6 commits intoNVIDIA:mainfrom
pzharrington:fsdp-shardtensor-ckpt
Draft

Checkpoint save/load FSDP and ShardTensor support#1472
pzharrington wants to merge 6 commits intoNVIDIA:mainfrom
pzharrington:fsdp-shardtensor-ckpt

Conversation

@pzharrington
Copy link
Collaborator

@pzharrington pzharrington commented Mar 5, 2026

PhysicsNeMo Pull Request

Description

Summary

  • FSDP/ShardTensor-aware checkpoint save and load: save_checkpoint and load_checkpoint now automatically detect FSDP-wrapped and DTensor/ShardTensor-distributed models and use PyTorch's Distributed Checkpoint (DCP) state-dict APIs to gather/scatter model and optimizer state. In distributed mode all ranks call the functions collectively, while only rank 0 performs file I/O. This eliminates the need for manual parameter gathering/scattering that recipe code (e.g. StormCast) previously had to implement.
  • New load_model_weights utility: A convenience function for loading a single .mdlus or .pt file directly into a (potentially distributed) model, handling FSDP + DTensor redistribution automatically.
  • StormCast recipe simplification: Removed ~200 lines of manual checkpoint gather/scatter logic from parallel.py (gather_training_state, scatter_optimizer_state, shard_state_dict, scatter_object, get_state_dict_shard) and ~50 lines of rank-0 CPU model/optimizer bookkeeping from trainer.py. All ranks now participate symmetrically in _resume_or_init, calling load_checkpoint / save_checkpoint directly.
  • physicsnemo.core.Module.save: Added an optional state_dict parameter so save_checkpoint can pass a pre-gathered full state dictionary for FSDP/DTensor models without calling self.state_dict() on the distributed module.
  • Minimum torch version bump 2.4 → 2.5: Required because StateDictOptions.broadcast_from_rank0 (used in the pure-FSDP load path) was introduced in PyTorch 2.5. This option enables rank 0 to broadcast the full state dict to all other ranks without manual scatter, which is the standard non-DTensor distributed load mechanism.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@pzharrington
Copy link
Collaborator Author

@greptileai

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR introduces FSDP and ShardTensor-aware checkpoint save/load by centralising distributed state-dict logic in physicsnemo/utils/checkpoint.py, eliminating ~250 lines of manual gather/scatter code from the StormCast recipe, and adding a new load_model_weights convenience function. The approach is architecturally sound: save_checkpoint / load_checkpoint auto-detect distributed models via _is_distributed_model, enter a collective code path using PyTorch's DCP get_model_state_dict / set_model_state_dict APIs, and restrict file I/O to rank 0. One logic issue was found:

  • _redistribute_sd_for_dtensor docstring mismatch: The docstring states "all other entries are left unchanged," but when a fallback mesh is available, the implementation distributes all remaining plain tensors to the fallback mesh as Replicate DTensors, not just those expected to be DTensors. For models with both ShardTensor and plain parameters, this could cause downstream issues. The docstring should be corrected to accurately document this behavior.

A minor style issue was also identified: a test file saves a plain torch.save() weights file with a .mdlus extension, which is misleading about supported file formats even though the test passes due to type-based dispatch logic.

Last reviewed commit: 52d03eb

Comment on lines +163 to +197
def _redistribute_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
state_dict: dict[str, Any],
device_type: str = "cuda",
) -> dict[str, Any]:
"""Convert plain tensors in *state_dict* to DTensors matching *placements*.

Entries whose key appears in *placements* are converted via
``distribute_tensor`` so that each rank receives its correct local shard.
All other entries are left unchanged.
"""
if not placements:
return state_dict

# Determine a fallback (Replicate) mesh for keys NOT in placements but
# that may still need to be DTensors (e.g. when FSDP use_orig_params=False
# promotes all params to DTensor).
fallback_mesh = next(iter(placements.values()))[0] if placements else None

out: dict[str, Any] = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor) or isinstance(value, DTensor):
out[key] = value
continue

if key in placements:
mesh, plc = placements[key]
out[key] = distribute_tensor(value.to(device_type), mesh, list(plc))
elif fallback_mesh is not None:
out[key] = distribute_tensor(
value.to(device_type), fallback_mesh, [Replicate()]
)
else:
out[key] = value
return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring/implementation mismatch in _redistribute_sd_for_dtensor

The function docstring states "All other entries are left unchanged," but the implementation distributes every plain tensor (regardless of whether its key appears in placements) to the fallback_mesh as a Replicate DTensor when fallback_mesh is not None (lines 191–194). For a model that has both user-managed ShardTensor parameters and plain (non-distributed) parameters, this would wrap the plain tensors in a Replicate DTensor backed by the ShardTensor mesh — which may cause shape/device issues downstream when set_model_state_dict tries to assign them.

The docstring should be corrected to accurately describe this fallback behavior, or (if plain parameters are not meant to become DTensors) the fallback branch should be guarded so it only fires for keys that are actually expected to be DTensors.

Suggested change
def _redistribute_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
state_dict: dict[str, Any],
device_type: str = "cuda",
) -> dict[str, Any]:
"""Convert plain tensors in *state_dict* to DTensors matching *placements*.
Entries whose key appears in *placements* are converted via
``distribute_tensor`` so that each rank receives its correct local shard.
All other entries are left unchanged.
"""
if not placements:
return state_dict
# Determine a fallback (Replicate) mesh for keys NOT in placements but
# that may still need to be DTensors (e.g. when FSDP use_orig_params=False
# promotes all params to DTensor).
fallback_mesh = next(iter(placements.values()))[0] if placements else None
out: dict[str, Any] = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor) or isinstance(value, DTensor):
out[key] = value
continue
if key in placements:
mesh, plc = placements[key]
out[key] = distribute_tensor(value.to(device_type), mesh, list(plc))
elif fallback_mesh is not None:
out[key] = distribute_tensor(
value.to(device_type), fallback_mesh, [Replicate()]
)
else:
out[key] = value
return out
def _redistribute_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
state_dict: dict[str, Any],
device_type: str = "cuda",
) -> dict[str, Any]:
"""Convert plain tensors in *state_dict* to DTensors matching *placements*.
Entries whose key appears in *placements* are converted via
``distribute_tensor`` so that each rank receives its correct local shard.
When fallback_mesh is provided, remaining plain tensors are also
distributed as ``Replicate`` DTensors on that mesh.
"""

Comment on lines +413 to +415
weights_file = f"{shared_tmp_dir}/trained_shard.mdlus"
if dm.rank == 0:
torch.save(model.state_dict(), weights_file)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect file extension for torch.save output

The weights file is named with a .mdlus extension, but it is saved using torch.save(model.state_dict(), ...), which produces a raw .pt-format file — not the zip/tar archive that the .mdlus format actually requires. The test passes because load_model_weights dispatches on model type (not file extension): since _PosEmbedModel is a plain nn.Module, it calls torch.load directly and ignores the extension. However, the naming is misleading and will confuse future readers about what formats are valid for .mdlus files.

Suggested change
weights_file = f"{shared_tmp_dir}/trained_shard.mdlus"
if dm.rank == 0:
torch.save(model.state_dict(), weights_file)
weights_file = f"{shared_tmp_dir}/trained_shard.pt"
if dm.rank == 0:
torch.save(model.state_dict(), weights_file)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant