Skip to content

perf(graph): classification cache, static-collapse, and shared-State dedup#218

Merged
chaoming0625 merged 8 commits into
mainfrom
worktree-graph-engine-upgrade
Jun 13, 2026
Merged

perf(graph): classification cache, static-collapse, and shared-State dedup#218
chaoming0625 merged 8 commits into
mainfrom
worktree-graph-engine-upgrade

Conversation

@chaoming0625

Copy link
Copy Markdown
Collaborator

Summary

Upgrades the brainstate.graph engine for performance, correctness, and accuracy without changing the public API. The engine flattens object graphs (Nodes + States) into a static GraphDef plus a dynamic state mapping and back; this PR makes the hot paths substantially faster and fixes a latent shared-State double-counting bug, while preserving round-trip semantics.

Three themes:

  1. Hot-path performance — a type-keyed classification cache replaces per-value jax.tree_util.all_leaves + ABCMeta isinstance checks on every traversal; the encoder/traversal kernel classifies each value once; the decoder dispatches on exact edge type.
  2. Correctnessiter_leaf()/states() now dedup shared States by identity (first pre-order path wins), matching the long-standing behavior of flatten()/treefy_states(). The two halves of the API agreed on nodes but disagreed on states; they now agree.
  3. Accuracy/robustness — all-static, value-hashable pytree containers collapse to a single StaticEdge, with a guard that prevents baking live JAX arrays into the static IR.

Performance

Micro-benchmark, MLP depth=32 width=64 (96 leaves, 33 nodes), best of 7×120 iterations, bench_graph.py:

operation baseline after Δ
flatten 629.0 µs 406.8 µs −35%
unflatten 299.3 µs 138.4 µs −54%
treefy_split 627.7 µs 408.7 µs −35%
treefy_merge 312.4 µs 172.0 µs −45%
states 403.8 µs 232.6 µs −42%
iter_leaf 328.1 µs 199.8 µs −39%
iter_node 324.5 µs 187.4 µs −42%
graphdef 619.4 µs 414.8 µs −33%
hash(fresh graphdef) 639.7 µs 448.2 µs −30%
clone 1.01 ms 641.7 µs −37%

Similar −25% … −54% gains hold across the other benchmarked shapes (tiny Linear, MLP-8, WideBag-128, DeepChain-64, shared/tied weights). The hash(cached graphdef) fast path is unchanged (~57 ns). Mechanism confirmed by profiling: a warm flatten() now triggers 0 all_leaves calls on State objects (previously one per State per traversal).

What changed, by phase

Phase A — hot-path performance (_walk.py, _flatten.py)

  • New classify(x) cache keyed by type(x): memoizes each type's kind (GRAPH_NODEPYTREESTATESTATE_LEAFSTATIC). The probe order is behavior-preserving. all_leaves is now consulted at most once per type, not once per value.
  • register_graph_node_type() clears the cache so newly registered node types reclassify correctly (covered by a mutation-checked test).
  • The traversal kernel classifies each node once and iterates impl.flatten(node)[0] directly.
  • The decoder dispatches on exact edge type (type(e) is NodeEdge/PytreeEdge/StateEdge/StateLeafEdge/StaticEdge); the edge dataclasses are frozen=True and unsubclassed, so type(e) is X is equivalent to isinstance and faster.

Phase B — static-collapse (_flatten.py)

  • An all-static, value-hashable pytree container collapses to a single StaticEdge instead of an expanded PytreeEdge, shrinking the IR and the work on both encode and decode.
  • Discovered bug + guard: a naive collapse would fold a bare TreefyState (identity-hashable; its array child encodes to StaticEdge; hash() doesn't raise) into a StaticEdge, baking a live JAX array into the static IR. The guard type(value).__hash__ is not object.__hash__ excludes identity-hashable objects from collapse. This is load-bearing and covered by a dedicated regression test (verified to fail when the guard is removed).

Phase C — shared-State dedup + convert cleanup (_walk.py, _operations.py, _convert.py)

  • iter_leaf()/states() dedup shared States by identity via the existing node-visited set (containers and State leaves are disjoint object sets, so ids never collide). A dedup_leaves=False seam preserves all-paths enumeration for any caller that needs it.
  • graph_to_tree() now pulls States straight out of the index_ref RefMap instead of re-walking via states().

Behavior change & caller audit

iter_leaf()/states() previously yielded a shared State once per path it was reachable by; they now yield it once, matching flatten()/treefy_states(). nodes()/iter_node() already deduped — this removes the asymmetry.

Library-wide caller audit (full suite + per-call-site review of every states()/iter_leaf()/iter_node() consumer): every consumer wants unique states/nodes; none relied on the double-count. check_consistent_aliasing is unaffected (it operates across calls, not within a single traversal). One consumer's output changes: brainstate.nn._utils.count_parameters now counts a tied weight once (correct) instead of twice — i.e. this fixes a latent double-count rather than regressing it.

Testing

  • Full suite at this HEAD: 5312 passed, 23 skipped, 0 failed (~11 min). Graph module alone: 223 passed, 48 subtests.
  • New tests: classification cache (kinds, cache stability, registration invalidation — mutation-checked); bare-TreefyStatePytreeEdge encoding; static-collapse (incl. identity-hashable-not-collapsed regression); shared-State dedup (states/iter_leaf dedup, count matches treefy_states, nodes still dedup, inconsistent-aliasing still detected, dedup_leaves=False recovers all paths); graph_to_tree state extraction + sharing.
  • Round-trip fidelity battery (cycles, self-cycles, shared/tied states, pytree-rooted graphs, int-key dicts, empty nodes, deepcopy independence) all pass.

Notes

  • No public API changes. Pre-1.0; the dedup behavior change is intentional and documented in the states()/iter_leaf() docstrings.
  • The design spec, implementation plan, and benchmark/probe scripts live under dev/superpowers/ and are intentionally gitignored (local dev artifacts), so they are not part of this diff.

Test plan

  • pytest brainstate/graph/ — 223 passed
  • Full suite — 5312 passed, 0 failed
  • bench_graph.py baseline vs final — −25%…−54% across operations
  • probe_graph.py / probe_graph2.py — round-trip + dedup correctness

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @chaoming0625, you have reached your weekly rate limit of 500000 diff characters.

Please try again later or upgrade to continue using Sourcery

@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 94.20290% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
brainstate/graph/_flatten.py 92.85% 0 Missing and 2 partials ⚠️
brainstate/graph/_walk.py 94.87% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@chaoming0625 chaoming0625 merged commit 4e9886c into main Jun 13, 2026
7 checks passed
@chaoming0625 chaoming0625 deleted the worktree-graph-engine-upgrade branch June 13, 2026 16:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant