Skip to content

[Training] Add memory estimator breakdown#3910

Open
yaoyu-33 wants to merge 3 commits into
mainfrom
priya/issue1673-memory-estimator
Open

[Training] Add memory estimator breakdown#3910
yaoyu-33 wants to merge 3 commits into
mainfrom
priya/issue1673-memory-estimator

Conversation

@yaoyu-33
Copy link
Copy Markdown
Contributor

Summary

  • Add a structured estimate_training_memory API and formatter around the existing theoretical memory utility.
  • Account for MoE layer patterns, routed expert EP/ETP sharding, context parallel partitioning, and distributed optimizer shard sizes.
  • Preserve the existing training-time aggregate memory report while adding docs, focused arithmetic tests, and memory-tuning skill guidance.

Closes #1673

Validation

  • git diff --check origin/main..HEAD
  • python3 -m py_compile src/megatron/bridge/training/utils/theoretical_memory_utils.py tests/unit_tests/training/utils/test_theoretical_memory_utils.py
  • /home/yuya/.local/bin/ruff format src/megatron/bridge/training/utils/theoretical_memory_utils.py tests/unit_tests/training/utils/test_theoretical_memory_utils.py
  • /home/yuya/.local/bin/ruff check src/megatron/bridge/training/utils/theoretical_memory_utils.py tests/unit_tests/training/utils/test_theoretical_memory_utils.py
  • /home/yuya/.local/bin/pre-commit run --all-files

Not run:

  • uv run pre-commit run --all-files: local uv run cannot install nvidia-resiliency-ext==0.6.0 because this workstation is manylinux_2_31_x86_64 and the locked wheel is only available for manylinux_2_39_{x86_64,aarch64}.
  • Targeted pytest: not run locally per task instructions; cw, sbatch, and srun are not available in this environment.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 20e5d87

Comment on lines +542 to +552
def _expert_optimizer_shard_size(
config: ConfigContainer,
*,
tensor_parallel_size: int,
expert_parallel_size: int,
expert_tensor_parallel_size: int,
) -> int:
data_parallel_size = _positive_int_attr(config, "data_parallel_size", 1)
shard_size = data_parallel_size * tensor_parallel_size
shard_size //= max(1, expert_parallel_size * expert_tensor_parallel_size)
return max(1, shard_size)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: context_parallel_size is missing from the expert optimizer shard size.

For dense parameters (line 151), the optimizer shard size correctly includes CP: data_parallel_size * context_parallel_size. But for experts, CP ranks also share the same expert parameters, so the expert data-parallel shard size should be data_parallel_size * tensor_parallel_size * context_parallel_size / (EP * ETP), not data_parallel_size * tensor_parallel_size / (EP * ETP).

When context_parallel_size > 1, this under-counts the shard size by a factor of CP, making bytes_per_parameter too high for experts (overestimate).

Suggested change
def _expert_optimizer_shard_size(
config: ConfigContainer,
*,
tensor_parallel_size: int,
expert_parallel_size: int,
expert_tensor_parallel_size: int,
) -> int:
data_parallel_size = _positive_int_attr(config, "data_parallel_size", 1)
shard_size = data_parallel_size * tensor_parallel_size
shard_size //= max(1, expert_parallel_size * expert_tensor_parallel_size)
return max(1, shard_size)
def _expert_optimizer_shard_size(
config: ConfigContainer,
*,
tensor_parallel_size: int,
context_parallel_size: int,
expert_parallel_size: int,
expert_tensor_parallel_size: int,
) -> int:
data_parallel_size = _positive_int_attr(config, "data_parallel_size", 1)
shard_size = data_parallel_size * tensor_parallel_size * context_parallel_size
shard_size //= max(1, expert_parallel_size * expert_tensor_parallel_size)
return max(1, shard_size)

The caller at line 168 would also need context_parallel_size=context_parallel_size.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 304786d: expert optimizer shard size now includes context_parallel_size, and the MoE CP test asserts the corrected bytes-per-parameter value.

Comment on lines +39 to +41
- `6 + 12 / shard_size` bytes per parameter when the distributed optimizer is enabled

For dense parameters, `shard_size` is `data_parallel_size * context_parallel_size`.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: the doc says shard_size for dense is data_parallel_size * context_parallel_size, but if the bug above is fixed the expert shard formula should also be updated to mention CP:

For routed MoE experts, expert parameters are divided by expert_model_parallel_size * expert_tensor_parallel_size, and optimizer state uses the expert data-parallel shard size which includes context parallel ranks.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated in 304786d: the docs now state that the expert data-parallel shard size includes context parallel ranks.

@claude
Copy link
Copy Markdown
Contributor

claude Bot commented May 21, 2026

Review of [Training] Add memory estimator breakdown. Bug: _expert_optimizer_shard_size is missing context_parallel_size (theoretical_memory_utils.py:542-552). The dense optimizer shard size correctly includes CP at line 151 but the expert shard size at line 550 uses data_parallel_size times tensor_parallel_size without multiplying by CP. When CP > 1 this under-counts the shard size overestimating expert memory. See inline comment for suggested fix. Test coverage gaps: MTP layer counting, moe_layer_freq as a list, moe_latent_size latent projection branch, shared expert parameters, VPP activation penalty, and report_theoretical_memory integration are all untested. Suggested test cases: No perf tests impacted.

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33 yaoyu-33 force-pushed the priya/issue1673-memory-estimator branch from 20e5d87 to 304786d Compare May 21, 2026 00:35
@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 304786d

@yaoyu-33 yaoyu-33 added area:training Training loop, callbacks, and runtime integration feature New capabilities, enhancements, or enablement work needs-review PR is ready for code review and waiting on a reviewer labels May 21, 2026
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 01d2da3

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test dc1c741

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:training Training loop, callbacks, and runtime integration feature New capabilities, enhancements, or enablement work needs-review PR is ready for code review and waiting on a reviewer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Memory Estimator from MBridge

1 participant