fix(load_run): replicate HF prefix without cross-host allgather (dp>=64 OOM)#898
Conversation
…64 OOM) jax.device_put(host_array, P()) runs multihost_utils.assert_equal -> process_allgather(tiled=True), tiling the ~1GB embedding to ~process_count GB (OOM at dp=64). Build the replicated global array from each host's local copy via make_array_from_callback instead. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_012PUbea772bseCPww4uv3JE
|
@claude could you review this? |
|
Claude finished @danbraunai-goodfire's task —— View job Code Review Completed ✅Todo List:
Code Review Feedback✅ Excellent Technical SolutionThis PR addresses a real scaling bottleneck with an elegant and well-reasoned fix. The approach demonstrates deep understanding of JAX's distributed semantics. Technical AnalysisThe Problem:
The Solution (
Code Quality AssessmentStrengths:
Implementation Details
repl = NamedSharding(mesh, P()) # Replicated placement
return jax.tree.map(
lambda a: jax.make_array_from_callback(a.shape, repl, lambda _idx: a)
if eqx.is_array(a)
else a,
tree,
)This is textbook correct usage of
Consistency AnalysisPotential Improvement: The codebase has an inconsistency in replication patterns:
Recommendation: Consider updating Safety & Testing✅ Memory safety - Eliminates the problematic allocation Minor Observations
Verdict: LGTM ✅This is a high-quality fix that:
Ready to merge - No blocking issues identified. |
feature/jax removed the separate HF prefix (the model now carries the embedding internally and is placed replicated via place_via_shardings), so the prefix-specific no-allgather helper is gone. Relocated the fix into place_via_shardings: replicated leaves are built from the host-local copy (make_array_from_callback) instead of device_put, avoiding the cross-host equality allgather that OOMs at dp>=64. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2
Description
The GPU memory of this allocation did not scale with batch size — it scaled with the data-parallel process count, because of the path fixed here.
build_targetplaced the frozen HF Llama prefix (embedding + pre-suffix blocks) replicated viajax.device_put(host_array, NamedSharding(mesh, P())). In multi-host JAX that path runsmultihost_utils.assert_equal→process_allgather(tiled=True)to verify the hosts agree, which tiles the ~1 GB bf16 embedding to ~process_countGB. Replace it withjax.make_array_from_callback, which builds the replicated global array directly from each host's local copy — no cross-host allgather.Motivation and Context
At
dp=64(8 nodes) the tiled allgather needed ~64 GB in a single allocation and OOM'd every rank during model build (RESOURCE_EXHAUSTED: 62.62GiB). It was tempting to blame the doubled train batch, but per-device batch was unchanged (dp doubled with it) — the allocation tracks process count, not batch.dp=32survived only because the tile was ~half the size. The prefix is read from the same cached safetensors on every host, so the equality check is pure overhead — dropping it is safe and removes the process-count memory scaling.How Has This Been Tested?
make typeclean.dp=64(batch 128) Llama-3.1-8B layer-18 decomposition that OOM'd at build on this exact path now clears model load and trains;dp=32runs unaffected.Does this PR introduce a breaking change?
No. Replicated placement is unchanged; only the construction path differs (skips the cross-host equality allgather).