perf(graph): classification cache, static-collapse, and shared-State dedup#218
Merged
Conversation
Contributor
There was a problem hiding this comment.
Sorry @chaoming0625, you have reached your weekly rate limit of 500000 diff characters.
Please try again later or upgrade to continue using Sourcery
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Upgrades the
brainstate.graphengine for performance, correctness, and accuracy without changing the public API. The engine flattens object graphs (Nodes +States) into a staticGraphDefplus a dynamic state mapping and back; this PR makes the hot paths substantially faster and fixes a latent shared-Statedouble-counting bug, while preserving round-trip semantics.Three themes:
jax.tree_util.all_leaves+ ABCMetaisinstancechecks on every traversal; the encoder/traversal kernel classifies each value once; the decoder dispatches on exact edge type.iter_leaf()/states()now dedup sharedStates by identity (first pre-order path wins), matching the long-standing behavior offlatten()/treefy_states(). The two halves of the API agreed on nodes but disagreed on states; they now agree.StaticEdge, with a guard that prevents baking live JAX arrays into the static IR.Performance
Micro-benchmark, MLP depth=32 width=64 (96 leaves, 33 nodes), best of 7×120 iterations,
bench_graph.py:Similar −25% … −54% gains hold across the other benchmarked shapes (tiny
Linear, MLP-8,WideBag-128,DeepChain-64, shared/tied weights). Thehash(cached graphdef)fast path is unchanged (~57 ns). Mechanism confirmed by profiling: a warmflatten()now triggers 0all_leavescalls onStateobjects (previously one perStateper traversal).What changed, by phase
Phase A — hot-path performance (
_walk.py,_flatten.py)classify(x)cache keyed bytype(x): memoizes each type's kind (GRAPH_NODE→PYTREE→STATE→STATE_LEAF→STATIC). The probe order is behavior-preserving.all_leavesis now consulted at most once per type, not once per value.register_graph_node_type()clears the cache so newly registered node types reclassify correctly (covered by a mutation-checked test).impl.flatten(node)[0]directly.type(e) is NodeEdge/PytreeEdge/StateEdge/StateLeafEdge/StaticEdge); the edge dataclasses arefrozen=Trueand unsubclassed, sotype(e) is Xis equivalent toisinstanceand faster.Phase B — static-collapse (
_flatten.py)StaticEdgeinstead of an expandedPytreeEdge, shrinking the IR and the work on both encode and decode.TreefyState(identity-hashable; its array child encodes toStaticEdge;hash()doesn't raise) into aStaticEdge, baking a live JAX array into the static IR. The guardtype(value).__hash__ is not object.__hash__excludes identity-hashable objects from collapse. This is load-bearing and covered by a dedicated regression test (verified to fail when the guard is removed).Phase C — shared-State dedup + convert cleanup (
_walk.py,_operations.py,_convert.py)iter_leaf()/states()dedup sharedStates by identity via the existing node-visitedset (containers andStateleaves are disjoint object sets, so ids never collide). Adedup_leaves=Falseseam preserves all-paths enumeration for any caller that needs it.graph_to_tree()now pullsStates straight out of theindex_refRefMapinstead of re-walking viastates().Behavior change & caller audit
iter_leaf()/states()previously yielded a sharedStateonce per path it was reachable by; they now yield it once, matchingflatten()/treefy_states().nodes()/iter_node()already deduped — this removes the asymmetry.Library-wide caller audit (full suite + per-call-site review of every
states()/iter_leaf()/iter_node()consumer): every consumer wants unique states/nodes; none relied on the double-count.check_consistent_aliasingis unaffected (it operates across calls, not within a single traversal). One consumer's output changes:brainstate.nn._utils.count_parametersnow counts a tied weight once (correct) instead of twice — i.e. this fixes a latent double-count rather than regressing it.Testing
TreefyState→PytreeEdgeencoding; static-collapse (incl. identity-hashable-not-collapsed regression); shared-State dedup (states/iter_leaf dedup, count matchestreefy_states, nodes still dedup, inconsistent-aliasing still detected,dedup_leaves=Falserecovers all paths);graph_to_treestate extraction + sharing.Notes
states()/iter_leaf()docstrings.dev/superpowers/and are intentionally gitignored (local dev artifacts), so they are not part of this diff.Test plan
pytest brainstate/graph/— 223 passedbench_graph.pybaseline vs final — −25%…−54% across operationsprobe_graph.py/probe_graph2.py— round-trip + dedup correctness