Skip to content

[ROCm] reduce temp compile memory usage before training starts#85

Closed
cj401-amd wants to merge 29 commits into
mainfrom
cj-fix-tmp-mem_rocm-main
Closed

[ROCm] reduce temp compile memory usage before training starts#85
cj401-amd wants to merge 29 commits into
mainfrom
cj-fix-tmp-mem_rocm-main

Conversation

@cj401-amd

Copy link
Copy Markdown

Description

Previously, it was observed that excessively temp compile memory usage with various models as shown https://github.com/orgs/ROCm/projects/14/views/7?pane=issue&itemId=149359236&issue=ROCm%7Cframeworks-internal%7C15124. This PR is trying to address the issue and reduce the temp compile memory usage without impacting the training performance negatively as shown https://github.com/orgs/ROCm/projects/14/views/1?filterQuery=assignee%3A%22cj401-amd%22&pane=issue&itemId=179156668&issue=ROCm%7Cframeworks-internal%7C16353. it can potentially avoid the crash due to OOM resulting from excessive temp compile usage for some training workloads. i.e., 405B.

file changes

File Change Status
train.py Remove nan_to_num + grad_dtype guard + hasattr JAX config guard + flax_always_shard_variable Clean
train_compile.py hasattr guard for JAX 0.7.1 Clean
sharding.py skip_trivial_specs param Clean
gather_reduce_sc.py functools.partial(jax.jit, ...) fix Clean
configs/base.yml float32_weight_sum: false with updated comment Clean
configs/types.py float32_weight_sum default synced to False Clean
attentions.py skip_trivial_specs=True added Clean
attention_op.py Synthetic mask shortcut, scale_factor removed, context_parallel_axis restored, context_parallel_strategy removed with comment Clean
normalizations.py Replaced einsum with y * effective_scale + explicit sharding Clean
moe.py 3-tuple tiling with justifying comment, removed .astype(dtype) Clean
embeddings.py Fixed activation_length_no_exp → output_default_axis_names, added nn.with_logical_constraint Clean
pipeline.py Replaced shard_map-based permute/shift with pure array ops, removed extra sharding ops Clean
mixtral.py Converted NNX→Linen MixtralDecoderLayer, MixtralDecoderLayerToLinen = MixtralDecoderLayer alias Clean
decoders.py shared_embedding as class field, logits sharding constraints added Clean
multi_token_prediction.py shared_embedding removed from call signatures Clean
models.py shared_embedding passed at construction, removed from call sites Clean
deepseek.py remove_size_one_mesh_axis for activation_pspec, removed jax.reshard calls, nd_dense_init(1.0, ...) Clean

temp memory changes

Model Branch Total Output Temp Argument Host Temp Δ Temp vs rocm-main
ds-proxy-N1-ep2-pp4 rocm-main 59.8 GB 14.6 GB 45.1 GB 14.6 GB 0.0 GB
ds-proxy-N1-ep2-pp4 cj-fix-tmp-mem_rocm-main 44.9 GB 14.6 GB 30.3 GB 14.6 GB 0.0 GB −14.8 GB
ds-proxy-se2-e256-h4096 rocm-main 66.0 GB 29.3 GB 36.8 GB 29.3 GB 0.0 GB
ds-proxy-se2-e256-h4096 cj-fix-tmp-mem_rocm-main 60.3 GB 29.3 GB 31.1 GB 29.3 GB 0.0 GB −5.7 GB

gulsumgudukbay and others added 29 commits April 23, 2026 15:55
(cherry picked from commit a2f9860)
fix rocm version finding
Removed ROCm specific environment variables for fused-attention.
- Import remove_size_one_mesh_axis from sharding utils
- Use remove_size_one_mesh_axis for activation_pspec to handle all mesh
  axes including fsdp_transpose, expert, context correctly
- Remove jax.reshard calls that caused extra temp memory allocation
- Fix dense_init_scale to 1.0 (was self.config.dense_init_scale)

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
JAX 0.7.1 requires fun as first positional argument to jax.jit, so
@jax.jit(static_argnames=[...]) fails. Use functools.partial instead.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This config option was added in JAX > 0.7.1; guard with hasattr so
the code runs on both older and newer JAX versions.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
Three root causes identified vs the working cj-reduce-tmp-mem_rocm-main:

1. GMM 9-tuple tile sizes (moe.py): newer rocm-main reintroduced 9-tuple
   (fwd+dlhs+drhs) tiling which changes XLA backward-pass planning and
   adds ~0.5-1 GB temp memory. Revert to 3-tuple (forward-only).

