Skip to content

fix(eval): make the slow-eval tier work on multi-host GPU#885

Closed
danbraunai-goodfire wants to merge 3 commits into
feature/jaxfrom
fix/ci-fn-fp32-attn
Closed

fix(eval): make the slow-eval tier work on multi-host GPU#885
danbraunai-goodfire wants to merge 3 commits into
feature/jaxfrom
fix/ci-fn-fp32-attn

Conversation

@danbraunai-goodfire

@danbraunai-goodfire danbraunai-goodfire commented Jun 22, 2026

Copy link
Copy Markdown
Collaborator

Description

The slow/plot eval tier crashed on multi-host GPU (>1 process) at the first slow eval (step 1000). Two independent bugs, both fixed here:

1. cuDNN-fp32 attention (param_decomp/ci_fn.py). The CI-transformer attention always requested the cuDNN flash impl via attn_implementation(), which only accepts fp16/bf16/fp8. Training and the fast eval run the CI fn under bf16 autocast, but the slow tier deliberately reads it out in fp32 (slow_eval.py: "a fp32-CI-fn readout"). On GPU that raised:

NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32

Fix: pick the impl by dtype — cuDNN for fp16/bf16 (training + fast eval, unchanged), the XLA impl for fp32 (slow eval). The XLA composite materializes (B,H,T,T), but that only matters for the long-sequence target forward, not this seq-512 CI transformer.

2. np.asarray on a process-sharded array (param_decomp/slow_eval.py). accumulate_site_reductions did np.asarray(flat_lower[site]) / flat_logits[site], which keep the dp-sharded batch axis. On >1 process those span non-addressable devices and np.asarray raises:

RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices ...

(The density/sum reductions are all-reduced → addressable, so they were fine; only the raw per-position histogram sample is sharded.) Fix: gather across processes with jax.experimental.multihost_utils.process_allgather(..., tiled=True).

Motivation and Context

Together these crash every multi-GPU LM decomposition run at its first slow eval — including the in-flight baseline p-0ff8e5d3, which reached step 1000 then died from the gRPC coordination cascade these exceptions trigger on the other ranks. CPU tests never caught either: attn_implementation() returns "xla" off-GPU, and single-process np.asarray is addressable.

How Has This Been Tested?

  • make type / basedpyright clean.
  • param_decomp/tests/test_slow_eval.py: 26 passed (MPLBACKEND=agg; the data-path + gather tests pass, incl. accumulate_site_reductions).
  • On btdr (2 nodes / 16 GPU): pre-fix the run aborts at the step-1000 slow eval (first with the Q must be fp16… error, then the non-addressable-fetch RuntimeError); with both fixes the slow tier completes and training continues. [validation in progress]

Does this PR introduce a breaking change?

No. Training and fast-eval paths (bf16, addressable reductions) are unchanged; only the fp32 attention path switches to the XLA impl and the sharded histogram sample is gathered across processes.

🤖 Generated with Claude Code

danbraunai-goodfire and others added 3 commits June 22, 2026 18:16
…ject it

The CI-transformer attention always requested cuDNN flash, which only accepts
fp16/bf16/fp8. The slow-eval path reads the CI fn out in fp32 (by design), so on
GPU cuDNN raised NotImplementedError at the first slow eval (step 1000), crashing
the run (the baseline died here too; CPU tests passed because they use the XLA
impl). Use cuDNN for fp16/bf16 (training/fast-eval, unchanged) and the XLA impl
for fp32.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01UQQMNAqT6t8VJJhCajNEbE
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01NaCwbA8z5iK9YuaANUToNB
accumulate_site_reductions did np.asarray(flat_lower[site]) / flat_logits[site],
but those keep the dp-sharded batch axis, so on >1 process they span
non-addressable devices and np.asarray raises RuntimeError. (The density/sum
reductions are all-reduced, hence addressable — only the raw per-position sample
is sharded.) Gather across processes with multihost_utils.process_allgather
(tiled=True); order is irrelevant for the histogram sample. This is the second
multi-host slow-eval blocker, after the cuDNN-fp32 attention fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01UQQMNAqT6t8VJJhCajNEbE
@danbraunai-goodfire danbraunai-goodfire changed the title fix(ci_fn): use XLA attention impl for fp32 so GPU slow-eval doesn't crash fix(eval): make the slow-eval tier work on multi-host GPU Jun 22, 2026
@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator Author

Validated on btdr (2 nodes / 16 GPU): with both fixes, a run clears the step-1000 slow eval (previously the crash point) and continues training — [eval @ 1000] and [eval @ 2000] both fire, slow-tier metrics (eval/slow/loss/CIHiddenActsReconLoss/*, etc.) log, and there are no Q must be fp16 / non-addressable-fetch errors (run p-a067a8d0).

@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator Author

superseded by #905

danbraunai-goodfire added a commit that referenced this pull request Jun 29, 2026
…ot fp32 (#905)

* fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32

The slow/plot eval tier deliberately read the CI fn out in fp32, but the CI
transformer's attention routes to cuDNN flash, which rejects fp32 — so the slow
tier crashed at the first slow eval on GPU:
    NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32
(hit by every run, including bf16-only ones — unrelated to fp8 work).

Rather than route fp32 to the XLA attention impl (#885), run the CI-fn readout in
training precision (bf16) — matching train.py / eval.py and the hidden-acts +
attn-pattern slow tiers, which already do. This is the more faithful readout (the
deployed model runs bf16) and keeps cuDNN flash (faster, no (B,H,T,T) materialize).
Reductions/returns stay fp32 so the host accumulation is byte-unchanged.

Also incorporates #885's precision-independent multihost fix: the per-position
histogram sample keeps the dp-sharded batch axis, so `np.asarray` on >1 process
spans non-addressable devices — gather it with process_allgather(tiled=True).

Supersedes #885's ci_fn.py change (no fp32 attention path is needed once the slow
tier is bf16); incorporates its slow_eval.py gather fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL

* style(slow-eval): trim the bf16-readout and gather comments to the load-bearing why

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL

---------

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant