fix: resolve all issues from the cross-module audit (dev/issues.md)#217
Merged
Conversation
Addresses every verifier-confirmed finding in the audit (C1, H1-H21, M1-M46, L1-L29) plus the actionable needs-judgment items NJ1-NJ3, each with a regression test that fails pre-fix and passes post-fix. core/_state: numpy-backed HiddenGroupState/HiddenTreeState set_value, integer index bounds + negative-index handling, TreefyState pytree aux hygiene (no spurious JIT recompiles), assert->raise validation, tag normalization, copy() _been_writen, StateTraceStack.merge originals. transform: thread-safe stale-cache recompile (atomic replace, NJ2); tracer guard under disable_jit (cond/switch); checkify state write-back; loop base validation; while_loop ConcretizationTypeError; IR codegen (keyword names, round rounding_method, 0-d literals, enum imports, thread-local imports); IR visualize edges/labels; mapping static_argnums; vjp aliasing; eval_shape read-only states; fwd_grad key; progress bar. nn: softmin overflow; F1 macro/weighted; Flatten/MaxUnpool/AdaptivePool; AvgPool/LPPool NaN-safety; regularizer log-normalizers + deterministic spectral norm; conv per-dim padding semantics + dead lhs_dilation; 0-d log-det jacobian; metrics validation under jit. random: wald/multinomial explicit key + per-row validation; seed range / None / context; dtype-preserving *_like; ir_compilation static_argnums; choice on object arrays; chisquare/power domain checks. interop: bias-only BatchNorm; BatchNorm0d for 2-D input; equinox RMSNorm bias; linen GroupNorm group_size; dtype preservation; deterministic Dropout export; conv channel-shape validation. graph: update_states new/raw-State keys; consistent-aliasing node check. util: namedtuple DotDict; empty-subdict round-trips; dataclass subclassing + init=False fields; FlattedDict views/iterable; assert->raise. core misc: Mode.__hash__; DeprecatedModule underscore short-circuit. Full suite: 5074 passed, 23 skipped.
Works through the 110 unverified appendix findings (and the deferred needs-judgment items NJ4 + random split_key backup) on top of the confirmed-bug pass. Each behavioral fix carries a regression test that fails pre-fix and passes post-fix; documented-intent and false-positive findings are left as-is. Runtime/user-input validation that used `assert` is converted to `raise` so it survives `python -O`. core: _compatible_import.wraps copies each metadata attribute independently; _deprecation drops dead code and uses the lazy replacement-module property; _state init-hook metadata excludes internal bookkeeping, copy() gets an independent hook manager, numel() raises on a poisoned state, StateTraceStack exit verifies stack ownership; Hook id counter RMW is lock-guarded; has_hooks validates type up front; environ.reset clears per-key locks and tolerance honors the env precision dtype; mixin HashableDict hashes via frozenset, JointTypes alias call raises; typing PyTree validates leaf hashability, normalizes whitespace, and rejects malformed ellipsis structures. graph: removed write-only context stacks; convert/flatten/walk assert->raise with descriptive errors; dropped a redundant Everything predicate. interop: AddScalar uses brainunit math so united Quantities work. nn: softmin tuple-axis hint; soft_shrink/hard_shrink preserve integer dtype; conv assert->raise + padding-parse guards; Threshold accepts Quantity; removed dead extra_repr; MultiMetric reserves _metric_names; paddings accept numpy/JAX integer scalars; poolings assert->raise, correct stride message, channel_axis `is not None`, target_size validation; regularization sample_init relu-guards the sqrt, validates group_size/order, and targets the requested spectral norm; SoftplusT/NegSoftplusT demote via maybe_decimal; ScaledSigmoidT rejects non-positive beta; EntropyReg/DirichletReg document the global-softmax default (NJ4). random: multinomial returns integer counts; zipf/geometric/triangular validate parameters via jit_error_if; rayleigh(scale=None) no longer crashes; _numpy_keys draws from the full uint32 range; the module-global DEFAULT key read-modify-write is lock-guarded; split_key/self_assign_multi_keys back up the pre-split key so restore_key actually rewinds (matches the documented contract). transform: ifelse enforces the >=2-branch contract; _error_msg falls back to str.format; fwd_grad guards integer dtypes and non-scalar outputs; grad NaN check is inexact-aware and has_aux asserts->raise; IdentityMap iterates keys; scan codegen uses collision-free skolem names; ir_visualize drops phantom pjit edges; checkpointed_scan validates its two-tuple return; make_jaxpr guards fn_to_check/None state_vals/keyword-only static_argnums, ignores weak_type-only changes, and raises ValueError (not UnitMismatchError) on unit mismatch; _remove_axis reports non-array leaves clearly; progress_bar assert->raise and guards post-close updates; corrected unvmap/named_call docstrings; _util assert->IndexError on argnums bounds. util: split_total rejects bool; DotDict overrides __or__/__ior__; clear_buffer_memory catches per-buffer; to_pure_dict keeps empty sub-dicts; StateJaxTracer is hashable; filter WithTag rejects non-str tags and Any/All/Not hash defensively; documented that a str filter matches by tag, not path.
#216 rewrote the state-mapping engine (_mapping_core.py, _mapping2.py, _mapping1.py, _shard_map.py) to fix 8 mapping bugs. Several overlap the audit's own mapping work, so the conflicts in _mapping_core.py / _mapping2.py were resolved in favor of #216 (the reviewed, more thorough rewrite), discarding the now-superseded audit changes: - H19 (pmap with no RandomState -> dummy iota) == #216 #2: took #216's version, which feeds the dummy iota uniformly for vmap and pmap. - M41 (reject static_argnums) is contradicted by #216 #8, which instead *supports* static_argnums (closed over, jax.jit parity); took #216's behavior and dropped the NotImplementedError + its tests. - Appendix item 14 (axis_size cross-check) == #216 #4: took #216's. - Appendix item 15 (read-only batched input scatter-back): #216's engine re-stacks read-only inputs to their original value, so the concern does not manifest; dropped the old-engine-shaped change. Re-applied on top of #216 (additive, not covered by #216): - Appendix item 12: _remove_axis reports a clear ValueError for a non-array (scalar) leaf instead of an opaque AttributeError (+ test). - Appendix item 13: corrected the lying _compile_stateful_function type hint (in_axes/args are 2-tuples). Full transform suite: 1242 passed, 1 skipped.
The appendix refactor of `_MetaPyTree._build` subscripted `_FakePyTree` with the unpacked `leaftype` variable. mypy's strict ratchet on brainstate.typing rejects a bare variable in a type-subscript position ("Variable not valid as a type"). Use the index expression `key[0]` instead (the form the original `__getitem__` used), which mypy accepts. Restores `mypy brainstate/` to a clean run (0 errors, 108 files).
Contributor
There was a problem hiding this comment.
Sorry @chaoming0625, your pull request is larger than the review limit of 150000 diff characters
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
… gate The audit fixes added validation/edge branches (bad-argument raises, unit-carrying state paths, IR-visualization node/edge cases, RNG edge cases, defensive metadata-copy fallbacks) that the existing suite did not exercise, leaving PR patch coverage at ~84%. Add behavioral regression tests across 17 co-located *_test.py files so every added line and branch is executed: - nn/_poolings, _conv, _metrics, _regularization: bad kernel/stride/ padding/group args raise the documented Type/ValueError; multiclass micro-F1. - _state: HiddenGroupState/HiddenTreeState dict get/set key, range, shape and dimensionless-vs-unit branches. - transform/_ir_visualize, _ir_tocode, _progress_bar, _grad_transform, _make_jaxpr: dropped/duplicate invars, literal scan carry, set-like import helper, desc-return validation, non-inexact nan-check leaves, unchanged-shape validation. - random/_seed, _state: seed_context(None) fresh-seed path, numeric choice jax path, non-integer self_assign_multi_keys. - util/_others, typing, graph/_convert, graph/_flatten, _compatible_import: DotDict __ror__/__ior__ NotImplemented, flatten key collision, PyTree non-string structure, nested-node walk, present StateLeafEdge resolution, unsettable wrapper metadata. Diff line coverage 93% -> 100%; the only remaining uncovered branch is genuinely unreachable (graph/_convert iter_node always yields graph nodes). Full suite: 5296 passed, 23 skipped.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR resolves all issues catalogued in the cross-module audit (
dev/issues.md), spanning every package underbrainstate/. Each fix is paired with a regression test that fails before the change and passes after.Scope of the audit
Commit structure
fix(audit): resolve 97 confirmed bugs + NJ1–NJ3 across all modules— the core C/H/M/L fixes acrosscore,graph,interop,nn,random,transform,util.fix(audit): resolve appendix findings + NJ items across all modules— the 110-entry appendix triage plus the remaining needs-judgment items.Merge origin/main (#216 …)— reconciles with the independently-merged mapping-engine rewrite (fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine #216). Where the audit overlapped fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine #216 (H19, M41, appendix items 14/15), fix(transform): resolve 8 bugs in the vmap/pmap/shard_map mapping engine #216's reviewed rewrite was taken; only the additive, non-overlapping items (a clearValueErrorfor scalar leaves in_remove_axis, and a corrected_compile_stateful_functiontype hint) were re-applied on top.fix(typing): keep PyTree name construction mypy-clean— keepsbrainstate.typing(a strict-mypy module) type-clean after the appendix refactor.Nature of the fixes
raise TypeError/ValueErrorrather thanassert(asserts are stripped underpython -O).Test plan
pytest brainstate/). The 23 skips are the pre-existing multi-device /--run-slowgated tests.mypy brainstate/— 0 errors, 108 source files._foo_test.py) exercising the specific edge case.Net change on top of current
main: 105 files, +7120 / −839 (50 test files, 55 source files).