2. Trivial sharding constraints (sharding.py, mixtral.py, attentions.py):
   For pp=8 ep=1, ALL mesh axes inside the pipeline vmap are size 1.
   Every with_sharding_constraint resolves to all-None/() PartitionSpecs
   (trivial), which XLA loop_broadcast_fusion hoists into the pipeline
   scan carry as extra buffers (+5 GB). Fix: add skip_trivial_specs param
   to maybe_shard_with_logical; replace nn.with_logical_constraint in
   MixtralDecoderLayer with shard() helper using skip_trivial_specs=True;
   same for Attention._maybe_shard_with_logical.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
HLO analysis shows pipeline_module.init_states/_maybe_shard_with_logical
at sharding.py:115 generating loop_broadcast_fusion entries with
bf16[1,1,4096,4096] tensors in the preallocated-temp pool for pp=8 ep=1.
These are trivial constraints (all size-1 mesh axes) that XLA cannot
fully eliminate. Skip them to avoid polluting the scan carry.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
This nan_to_num tree_map forced materialization of every gradient tensor
(f32[8,16,4096,2048] etc.) as an explicit elementwise op, preventing XLA
from eliding them as loop-invariant values through the pipeline scan carry.
Result: +10 sub-computations and +80 parameter() occurrences per compiled
module, blocking loop_broadcast_fusion and adding ~6 GB preallocated temp.

The good branch (cj-reduce-tmp-mem_rocm-main) never had this call.
FP8 NaN sanitization is still handled conditionally in the fp8_stats block.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
- Remove 'cj-reduce ---' debug prefix from print_compiled_memory_stats log
- Restore shard_optimizer_over_data guard (was commented out)
- Restore compiled_trainstep_file guard so pre-compiled files skip recompile
  (forced compile broke any run using compiled_trainstep_file)

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
attentions.py:
- Restore depth_scaling = jnp.sqrt(self.head_dim) in the default branch.
  Both branches were incorrectly set to 1.0, eliminating the T5-style
  1/sqrt(d_k) folded into the query weight initializer. For Mixtral with
  head_dim=128 this produced query weights ~11x larger than intended,
  degrading training convergence from step 0.

attention_op.py:
- Restore context_parallel_axis=self.config.context_sharding (was hardcoded
  to "context", silently breaking the ep-as-cp mesh config where
  context_sharding="expert").
- Add comment explaining scale_factor is intentionally omitted: passing 1.0
  disables QK scaling, while TE's default (None) auto-computes 1/sqrt(head_dim).
- Add comment explaining context_parallel_strategy is omitted because the
  installed TE 2.6.x DotProductAttention does not accept that parameter.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
…um sync

moe.py: Add comment explaining why 3-tuple tiling is correct. The 9-tuple
(fwd+dlhs+drhs) is only used by the megablox custom VJP (_gmm_bwd reads
tiling[3:6] and tiling[-3:] for backward passes). ds-proxy uses megablox=False
/ use_tokamax_gmm=False (jax.lax.ragged_dot path) which only reads tiling[0:3],
so the extra 6 values were always ignored. base.yml documents this explicitly:
"megablox/jax ragged dot - supports forward pass only".

embeddings.py: Replace undefined axis "activation_length_no_exp" with
"activation_length" (output_default_axis_names) in the ShardMode.EXPLICIT
path. The undefined axis would silently map to None (fully replicated) in any
explicit-shard config that lacks a rule for it (e.g. deepseek3-671b-batchsplit).
ds-proxy is unaffected (uses shard_mode="auto"), but the latent bug is real.

types.py / base.yml: Sync float32_weight_sum default from True (types.py) to
False (base.yml). The False default was set intentionally in commit 4fceae4
to eliminate a ~2 GB temporary f32 tensor from the MoE weight_sum einsum.
bf16 summation over 4 experts (num_experts_per_tok) is numerically acceptable.
Update comment to document the memory trade-off.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>
@cj401-amd cj401-amd requested a review from yeandy April 30, 2026 21:46
@i-chaochen

Copy link
Copy Markdown

@cj401-amd We need to land the PR to upstream maxtext, we cannot afford the maintenance burden on rocm/maxtext:rocm-main, if you work can reduce the memory on other side, directly go to upstream maxtext.

@i-chaochen i-chaochen left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just keep those general changes (no ROCm specific ones), and verify on another side.

@i-chaochen i-chaochen left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up your commits and agaisnt to the correct branch (rocm-main), then upstream to maxtext, not here

@cj401-amd cj401-amd closed this May 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants