fix(eval): make the slow-eval tier work on multi-host GPU#885
Closed
danbraunai-goodfire wants to merge 3 commits into
Closed
fix(eval): make the slow-eval tier work on multi-host GPU#885danbraunai-goodfire wants to merge 3 commits into
danbraunai-goodfire wants to merge 3 commits into
Conversation
…ject it The CI-transformer attention always requested cuDNN flash, which only accepts fp16/bf16/fp8. The slow-eval path reads the CI fn out in fp32 (by design), so on GPU cuDNN raised NotImplementedError at the first slow eval (step 1000), crashing the run (the baseline died here too; CPU tests passed because they use the XLA impl). Use cuDNN for fp16/bf16 (training/fast-eval, unchanged) and the XLA impl for fp32. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01UQQMNAqT6t8VJJhCajNEbE
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01NaCwbA8z5iK9YuaANUToNB
accumulate_site_reductions did np.asarray(flat_lower[site]) / flat_logits[site], but those keep the dp-sharded batch axis, so on >1 process they span non-addressable devices and np.asarray raises RuntimeError. (The density/sum reductions are all-reduced, hence addressable — only the raw per-position sample is sharded.) Gather across processes with multihost_utils.process_allgather (tiled=True); order is irrelevant for the histogram sample. This is the second multi-host slow-eval blocker, after the cuDNN-fp32 attention fix. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01UQQMNAqT6t8VJJhCajNEbE
Collaborator
Author
|
Validated on btdr (2 nodes / 16 GPU): with both fixes, a run clears the step-1000 slow eval (previously the crash point) and continues training — |
This was referenced Jun 23, 2026
Collaborator
Author
|
superseded by #905 |
danbraunai-goodfire
added a commit
that referenced
this pull request
Jun 29, 2026
…ot fp32 (#905) * fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32 The slow/plot eval tier deliberately read the CI fn out in fp32, but the CI transformer's attention routes to cuDNN flash, which rejects fp32 — so the slow tier crashed at the first slow eval on GPU: NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32 (hit by every run, including bf16-only ones — unrelated to fp8 work). Rather than route fp32 to the XLA attention impl (#885), run the CI-fn readout in training precision (bf16) — matching train.py / eval.py and the hidden-acts + attn-pattern slow tiers, which already do. This is the more faithful readout (the deployed model runs bf16) and keeps cuDNN flash (faster, no (B,H,T,T) materialize). Reductions/returns stay fp32 so the host accumulation is byte-unchanged. Also incorporates #885's precision-independent multihost fix: the per-position histogram sample keeps the dp-sharded batch axis, so `np.asarray` on >1 process spans non-addressable devices — gather it with process_allgather(tiled=True). Supersedes #885's ci_fn.py change (no fp32 attention path is needed once the slow tier is bf16); incorporates its slow_eval.py gather fix. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL * style(slow-eval): trim the bf16-readout and gather comments to the load-bearing why Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL --------- Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The slow/plot eval tier crashed on multi-host GPU (>1 process) at the first slow eval (step 1000). Two independent bugs, both fixed here:
1. cuDNN-fp32 attention (
param_decomp/ci_fn.py). The CI-transformer attention always requested the cuDNN flash impl viaattn_implementation(), which only acceptsfp16/bf16/fp8. Training and the fast eval run the CI fn under bf16 autocast, but the slow tier deliberately reads it out in fp32 (slow_eval.py: "a fp32-CI-fn readout"). On GPU that raised:Fix: pick the impl by dtype — cuDNN for fp16/bf16 (training + fast eval, unchanged), the XLA impl for fp32 (slow eval). The XLA composite materializes
(B,H,T,T), but that only matters for the long-sequence target forward, not this seq-512 CI transformer.2.
np.asarrayon a process-sharded array (param_decomp/slow_eval.py).accumulate_site_reductionsdidnp.asarray(flat_lower[site])/flat_logits[site], which keep the dp-sharded batch axis. On >1 process those span non-addressable devices andnp.asarrayraises:(The density/sum reductions are all-reduced → addressable, so they were fine; only the raw per-position histogram sample is sharded.) Fix: gather across processes with
jax.experimental.multihost_utils.process_allgather(..., tiled=True).Motivation and Context
Together these crash every multi-GPU LM decomposition run at its first slow eval — including the in-flight baseline
p-0ff8e5d3, which reached step 1000 then died from the gRPC coordination cascade these exceptions trigger on the other ranks. CPU tests never caught either:attn_implementation()returns"xla"off-GPU, and single-processnp.asarrayis addressable.How Has This Been Tested?
make type/ basedpyright clean.param_decomp/tests/test_slow_eval.py: 26 passed (MPLBACKEND=agg; the data-path + gather tests pass, incl.accumulate_site_reductions).Q must be fp16…error, then the non-addressable-fetchRuntimeError); with both fixes the slow tier completes and training continues. [validation in progress]Does this PR introduce a breaking change?
No. Training and fast-eval paths (bf16, addressable reductions) are unchanged; only the fp32 attention path switches to the XLA impl and the sharded histogram sample is gathered across processes.
🤖 Generated with Claude Code