Skip to content

fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32#905

Merged
danbraunai-goodfire merged 2 commits into
feature/jaxfrom
fix/slow-eval-training-precision
Jun 29, 2026
Merged

fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32#905
danbraunai-goodfire merged 2 commits into
feature/jaxfrom
fix/slow-eval-training-precision

Conversation

@danbraunai-goodfire

Copy link
Copy Markdown
Collaborator

Description

The slow/plot eval tier (make_slow_eval_step, make_position_ci_step) deliberately read the CI fn out in fp32. But the CI transformer's attention routes to cuDNN flash, which only accepts fp16/bf16/fp8 — so the slow tier crashed at the first slow eval on GPU:

NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32

This hit every run (including bf16-only ones; surfaced while validating the fp8 work, but unrelated to it).

Fix: run the CI-fn readout in training precision (bf16) — matching train.py / eval.py and the hidden-acts + attn-pattern slow tiers, which already cast to COMPUTE_DT. Reductions and returns are upcast to fp32, so the host-side accumulation is byte-unchanged. The two affected tiers were the only slow-eval compute paths still pinned to fp32.

Also incorporates the precision-independent multihost fix from #885: the per-position histogram sample keeps the dp-sharded batch axis, so a bare np.asarray on >1 process spans non-addressable devices — gathered via process_allgather(tiled=True).

Related Issue

Related to #885. Supersedes #885's ci_fn.py fp32-attention routing (no fp32 attention path is needed once the slow tier is bf16); incorporates #885's slow_eval.py multihost gather fix. #885 can be closed if this lands.

Motivation and Context

We shouldn't run the slow eval in fp32: the deployed model runs bf16, so a bf16 readout is the faithful one, it keeps cuDNN flash (faster, avoids the XLA (B,H,T,T) materialize the fp32 path needs), and it removes the GPU crash without a special-case attention impl.

How Has This Been Tested?

  • make type clean; make format clean.
  • param_decomp/tests/test_slow_eval.py — 26 passed (the hand-rolled reduction reference now mirrors the bf16 readout).
  • Found while debugging an fp8 parity ladder whose runs (incl. the bf16 baseline) all crashed here; disabling slow eval unblocked them, confirming this is the cause.

Does this PR introduce a breaking change?

No. Slow-eval metric values shift by bf16-rounding (e.g. a CI value near the alive threshold may flip), which is the intended, more-faithful behavior. Host APIs and log keys are unchanged.

🤖 Generated with Claude Code

https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL

danbraunai-goodfire and others added 2 commits June 29, 2026 10:38
…ot fp32

The slow/plot eval tier deliberately read the CI fn out in fp32, but the CI
transformer's attention routes to cuDNN flash, which rejects fp32 — so the slow
tier crashed at the first slow eval on GPU:
    NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32
(hit by every run, including bf16-only ones — unrelated to fp8 work).

Rather than route fp32 to the XLA attention impl (#885), run the CI-fn readout in
training precision (bf16) — matching train.py / eval.py and the hidden-acts +
attn-pattern slow tiers, which already do. This is the more faithful readout (the
deployed model runs bf16) and keeps cuDNN flash (faster, no (B,H,T,T) materialize).
Reductions/returns stay fp32 so the host accumulation is byte-unchanged.

Also incorporates #885's precision-independent multihost fix: the per-position
histogram sample keeps the dp-sharded batch axis, so `np.asarray` on >1 process
spans non-addressable devices — gather it with process_allgather(tiled=True).

Supersedes #885's ci_fn.py change (no fp32 attention path is needed once the slow
tier is bf16); incorporates its slow_eval.py gather fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL
…ad-bearing why

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL
@danbraunai-goodfire danbraunai-goodfire merged commit 4135496 into feature/jax Jun 29, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant