Skip to content

fix(load_run): replicate HF prefix without cross-host allgather (dp>=64 OOM)#898

Merged
danbraunai-goodfire merged 3 commits into
feature/jaxfrom
fix/prefix-replicate-no-allgather
Jun 29, 2026
Merged

fix(load_run): replicate HF prefix without cross-host allgather (dp>=64 OOM)#898
danbraunai-goodfire merged 3 commits into
feature/jaxfrom
fix/prefix-replicate-no-allgather

Conversation

@danbraunai-goodfire

@danbraunai-goodfire danbraunai-goodfire commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

Description

The GPU memory of this allocation did not scale with batch size — it scaled with the data-parallel process count, because of the path fixed here.

build_target placed the frozen HF Llama prefix (embedding + pre-suffix blocks) replicated via jax.device_put(host_array, NamedSharding(mesh, P())). In multi-host JAX that path runs multihost_utils.assert_equalprocess_allgather(tiled=True) to verify the hosts agree, which tiles the ~1 GB bf16 embedding to ~process_count GB. Replace it with jax.make_array_from_callback, which builds the replicated global array directly from each host's local copy — no cross-host allgather.

Motivation and Context

At dp=64 (8 nodes) the tiled allgather needed ~64 GB in a single allocation and OOM'd every rank during model build (RESOURCE_EXHAUSTED: 62.62GiB). It was tempting to blame the doubled train batch, but per-device batch was unchanged (dp doubled with it) — the allocation tracks process count, not batch. dp=32 survived only because the tile was ~half the size. The prefix is read from the same cached safetensors on every host, so the equality check is pure overhead — dropping it is safe and removes the process-count memory scaling.

How Has This Been Tested?

  • make type clean.
  • A dp=64 (batch 128) Llama-3.1-8B layer-18 decomposition that OOM'd at build on this exact path now clears model load and trains; dp=32 runs unaffected.

Does this PR introduce a breaking change?

No. Replicated placement is unchanged; only the construction path differs (skips the cross-host equality allgather).

…64 OOM)

jax.device_put(host_array, P()) runs multihost_utils.assert_equal ->
process_allgather(tiled=True), tiling the ~1GB embedding to ~process_count GB
(OOM at dp=64). Build the replicated global array from each host's local copy
via make_array_from_callback instead.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_012PUbea772bseCPww4uv3JE
@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator Author

@claude could you review this?

@claude

claude Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Claude finished @danbraunai-goodfire's task —— View job


Code Review Completed ✅

Todo List:

  • Set up repository environment
  • Read and analyze the changed file (param_decomp_lab/experiments/lm/load_run.py)
  • Review the technical approach and implementation
  • Check for potential issues, edge cases, or improvements
  • Provide comprehensive code review feedback

Code Review Feedback

Excellent Technical Solution

This PR addresses a real scaling bottleneck with an elegant and well-reasoned fix. The approach demonstrates deep understanding of JAX's distributed semantics.

Technical Analysis

The Problem:

  • jax.device_put(host_array, NamedSharding(mesh, P())) triggers multihost_utils.assert_equal
  • This runs process_allgather(tiled=True) which scales memory usage with process count (~dp GB)
  • At dp>=64, this caused 62+ GB allocations and OOM during model building

The Solution (_replicate_host_pytree):

  • Replaces device_put with jax.make_array_from_callback
  • Creates replicated arrays directly from local host copies
  • Eliminates cross-host equality check and associated memory scaling
  • Safe because HF weights are identical across hosts (same cached safetensors)

Code Quality Assessment

Strengths:

  1. Clear documentation - The function docstring perfectly explains the issue and rationale (lines 75-78)
  2. Correct implementation - Uses jax.make_array_from_callback with proper sharding spec
  3. Targeted fix - Only affects the problematic path (Llama targets), leaves SimpleMLP unchanged
  4. Type safety - Preserves generic typing with [T] parameter
  5. Equinox compatibility - Correctly handles arrays vs. non-arrays with eqx.is_array check

Implementation Details

_replicate_host_pytree (lines 74-85):

repl = NamedSharding(mesh, P())  # Replicated placement
return jax.tree.map(
    lambda a: jax.make_array_from_callback(a.shape, repl, lambda _idx: a)
    if eqx.is_array(a)
    else a,
    tree,
)

This is textbook correct usage of make_array_from_callback:

  • Creates global arrays with replicated sharding
  • Callback lambda _idx: a returns the same local array for any device index
  • Bypasses JAX's cross-host verification entirely

Consistency Analysis

Potential Improvement: The codebase has an inconsistency in replication patterns:

  • SimpleMLP (line 108): Still uses llama_simple_mlp.replicate_prefix()jax.device_put
  • Llama (lines 122-124): Now uses _replicate_host_pytree()jax.make_array_from_callback

Recommendation: Consider updating llama_simple_mlp.replicate_prefix() to use the same pattern for consistency, though this is not blocking since SimpleMLP likely doesn't hit the same scale issues.

Safety & Testing

Memory safety - Eliminates the problematic allocation
Semantic preservation - Same replicated placement, different construction
Distributed correctness - Safe because identical weights across hosts
Tested at scale - Verified to resolve dp=64 OOM and maintain dp=32 compatibility

Minor Observations

  1. Function naming - _replicate_host_pytree is descriptive and follows Python conventions
  2. Placement - Logical location in the file, near the build_target function that uses it
  3. No breaking changes - External interface unchanged

Verdict: LGTM

This is a high-quality fix that:

  • Solves a real scaling problem with minimal, targeted changes
  • Uses the correct JAX distributed API
  • Maintains semantic equivalence with better performance characteristics
  • Includes excellent documentation

Ready to merge - No blocking issues identified.


danbraunai-goodfire and others added 2 commits June 29, 2026 11:16
feature/jax removed the separate HF prefix (the model now carries the
embedding internally and is placed replicated via place_via_shardings),
so the prefix-specific no-allgather helper is gone. Relocated the fix
into place_via_shardings: replicated leaves are built from the host-local
copy (make_array_from_callback) instead of device_put, avoiding the
cross-host equality allgather that OOMs at dp>=64.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2
@danbraunai-goodfire danbraunai-goodfire merged commit 951f5d6 into feature/jax Jun 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant