Skip to content

fix: resolve all issues from the cross-module audit (dev/issues.md)#217

Merged
chaoming0625 merged 5 commits into
mainfrom
worktree-fix-audit-issues
Jun 13, 2026
Merged

fix: resolve all issues from the cross-module audit (dev/issues.md)#217
chaoming0625 merged 5 commits into
mainfrom
worktree-fix-audit-issues

Conversation

@chaoming0625

Copy link
Copy Markdown
Collaborator

Summary

This PR resolves all issues catalogued in the cross-module audit (dev/issues.md), spanning every package under brainstate/. Each fix is paired with a regression test that fails before the change and passes after.

Scope of the audit

Severity Items Status
Critical C1 ✅ fixed
High H1–H21 ✅ fixed
Medium M1–M46 ✅ fixed
Low L1–L29 ✅ fixed
Needs-human-judgment NJ1–NJ4 ✅ addressed (fix or documented contract)
Appendix (unverified findings) 110 entries ✅ triaged + fixed where confirmed

Commit structure

  1. fix(audit): resolve 97 confirmed bugs + NJ1–NJ3 across all modules — the core C/H/M/L fixes across core, graph, interop, nn, random, transform, util.
  2. fix(audit): resolve appendix findings + NJ items across all modules — the 110-entry appendix triage plus the remaining needs-judgment items.
  3. Merge origin/main (#216 …) — reconciles with the independently-merged mapping-engine rewrite (fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine #216). Where the audit overlapped fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine #216 (H19, M41, appendix items 14/15), fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine #216's reviewed rewrite was taken; only the additive, non-overlapping items (a clear ValueError for scalar leaves in _remove_axis, and a corrected _compile_stateful_function type hint) were re-applied on top.
  4. fix(typing): keep PyTree name construction mypy-clean — keeps brainstate.typing (a strict-mypy module) type-clean after the appendix refactor.

Nature of the fixes

  • Runtime validation uses raise TypeError/ValueError rather than assert (asserts are stripped under python -O).
  • Changes target stable public JAX APIs to preserve compatibility across the supported JAX range (0.7.0 → latest).
  • Behavioral contracts that were genuinely ambiguous (the NJ items) are resolved by documenting the existing behavior rather than silently changing it.

Test plan

  • ✅ Full suite: 5235 passed, 23 skipped, 0 failed (pytest brainstate/). The 23 skips are the pre-existing multi-device / --run-slow gated tests.
  • ✅ Type check: mypy brainstate/ — 0 errors, 108 source files.
  • Every fix has a co-located regression test (_foo_test.py) exercising the specific edge case.

Net change on top of current main: 105 files, +7120 / −839 (50 test files, 55 source files).

Addresses every verifier-confirmed finding in the audit (C1, H1-H21,
M1-M46, L1-L29) plus the actionable needs-judgment items NJ1-NJ3, each
with a regression test that fails pre-fix and passes post-fix.

core/_state: numpy-backed HiddenGroupState/HiddenTreeState set_value,
integer index bounds + negative-index handling, TreefyState pytree aux
hygiene (no spurious JIT recompiles), assert->raise validation, tag
normalization, copy() _been_writen, StateTraceStack.merge originals.

transform: thread-safe stale-cache recompile (atomic replace, NJ2);
tracer guard under disable_jit (cond/switch); checkify state write-back;
loop base validation; while_loop ConcretizationTypeError; IR codegen
(keyword names, round rounding_method, 0-d literals, enum imports,
thread-local imports); IR visualize edges/labels; mapping static_argnums;
vjp aliasing; eval_shape read-only states; fwd_grad key; progress bar.

nn: softmin overflow; F1 macro/weighted; Flatten/MaxUnpool/AdaptivePool;
AvgPool/LPPool NaN-safety; regularizer log-normalizers + deterministic
spectral norm; conv per-dim padding semantics + dead lhs_dilation; 0-d
log-det jacobian; metrics validation under jit.

random: wald/multinomial explicit key + per-row validation; seed range /
None / context; dtype-preserving *_like; ir_compilation static_argnums;
choice on object arrays; chisquare/power domain checks.

interop: bias-only BatchNorm; BatchNorm0d for 2-D input; equinox RMSNorm
bias; linen GroupNorm group_size; dtype preservation; deterministic
Dropout export; conv channel-shape validation.

graph: update_states new/raw-State keys; consistent-aliasing node check.
util: namedtuple DotDict; empty-subdict round-trips; dataclass
subclassing + init=False fields; FlattedDict views/iterable; assert->raise.
core misc: Mode.__hash__; DeprecatedModule underscore short-circuit.

Full suite: 5074 passed, 23 skipped.
Works through the 110 unverified appendix findings (and the deferred
needs-judgment items NJ4 + random split_key backup) on top of the
confirmed-bug pass. Each behavioral fix carries a regression test that
fails pre-fix and passes post-fix; documented-intent and false-positive
findings are left as-is. Runtime/user-input validation that used `assert`
is converted to `raise` so it survives `python -O`.

core: _compatible_import.wraps copies each metadata attribute independently;
_deprecation drops dead code and uses the lazy replacement-module property;
_state init-hook metadata excludes internal bookkeeping, copy() gets an
independent hook manager, numel() raises on a poisoned state, StateTraceStack
exit verifies stack ownership; Hook id counter RMW is lock-guarded;
has_hooks validates type up front; environ.reset clears per-key locks and
tolerance honors the env precision dtype; mixin HashableDict hashes via
frozenset, JointTypes alias call raises; typing PyTree validates leaf
hashability, normalizes whitespace, and rejects malformed ellipsis structures.

graph: removed write-only context stacks; convert/flatten/walk assert->raise
with descriptive errors; dropped a redundant Everything predicate.

interop: AddScalar uses brainunit math so united Quantities work.

nn: softmin tuple-axis hint; soft_shrink/hard_shrink preserve integer dtype;
conv assert->raise + padding-parse guards; Threshold accepts Quantity;
removed dead extra_repr; MultiMetric reserves _metric_names; paddings accept
numpy/JAX integer scalars; poolings assert->raise, correct stride message,
channel_axis `is not None`, target_size validation; regularization sample_init
relu-guards the sqrt, validates group_size/order, and targets the requested
spectral norm; SoftplusT/NegSoftplusT demote via maybe_decimal; ScaledSigmoidT
rejects non-positive beta; EntropyReg/DirichletReg document the global-softmax
default (NJ4).

random: multinomial returns integer counts; zipf/geometric/triangular validate
parameters via jit_error_if; rayleigh(scale=None) no longer crashes; _numpy_keys
draws from the full uint32 range; the module-global DEFAULT key read-modify-write
is lock-guarded; split_key/self_assign_multi_keys back up the pre-split key so
restore_key actually rewinds (matches the documented contract).

transform: ifelse enforces the >=2-branch contract; _error_msg falls back to
str.format; fwd_grad guards integer dtypes and non-scalar outputs; grad NaN
check is inexact-aware and has_aux asserts->raise; IdentityMap iterates keys;
scan codegen uses collision-free skolem names; ir_visualize drops phantom pjit
edges; checkpointed_scan validates its two-tuple return; make_jaxpr guards
fn_to_check/None state_vals/keyword-only static_argnums, ignores weak_type-only
changes, and raises ValueError (not UnitMismatchError) on unit mismatch;
_remove_axis reports non-array leaves clearly; progress_bar assert->raise and
guards post-close updates; corrected unvmap/named_call docstrings; _util
assert->IndexError on argnums bounds.

util: split_total rejects bool; DotDict overrides __or__/__ior__; clear_buffer_memory
catches per-buffer; to_pure_dict keeps empty sub-dicts; StateJaxTracer is hashable;
filter WithTag rejects non-str tags and Any/All/Not hash defensively; documented
that a str filter matches by tag, not path.
#216 rewrote the state-mapping engine (_mapping_core.py, _mapping2.py,
_mapping1.py, _shard_map.py) to fix 8 mapping bugs. Several overlap the
audit's own mapping work, so the conflicts in _mapping_core.py /
_mapping2.py were resolved in favor of #216 (the reviewed, more thorough
rewrite), discarding the now-superseded audit changes:

