[ROCm] reduce temp compile memory usage before training starts#85
Closed
cj401-amd wants to merge 29 commits into
Closed
[ROCm] reduce temp compile memory usage before training starts#85cj401-amd wants to merge 29 commits into
cj401-amd wants to merge 29 commits into
Conversation
This reverts commit 11a8852.
(cherry picked from commit a2f9860) fix rocm version finding
Removed ROCm specific environment variables for fused-attention.
- Import remove_size_one_mesh_axis from sharding utils - Use remove_size_one_mesh_axis for activation_pspec to handle all mesh axes including fsdp_transpose, expert, context correctly - Remove jax.reshard calls that caused extra temp memory allocation - Fix dense_init_scale to 1.0 (was self.config.dense_init_scale) Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
JAX 0.7.1 requires fun as first positional argument to jax.jit, so @jax.jit(static_argnames=[...]) fails. Use functools.partial instead. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This config option was added in JAX > 0.7.1; guard with hasattr so the code runs on both older and newer JAX versions. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Three root causes identified vs the working cj-reduce-tmp-mem_rocm-main: 1. GMM 9-tuple tile sizes (moe.py): newer rocm-main reintroduced 9-tuple (fwd+dlhs+drhs) tiling which changes XLA backward-pass planning and adds ~0.5-1 GB temp memory. Revert to 3-tuple (forward-only). 2. Trivial sharding constraints (sharding.py, mixtral.py, attentions.py): For pp=8 ep=1, ALL mesh axes inside the pipeline vmap are size 1. Every with_sharding_constraint resolves to all-None/() PartitionSpecs (trivial), which XLA loop_broadcast_fusion hoists into the pipeline scan carry as extra buffers (+5 GB). Fix: add skip_trivial_specs param to maybe_shard_with_logical; replace nn.with_logical_constraint in MixtralDecoderLayer with shard() helper using skip_trivial_specs=True; same for Attention._maybe_shard_with_logical. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
HLO analysis shows pipeline_module.init_states/_maybe_shard_with_logical at sharding.py:115 generating loop_broadcast_fusion entries with bf16[1,1,4096,4096] tensors in the preallocated-temp pool for pp=8 ep=1. These are trivial constraints (all size-1 mesh axes) that XLA cannot fully eliminate. Skip them to avoid polluting the scan carry. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This nan_to_num tree_map forced materialization of every gradient tensor (f32[8,16,4096,2048] etc.) as an explicit elementwise op, preventing XLA from eliding them as loop-invariant values through the pipeline scan carry. Result: +10 sub-computations and +80 parameter() occurrences per compiled module, blocking loop_broadcast_fusion and adding ~6 GB preallocated temp. The good branch (cj-reduce-tmp-mem_rocm-main) never had this call. FP8 NaN sanitization is still handled conditionally in the fp8_stats block. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
- Remove 'cj-reduce ---' debug prefix from print_compiled_memory_stats log - Restore shard_optimizer_over_data guard (was commented out) - Restore compiled_trainstep_file guard so pre-compiled files skip recompile (forced compile broke any run using compiled_trainstep_file) Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
attentions.py: - Restore depth_scaling = jnp.sqrt(self.head_dim) in the default branch. Both branches were incorrectly set to 1.0, eliminating the T5-style 1/sqrt(d_k) folded into the query weight initializer. For Mixtral with head_dim=128 this produced query weights ~11x larger than intended, degrading training convergence from step 0. attention_op.py: - Restore context_parallel_axis=self.config.context_sharding (was hardcoded to "context", silently breaking the ep-as-cp mesh config where context_sharding="expert"). - Add comment explaining scale_factor is intentionally omitted: passing 1.0 disables QK scaling, while TE's default (None) auto-computes 1/sqrt(head_dim). - Add comment explaining context_parallel_strategy is omitted because the installed TE 2.6.x DotProductAttention does not accept that parameter. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…um sync moe.py: Add comment explaining why 3-tuple tiling is correct. The 9-tuple (fwd+dlhs+drhs) is only used by the megablox custom VJP (_gmm_bwd reads tiling[3:6] and tiling[-3:] for backward passes). ds-proxy uses megablox=False / use_tokamax_gmm=False (jax.lax.ragged_dot path) which only reads tiling[0:3], so the extra 6 values were always ignored. base.yml documents this explicitly: "megablox/jax ragged dot - supports forward pass only". embeddings.py: Replace undefined axis "activation_length_no_exp" with "activation_length" (output_default_axis_names) in the ShardMode.EXPLICIT path. The undefined axis would silently map to None (fully replicated) in any explicit-shard config that lacks a rule for it (e.g. deepseek3-671b-batchsplit). ds-proxy is unaffected (uses shard_mode="auto"), but the latent bug is real. types.py / base.yml: Sync float32_weight_sum default from True (types.py) to False (base.yml). The False default was set intentionally in commit 4fceae4 to eliminate a ~2 GB temporary f32 tensor from the MoE weight_sum einsum. bf16 summation over 4 experts (num_experts_per_tok) is numerically acceptable. Update comment to document the memory trade-off. Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
|
@cj401-amd We need to land the PR to upstream maxtext, we cannot afford the maintenance burden on rocm/maxtext:rocm-main, if you work can reduce the memory on other side, directly go to upstream maxtext. |
i-chaochen
reviewed
Apr 30, 2026
i-chaochen
left a comment
There was a problem hiding this comment.
just keep those general changes (no ROCm specific ones), and verify on another side.
i-chaochen
suggested changes
Apr 30, 2026
i-chaochen
left a comment
There was a problem hiding this comment.
clean up your commits and agaisnt to the correct branch (rocm-main), then upstream to maxtext, not here
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Previously, it was observed that excessively temp compile memory usage with various models as shown https://github.com/orgs/ROCm/projects/14/views/7?pane=issue&itemId=149359236&issue=ROCm%7Cframeworks-internal%7C15124. This PR is trying to address the issue and reduce the temp compile memory usage without impacting the training performance negatively as shown https://github.com/orgs/ROCm/projects/14/views/1?filterQuery=assignee%3A%22cj401-amd%22&pane=issue&itemId=179156668&issue=ROCm%7Cframeworks-internal%7C16353. it can potentially avoid the crash due to OOM resulting from excessive temp compile usage for some training workloads. i.e., 405B.
file changes
train.pytrain_compile.pysharding.pygather_reduce_sc.pyconfigs/base.ymlconfigs/types.pyattentions.pyattention_op.pynormalizations.pymoe.pyembeddings.pypipeline.pymixtral.pydecoders.pymulti_token_prediction.pymodels.pydeepseek.pytemp memory changes