fix(ci): make stacked-parity forward pins portable (CI flake)#901
Open
danbraunai-goodfire wants to merge 3 commits into
Open
fix(ci): make stacked-parity forward pins portable (CI flake)#901danbraunai-goodfire wants to merge 3 commits into
danbraunai-goodfire wants to merge 3 commits into
Conversation
test_clean_output_bit_identical used jnp.array_equal and the sibling forward pins used rtol=1e-5/atol=1e-6. Those encode bit-exactness against the fixture-generating host, but ubuntu-latest is a heterogeneous runner pool: float32 matmul reduction order differs by ~1 ULP across CPU microarchitectures, so the exact check and the near-zero elements under the tight atol flaked intermittently (failed 3 of the last 4 CI runs; passed only when a run happened to land on a matching CPU). The forwards are CI-fn-independent and the math is unchanged — this is pure reassociation noise, not a regression. Fold clean_output into the tolerance-based pins (rename -> test_clean_output_matches) and raise the shared forward tolerance to rtol=1e-4/atol=1e-5. Observed cross-arch noise is <=3.1e-6 abs (measured on arm64 vs the fixture host), so the new bound keeps ~3-4x headroom while staying essentially exact for fp32. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk
…rity-tolerance # Conflicts: # param_decomp/tests/stacked_parity/test_stacked_parity.py
…ates) make_train_step now donates the state (donate="all-except-first", #903), so the shared module-level vu/ci_fn were deleted after the first run_step and the second crashed on a deleted buffer. Build them fresh inside make_state with the same deterministic keys — bit-identical init, independent donatable buffers. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_011Y8zwFyb74dftPrAFHnej2
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Relax the
stacked_parityforward-equivalence pins from bit-exact / ultra-tightto a portable fp32 tolerance:
test_clean_output_bit_identical→test_clean_output_matches: dropjnp.array_equal(exact) for the shared_assert_close.rtol=1e-5/atol=1e-6→rtol=1e-4/atol=1e-5.This is the second (and final) of two CI-flake fixes. The first — the slow-eval
cadence smoke timing fix (
e4586e36e,7abe40b5d) — is already onfeature/jax(pushed directly while verifying), so this PR carries only the remaining
stacked_paritychange.Motivation and Context
CI on
feature/jaxwas red on two unrelated tests:test_in_loop_slow_tier_fires_on_cadence_without_stalling— a wall-clockbudget that was actually measuring ~3 serialized off-thread matplotlib renders,
not dispatch; it tipped past 30s under the CI runner's
-n4oversubscription.Fixed already on
feature/jaxby timing only the dispatch the loop pays.stacked_parity(test_clean_output_*+test_site_inputs_and_weight_deltas_match)— flaked intermittently (failed 3 of the last 4 CI runs; passed only when a run
happened to land on a matching CPU).
ubuntu-latestis a heterogeneous runnerpool, and these pins encoded bit-exactness against the fixture-generating host:
float32 matmul reduction order differs by ~1 ULP across CPU microarchitectures.
The forwards are CI-fn-independent and the math is unchanged — pure reassociation
noise, not a regression. This PR addresses it.
How Has This Been Tested?
a strong portability signal:
4 passed, 1 xfailed.failure at 2.19e-6 / rel 4.88e-5), so
atol=1e-5keeps ~3–4× headroom while thepins stay essentially exact for fp32.
-n4alongside thefour heaviest tests (86 passed) and on two CI runs.
Does this PR introduce a breaking change?
No — test-only.
🤖 Generated with Claude Code
https://claude.ai/code/session_01CixstgLX5XaRm8CwAEReXk