- H19 (pmap with no RandomState -> dummy iota) == #216 #2: took #216's
  version, which feeds the dummy iota uniformly for vmap and pmap.
- M41 (reject static_argnums) is contradicted by #216 #8, which instead
  *supports* static_argnums (closed over, jax.jit parity); took #216's
  behavior and dropped the NotImplementedError + its tests.
- Appendix item 14 (axis_size cross-check) == #216 #4: took #216's.
- Appendix item 15 (read-only batched input scatter-back): #216's engine
  re-stacks read-only inputs to their original value, so the concern does
  not manifest; dropped the old-engine-shaped change.

Re-applied on top of #216 (additive, not covered by #216):
- Appendix item 12: _remove_axis reports a clear ValueError for a
  non-array (scalar) leaf instead of an opaque AttributeError (+ test).
- Appendix item 13: corrected the lying _compile_stateful_function type
  hint (in_axes/args are 2-tuples).

Full transform suite: 1242 passed, 1 skipped.
The appendix refactor of `_MetaPyTree._build` subscripted `_FakePyTree` with the unpacked `leaftype` variable. mypy's strict ratchet on brainstate.typing rejects a bare variable in a type-subscript position ("Variable not valid as a type"). Use the index expression `key[0]` instead (the form the original `__getitem__` used), which mypy accepts. Restores `mypy brainstate/` to a clean run (0 errors, 108 files).

