fix(slow-eval): run the CI-fn readout in training precision (bf16), not fp32#905
Merged
danbraunai-goodfire merged 2 commits intoJun 29, 2026
Merged
Conversation
…ot fp32
The slow/plot eval tier deliberately read the CI fn out in fp32, but the CI
transformer's attention routes to cuDNN flash, which rejects fp32 — so the slow
tier crashed at the first slow eval on GPU:
NotImplementedError: Q must be fp16/bf16/fp8_e4m3fn/fp8_e5m2, got float32
(hit by every run, including bf16-only ones — unrelated to fp8 work).
Rather than route fp32 to the XLA attention impl (#885), run the CI-fn readout in
training precision (bf16) — matching train.py / eval.py and the hidden-acts +
attn-pattern slow tiers, which already do. This is the more faithful readout (the
deployed model runs bf16) and keeps cuDNN flash (faster, no (B,H,T,T) materialize).
Reductions/returns stay fp32 so the host accumulation is byte-unchanged.
Also incorporates #885's precision-independent multihost fix: the per-position
histogram sample keeps the dp-sharded batch axis, so `np.asarray` on >1 process
spans non-addressable devices — gather it with process_allgather(tiled=True).
Supersedes #885's ci_fn.py change (no fp32 attention path is needed once the slow
tier is bf16); incorporates its slow_eval.py gather fix.
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL
…ad-bearing why Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The slow/plot eval tier (
make_slow_eval_step,make_position_ci_step) deliberately read the CI fn out in fp32. But the CI transformer's attention routes to cuDNN flash, which only accepts fp16/bf16/fp8 — so the slow tier crashed at the first slow eval on GPU:This hit every run (including bf16-only ones; surfaced while validating the fp8 work, but unrelated to it).
Fix: run the CI-fn readout in training precision (bf16) — matching
train.py/eval.pyand the hidden-acts + attn-pattern slow tiers, which already cast toCOMPUTE_DT. Reductions and returns are upcast to fp32, so the host-side accumulation is byte-unchanged. The two affected tiers were the only slow-eval compute paths still pinned to fp32.Also incorporates the precision-independent multihost fix from #885: the per-position histogram sample keeps the dp-sharded batch axis, so a bare
np.asarrayon >1 process spans non-addressable devices — gathered viaprocess_allgather(tiled=True).Related Issue
Related to #885. Supersedes #885's
ci_fn.pyfp32-attention routing (no fp32 attention path is needed once the slow tier is bf16); incorporates #885'sslow_eval.pymultihost gather fix. #885 can be closed if this lands.Motivation and Context
We shouldn't run the slow eval in fp32: the deployed model runs bf16, so a bf16 readout is the faithful one, it keeps cuDNN flash (faster, avoids the XLA
(B,H,T,T)materialize the fp32 path needs), and it removes the GPU crash without a special-case attention impl.How Has This Been Tested?
make typeclean;make formatclean.param_decomp/tests/test_slow_eval.py— 26 passed (the hand-rolled reduction reference now mirrors the bf16 readout).Does this PR introduce a breaking change?
No. Slow-eval metric values shift by bf16-rounding (e.g. a CI value near the alive threshold may flip), which is the intended, more-faithful behavior. Host APIs and log keys are unchanged.
🤖 Generated with Claude Code
https://claude.ai/code/session_019uMvZPo7hyAFgGbhDdLrEL