@sourcery-ai sourcery-ai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry @chaoming0625, your pull request is larger than the review limit of 150000 diff characters

@codecov

codecov Bot commented Jun 13, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 95.76784% with 35 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
brainstate/transform/_ir_visualize.py 0.00% 34 Missing ⚠️
brainstate/graph/_convert.py 91.66% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

… gate

The audit fixes added validation/edge branches (bad-argument raises,
unit-carrying state paths, IR-visualization node/edge cases, RNG edge
cases, defensive metadata-copy fallbacks) that the existing suite did not
exercise, leaving PR patch coverage at ~84%. Add behavioral regression
tests across 17 co-located *_test.py files so every added line and branch
is executed:

- nn/_poolings, _conv, _metrics, _regularization: bad kernel/stride/
  padding/group args raise the documented Type/ValueError; multiclass
  micro-F1.
- _state: HiddenGroupState/HiddenTreeState dict get/set key, range,
  shape and dimensionless-vs-unit branches.
- transform/_ir_visualize, _ir_tocode, _progress_bar, _grad_transform,
  _make_jaxpr: dropped/duplicate invars, literal scan carry, set-like
  import helper, desc-return validation, non-inexact nan-check leaves,
  unchanged-shape validation.
- random/_seed, _state: seed_context(None) fresh-seed path, numeric
  choice jax path, non-integer self_assign_multi_keys.
- util/_others, typing, graph/_convert, graph/_flatten,
  _compatible_import: DotDict __ror__/__ior__ NotImplemented, flatten
  key collision, PyTree non-string structure, nested-node walk, present
  StateLeafEdge resolution, unsettable wrapper metadata.

Diff line coverage 93% -> 100%; the only remaining uncovered branch is
genuinely unreachable (graph/_convert iter_node always yields graph
nodes). Full suite: 5296 passed, 23 skipped.
@chaoming0625 chaoming0625 merged commit 160ee95 into main Jun 13, 2026
7 checks passed
@chaoming0625 chaoming0625 deleted the worktree-fix-audit-issues branch June 13, 2026 11:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant