Stan/rfc for monorepo docs#110
Merged
Merged
Conversation
…devinterp subrepo/shared/devinterp subrepo: subdir: "shared/devinterp" merged: "9eba2d86" upstream: origin: "git@github.com:timaeus-research/devinterp.git" branch: "rfc-for-monorepo-license-and-update" commit: "9eba2d86" git-subrepo: version: "0.4.9" origin: "https://github.com/ingydotnet/git-subrepo" commit: "4f60dd7"
…com/timaeus-research/timaeus into williamsnell/move-tests-to-triangle
* WIP vibecoded docs * delete claude slop files * more docs changes by claude * docs are OK now
* take the axe to TPUs * Attempt to make claude happy by fixing a file we're no longer using * fix setup * update lockfile * address William's comments * fix cursor style guide at claude's request * remove faulty merge import * FIX merge mistake * another small merge issue * attempted fix GPU test in ci/cd? * WIP fix test that somehow passed before? * remove some more TPU marks * WIP fix gpu large n batched depending on TPU code (ugh) * update snapshot * WIP * WIP * WIP * WIP * WIP * WIP * add snapshot fix --------- Co-authored-by: William Snell <59493198+williamsnell@users.noreply.github.com>
* Apply `ruff check --fix shared/` * Allow unused local variables in ruff check * Fix some linter issues * Enable ruff check in precommit and ci validation * Cleanup some test imports
* Remove typing.Dict, List etc. They've been deprecated in favour of dict, list... * Fix is vs == comparison
* Add high-level test laying out desired API * Add initial code structure * fix ParamGroups import * remove __init__.py partial __all__ for now * Remove redundant dict passing in parser; add base-class level validation with good error messages * Update tests, transformer_lens implementation; LIKELY WRONG * Move model-key matching to be model-agnostic * Improve typing, code structure, readability. Add more tests, try and get them to pass * Fix dtype handling in mask logic, add some sensible assertions, fix the test * change ParamGroups -> MaskedParamDict, refactor as needed. * Move Restriction to base, clean up typing * Move pythia test from test_wr_e2e.py. Fix pythia head handling, add einops masks * Properly move Restriction to base.py * Fix a bunch of tests to match the new API * Refactor MaskedTensor back to TypedDict for torch optimizer compat * optimize_over -> mask * Fix test_llc_vs_modular to use new weight restrictions arg * update WeightRestrictionArguments to match new format, fix some MaskedTensor calls * Change Restriction type to use Union, List (to retain compatibility with legacy config parser) * Update type def in weight_restrictions.py * weight_restriction_path -> weight_restriction * Update weight_restriction_path -> weight_restriction in all test files * sgld optimize_over -> mask * weight_restriction -> weight_restriction*s* * fix weight_restriction -> weight_restrictions * weight_restriction -> weight_restrictions * Remove duplicate import * refactor/simplify tests in test_weight_restrictions.;py * Delete test_weight_restrictions_e2e*. We should have equivalent end-to-end testing with sgld via test_llc_vs_modular.py (which includes snapshots) * Fix typing complaints about deprecated Union, Tuple * Make sure Restrictions is never None * Remove slicing, weight_restrictions.py * Update all references to old weight_restrictions.py, fix some docs * Refactor hessians.py to use new code * Explicitly test the mlp error handling for triangle * Break out model-specific parsing * Update yamls * Update tests (drop _path from weight_restrictions) * Add parsing tests, reject leading 0s in keys * remove test_slicing.py * Update weight restrictions structure, move imports * delete old weight restrictions * remove Optional * Add more comprehensive unit tests of added code * create_param_groups -> create_masked_params * Delete compatibility alias * Update docs * Finish updating docs * weight_restriction_name -> weight_restrictions. config .yaml -> semantic * Update has_mlp comment for pythia * pythia -> gpt_neox * Clearly explain what weight restrictions mean * Make tests compare gpt_neox through aether and transformerlens against all semantic groups * Make test_determinism use a different weight restriction (since most tests just use "full" and that doesn't test weight restriction handling) * Change wr to attn-2-3 * Change snapshots to be dir-specific * Filter delete so we only delete unused snapshots on the same backend as we currently are running * remove 5070 snapshots * Add some cursed messages since overriding syrupy is getting complex * Make syrupy messages more obvious * Fix accidentally printing the message 30 times * Fix error formatting one more time * update snapshots to use attn weight restriction * ones -> ones_like (so we copy device and dtype automatically) * fix keyword: input_shape -> input_tensor * Add cli to preview weight restriction keys, and test to verify it works * semantic patterns -> component groups * delete incomplete code fence * weight_restriction_path -> weight_restrictions * Update shared/aether/src/aether/model/weight_restrictions/wr2.py * Add plain mode, nicer formatting * Re-delete TPU code * Explain the purpose of merge_masks * Add some tests for weirder combinations of parameters. Add an abs(sum) so we have another datapoint * Add back requires_grad = False logic that wasn't replicated in create_masked_params * revert compute_susceptibilities to main, since these paths are hardcoded into test files * Update snapshots (weight_restriction_name=None -> "") * Add versioned SampleActionArgs * Update tests to either use a TypeAdapter or V2 directly * Remove type hints from queue * Bypass any caching of DataTrainingArguments; only deserialize the actual arguments of DataTrainingArguments whenever possible * Add back test (that should fail on main) * Try mark DataTrainingArguments manually as a dataclass to run the post_init * Move the test that uses github-all to test_observables (so it's run in CI, not just on nightlies) * Narrow allowed types of string_or_dict_to_dataset_args * Properly use parse_weight_restrictions; update docstring * Properly wrap SamplePipelineArgs in a type adapter * pile10k -> github_all in tests * DataTrainingArguments -> _DatasetArgs * Remove broken assertion (used to be weight_restriction_name, but that no longer exists) * DataArray -> str -> Path -> str * Move parse_weight_restrictions to shared module, use it in susc_fast * Use DatasetArgs as the exported type; update type info to remove Union[Dataset, DataTrainingArguments] * _DatasetArgs -> DatasetArgs in sample.py * Remove annotations import in data.py * Add docs, simplify string_or_dict_to_dataset_args
* Enable build docs on vercel * Update Vercel configuration: routes was deprecated - Added JSON schema URL to `vercel.json` for improved validation. - Changed `src` and `dest` keys to `source` and `destination` in route definitions * Separated deps installation * Minor change to test build speed * Using UV_NO_SOURCES to skip CUDA deps * Using UV_NO_SOURCES to skip CUDA deps * Additional hint to uv to not get CUDA deps * Restored server-side function execution * Moved install and build commands to scripts * Reverted to deprecated routes to see if that fixes functions * Installing deps from subproject * Explicitly install CPU Torch * Using simple excludes file for nvidia deps * Running latest uv for --excludes support * Running latest uv for --excludes support * Running latest uv for --excludes support * Running latest uv for --excludes support * Running latest uv for --excludes support * installing torch as a separate step * Using uv pip --torch-backend=cpu * Pruning torch from install deps and mocking in docs * No deps! * Bash strict mode * Using uv run --group=docs when building docs * Mocking more libs that require torch * Mocking more libs that require torch * Mocking more libs that require torch * Removing unused excludes.txt * Failing build when autodoc imports fail * Added comment about additional mocked packages * Added docs build instructions --------- Co-authored-by: Adam Newgas <adam@timaeus.co>
* Refactor sampler metrics for sgld.py * Restore with torch.no_grad(), to keep git diff smaller * refactor metrics code into separate method * Add comprehensive tests of metrics and sampler; fix some bugs in sgld now masks doesn't have a default * Tidy up test more, fix noise accounting * Fix tests by running in float64 mode to reduce fp roundoff * Properly handle metrics for multiple groups; add tests across different (and multi) devices * fix noise generator device * Combine wr and loc metrics into "prior". Remove redundant localization and weight_decay args to SGMCMC (we already have prior), simplify priors. * add tests for rmsprop, rmsprop scaling * Properly transfer numel across devices * Add more type hints, docstrings * For now, remove existing metrics * Update existing tests, sampler code to use new metrics * Remove mala test that uses sampler code, but keep other mala test * Add return type hints * Add explanatory note about tolerances and allclose. * replace diagnostics/optimizer_metrics with save_metrics * Fix LLCArguments missing diagnostics field after SamplerArguments refactor PR #1276 replaced SamplerArguments.diagnostics (list[str]) with save_metrics (bool) but didn't update llc.py, which still referenced llc_args.diagnostics in 6 places and llc_args.optimizer_metrics in the sampling kwargs. This caused AttributeError in all LLC integration tests. Move diagnostics to LLCArguments where it belongs (it controls LLC-level concerns: MALA acceptance rate, radial traces, convergence checks, contribution plots). Bridge the old metrics list to the new SGMCMC save_metrics bool via _needs_optimizer_metrics. Update snapshots. Made-with: Cursor * Rename Metrics.grad to scaled_grad, split prior into localization + weight_decay Addresses Stan's review feedback on #1276: `grad` was confusable with PyTorch's raw `p.grad`; now `scaled_grad` makes clear it includes lr, nbeta, and preconditioner scaling. The single `prior` field is replaced by separate `localization` and `weight_decay` fields with a `prior` property for the combined norm. SGMCMC decomposes sub-priors in step() before the parameter update to capture correct pre-update values. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Made-with: Cursor * Fix unbound `p` in SGLD metrics init and null grad in `_update_metrics` When save_metrics=True but localization and bounding_box_size are both 0, the `p` variable from the conditional loop was never assigned, causing UnboundLocalError on `p.device`. Use `next(iter(group["params"])).device` instead. Also guard `p.grad.data` access in `_update_metrics` for the case where `p.grad is None` (step() sets dw to zeros but still calls _update_metrics). Made-with: Cursor * Test methodology improvements Addresses Stan's review comment. **What changed** - **`test_sampler_metrics.py` — rewritten.** Deleted the `run_sampler_metrics_test` helper that assembled 1152 parametric combinations and added indirection that made the tests hard to verify. Replaced with focused component tests (`TestScaledGrad`, `TestLocalization`, `TestWeightDecay`, `TestNoise`, `TestAllComponents`, `TestNumel`, `TestMask`, `TestGetMetricsInterface`). Each test isolates a single metric component with hand-computed expected values using the 3-4-5 Pythagorean triple for clean norms. Both SGLD and SGMCMC.sgld are tested via a parametrized fixture. GPU and multi-GPU tests use pytest markers instead of `pytest.skip`. No `torch.distributed` needed — the multi-GPU test just places params on separate devices without DDP. - **`test_metrics.py` — new file.** Unit tests for the `Metrics` dataclass itself: `add_sum_squared_` + `sqrt_` produces correct L2 norms, accumulation across calls, `zero_()`, `to()` returns independent copy, `prior` property combines localization and weight_decay. - **`test_prior.py` — cleaned up.** Replaced all `torch.randn` with explicit tensor values. Added a mutation check proving `center="initial"` clones rather than aliases the parameter tensor. Removed `test_custom_key` (redundant with `test_composite_prior`). Merged two uniform-filtering tests. - **`test_sgmcmc.py` — tightened tolerances.** `compare_parameters` now requires `atol` as keyword-only (no default). Tried `atol=0` everywhere. Parameters are bitwise identical in most SGLD-vs-SGMCMC tests; three combos with `weight_decay=0.05` hit float32 operation-order rounding up to ~6e-8, so those use `atol=1e-7`. Metrics comparisons between SGLD and SGMCMC use `atol=1e-7` (float32 accumulation order). All other tests (`optimize_over`, `deterministic`, `vs_SGNHT`, `param_groups`, `rmsprop_equals_sgld`) pass at `atol=0` for parameters. - **`sgld.py` — production assertion.** Replaced `torch.testing.assert_close` (a testing API) with `assert torch.allclose(...)`. This is a sanity check that reconstructed metric components sum to the actual gradient update. `assert` is stripped by `python -O`, so no runtime cost in optimized mode. Tolerance `atol=1e-6, rtol=0` — same expression computed in different operation order on float32. - **`sgmcmc.py` — production assertion.** Added the same assertion as `sgld.py`. **Conventions applied** - `torch.testing.assert_close` with explicit `atol`/`rtol` everywhere in tests - Literal expected values instead of programmatic recomputation - `pytestmark` filter warnings with comments explaining which warnings and why - `@pytest.mark.gpu` / `@pytest.mark.only_multi_gpu` instead of runtime `pytest.skip` * Added multi-GPU job for devinterp * Address Stan's review: fp32 accumulation, device placement, lifecycle docs Stan's Dec 2025 review flagged several issues with the metrics implementation. This commit addresses them: - add_sum_squared_() now casts to float32 before squaring to avoid precision loss with bf16/fp16 inputs (previously squared in input dtype, only upcasting during sum reduction) - get_metrics() accumulates on CPU explicitly: transfers each group's scalar metrics to CPU first, then does all arithmetic there. Avoids picking an arbitrary device when groups span multiple GPUs. - UniformPrior.key gets a comment explaining it's required by the Prior ABC but never used - Metrics docstring now documents the full lifecycle (init, zero, accumulate per-param, sqrt, get_metrics on CPU) - Inline comments in step() and get_metrics() explain the zero/accumulate/sqrt flow and why get_metrics re-squares per-group norms - Multi-GPU test now explicitly asserts result is on CPU Made-with: Cursor * Fix CompositePrior missing key and UniformPrior type hint CompositePrior.__init__ never set self.key (required by Prior ABC), which would raise AttributeError if all priors were filtered out. UniformPrior.initialize still used Iterator instead of Sequence. Made-with: Cursor * Add rmsprop_sgld to mask tests Extends make_masked_optimizer fixture to cover SGMCMC.rmsprop_sgld, which routes through CompositePreconditioner (rmsprop + MaskPreconditioner). Parametrizes test_mask_restricts_scaled_grad with analytical expected values per preconditioner (3.0 for identity, 7.5 for rmsprop first-step scaling). Made-with: Cursor * Address code review feedback on sampler metrics PR Made-with: Cursor * Fix SGMCMC numel to count only masked-in elements When a mask is active, _update_metrics was using p.numel() which counts all elements regardless of the mask. Now checks whether the preconditioning overall_coef is a tensor (indicating a mask) and uses count_nonzero() to count only the masked-in elements. Made-with: Cursor * Wire sampler metrics into Zarr output Add /metrics group to the sampling Zarr schema with per-step L2 norms (chain, step) and numel as a group attribute. Introduces StepCallback protocol, on_step closure, MetricsTree protocol with jaxtyping dims, and migrate_to_v3 support for the metrics group. Validates numel from the optimizer against the schema-time computation from create_masked_params. Made-with: Cursor * Fix bfloat16 assertion failure in _update_metrics The debug assertion in SGLD._update_metrics and SGMCMC._update_metrics failed on bfloat16 models (e.g. timaeus/triangle-40k) because the metrics recomputation multiplied scalars and tensors in a different order than step() computed d_p. With bfloat16's 7-bit mantissa, the different associativity produced errors of ~0.016, exceeding atol=1e-6. Fix: match the multiplication order in _update_metrics to the order used in step() — e.g. `_half_lr * (grad_coef * p.grad.mul(nbeta))` instead of `_half_lr * grad_coef * overall_coef * p.grad * nbeta`. Made-with: Cursor * Added missing parameter to snapshots * Add unscaled_grad, distance, and dot product metrics Extend the Metrics dataclass with two new norm fields (unscaled_grad, distance) and three dot product fields (dot_grad_prior, dot_grad_noise, dot_prior_noise). These capture the gradient without preconditioner scaling, raw displacement from init params, and geometric relationships between update components — all needed when a preconditioner (e.g. RMSprop) is present. Both SGLD and SGMCMC optimizers now compute and accumulate these in _update_metrics. Initial params are stored whenever save_metrics=True (not just when localization > 0). The Zarr schema, on_step callback, and MetricsTree protocol are updated to include all new fields. Made-with: Cursor * Added missing parameter to snapshots * Tighten SGMCMC metrics assertions with scale-aware tolerance Replace the loose 5e-2 tolerance with principled, dtype-aware bounds. Any decomposition of d_p distributes half_lr over separate terms, breaking IEEE 754 associativity. The error from c*(a+b) vs c*a+c*b scales with the component magnitudes, so we bound it by 16 ULP of the reference scale -- tight enough to catch structural bugs while accommodating bf16 rounding. Two assertions: (1) grad + combined_prior ≈ half_lr * d_p (2) loc + wd ≈ combined_prior (distribution of prior_coef) Also fixes the noise double-counting of overall_coef and cleans up the SGLD assertion (cosmetic only, tolerance unchanged). Made-with: Cursor * Added comments to clarify metrics-saving behavior * Added docs for metrics * Address PR review feedback: clean up dead code and tighten types - Expand SamplingMethodKwargs TypedDict with missing fields (noise_level, localization, weight_decay, bounding_box_size) - Remove dead optimizer_metrics dict and stop passing it to callbacks - Replace defensive numel guard with assert when save_metrics=True - Add deprecation note on SamplerArguments.save_metrics Made-with: Cursor * Checking that metrics are purely observational Added a job variant to test_determinism.py that has metrics enabled. Its loss is compared against the same snapshot as the non-metrics runs to ensure they match. * When opening Samples data, use an empty group for metrics when missing instead of having no such node in the tree. Otherwise it doesn't match the protocol * Fix metrics lost during zarr-to-netcdf conversion The zarr schema wrote metrics into a child group (metrics/scaled_grad, etc.), but sample_from_dict converts to .nc via xr.open_zarr() which only reads the root group -- silently dropping the child group and all metrics data. Fix: write metrics as root-level flat arrays (metrics_scaled_grad, etc.) so they survive the xr.open_zarr roundtrip. migrate_to_v3 now collects them back into the /metrics DataTree group via _extract_metrics(). Unit tests now operate on Samples v3 objects (post-migration) rather than raw zarr stores, so they won't need updating when the save infrastructure is refactored. Made-with: Cursor * Improved docstring of has_data() * Fix SGMCMC distance metric to use pre-update params SGMCMC._update_metrics() was computing distance from p.data after the parameter update (deterministic + noise), while SGLD correctly computed it before. This caused an off-by-one-step inconsistency between the distance metric and all other component metrics. Compute distance before p.data is modified, matching SGLD's pattern. Merge the now-identical SGLD/SGMCMC exact-value tests and parametrize several hardcoded dot-product tests with the sampler_cls fixture. Made-with: Cursor * Address review items 4 & 5: docstring fixes - Remove stale `weight_decay` param doc from SGMCMC.__init__ (not a valid __init__ parameter; only exists on factory methods where it's documented) - Clarify test_dot_products_accumulate_across_params docstring: explicitly note it tests multiple params in a single group, not multiple groups Made-with: Cursor --------- Co-authored-by: z0u <sandy@timaeus.co> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
* Empty commit * Add count sketch metrics to sampling pipeline Implements count sketch projections (#1724) for tracking approximate gradient decomposition vectors during SGMCMC sampling. Sketches preserve inner products in expectation (E[<Sv, Sw>] = <v, w>), enabling post-hoc Gram matrix and directional drift analysis without storing full parameter vectors. - CountSketch class with deterministic hash/sign generation from seed - SketchBuffer for per-step accumulation of 5 quantities (scaled_grad, unscaled_grad, localization, weight_decay, noise) - SGMCMC integration: init_sketch(), _update_sketches(), get_sketches() - SketchConfig in SamplerArgs, sketch arrays in Zarr schema - optimizer_callback on sample_single_chain for pre-step initialization - 45 new tests across 3 files (unit, optimizer-level, pipeline) Made-with: Cursor * Refactor sketch lifecycle to match Metrics; add v3 migration Aligns sketch initialization with Metrics: sketch_dim/sketch_seed are now constructor args on SGMCMC (threaded through all factory methods), removing init_sketch() and the optimizer_callback plumbing from sampler.py. Adds _extract_sketches for v2->v3 migration and wires SketchTree into the Samples protocol. Guards FSDP path against sketch usage until global-offset support is implemented. Also: sketch_dim in tests reduced from 2048 to 64 (was larger than input_dim), tolerance comments added explaining O(1/sqrt(sketch_dim)). Made-with: Cursor * Added a GPU test that checks metrics against sketches with a real (but small) model * Tighten test assertions to match writing-tests conventions Add strict=True to np.testing.assert_array_equal, add loop-variable identification to all in-loop assertions, simplify redundant error messages, and improve the torch.all diagnostic in test_hash_signs_are_pm1. Made-with: Cursor * Extract shared decomposition from metrics and sketch code paths _update_metrics and _update_sketches computed identical component vectors independently. Extract _decompose_update returning a _ComponentVectors NamedTuple, consumed by _accumulate_metrics and _scatter_sketches. Debug assertions now protect both code paths. Also seed the global RNG in test_all_quantities_over_multiple_steps to fix flakiness. Made-with: Cursor * Fixed GPU tests re. missing RANK * Added sketch_metrics:null to snapshot files * Address Adam's review: simplify SketchBuffer, add citation and docstrings - Add Charikar et al. 2002 reference to CountSketch docstring - Make sketch() delegate to scatter_into_() instead of duplicating logic - Replace per-param-group sketch buffers with a single self._sketch_buf (sketch accumulation is purely additive, unlike Metrics which needs per-group buffers for non-linear norm aggregation) - Drop SketchBuffer.aggregate/to_cpu; add standard to(device) method - Add FSDP comment on sketch buffer noting all-reduce requirement - Add formulas to SketchTree docstrings matching MetricsTree pattern Made-with: Cursor * Fix decomposition assertion tolerance for low-precision dtypes The debug assertions in _decompose_update used d_p.dtype for the tolerance eps, but d_p can be promoted to float32 by the mask preconditioner while the underlying arithmetic (grad+prior sum in step()) happens at p.grad precision (e.g. bfloat16). This caused spurious failures in the l0h0 integration test. Use per-assertion eps: _grad_pre.dtype for assertion (1) — captures the actual d_p computation precision before overall_coef promotion — and p.grad.dtype for assertion (2) — the loc+wd sum is always at input precision. Made-with: Cursor * Add count sketch support to legacy SGLD optimizer Mirrors the sketch integration already in SGMCMC: accepts sketch_dim and sketch_seed, decomposes update components via _ComponentVectors named tuple (shared between metrics and sketch paths), and scatters into a SketchBuffer each step. Includes unit tests (init, zero components, norm cross-check, mask handling, without-metrics) and a GPU integration test variant. Made-with: Cursor * Added Aether docs for count sketch metrics * Add Marimo notebook demoing count sketch metrics Loads a .zarr samples store with sketch_metrics enabled, computes L2 norms from sketch vectors, and overlays them against the scalar metric norms to demonstrate the norm-preservation property of count sketches. Made-with: Cursor * Fix SGLD decomposition assertion tolerance for low-precision dtypes Mirrors the dtype-aware, scale-dependent tolerance already used in SGMCMC. The hardcoded atol=1e-6 failed on GPU with bfloat16 models where a single ULP at the comparison scale exceeds that threshold. Made-with: Cursor * Add comments explaining why _ComponentVectors is duplicated across optimizers SGLD is deprecated and will be removed; sharing a private type with SGMCMC would couple the two and complicate that cleanup. Made-with: Cursor * Made most optimizer args keyword-only Construction call sites: 1. `sampler.py` (aether): the main production callsite. Passes `params` positionally, then `**sampling_method_kwargs`. Already kwargs-only. 2. `test_sgmcmc.py`: all 5 construction calls use either `**kwargs` unpacking or explicit keyword arguments like `lr=1.0, noise_level=1.0, nbeta=1.0`. 3. `sgld_test.py`: uses `lr=2*lr, noise_level=0.0, localization=0.0, nbeta=1.0`. All kwargs. 4. Docstring examples: `SGLD(model.parameters(), lr=0.1, nbeta=...)`. Already kwargs. The only positional argument anyone ever passes is `params` (the first arg), which makes sense to keep positional. Everything after it is already passed as keyword arguments everywhere. * Added method parameter type hints for new functions in sgmcmc: * Add dot product cross-checks for sketch metrics Sketch norms were already cross-checked against scalar metric norms, but this only validates each sketch individually. Dot product cross-checks verify that all sketches share the same hash space by comparing ⟨S(v), S(w)⟩ against the exact dot product metrics (dot_grad_prior, dot_grad_noise, dot_prior_noise). Uses absolute tolerance (margin · ‖v‖·‖w‖ / √k) rather than relative, since dot product estimation variance scales with the product of norms, not the dot product magnitude. Made-with: Cursor * Added plots of dot products and cosine similarity to sketch demo notebook * Added links to sketch paper in code comments * Added test that exercises sketch computation when params with gradients are non-contiguous * Add `device` property and return `self` from `.to()` on same device Follows PyTorch tensor convention: `.to()` is a no-op when already on the target device. Raises RuntimeError on inconsistent internal state. Made-with: Cursor * Fix count sketch citations: credit Weinberger et al. for inner product property The original Charikar et al. (2002) paper is about frequency estimation, not inner product preservation. The property E[<Sv, Sw>] = <v, w> is proved by Weinberger et al. (2009) in the feature hashing literature. Made-with: Cursor * Added docs noting that different wrs are not comparable * Discussed choice of t=1 for count sketches * Added version number to sketches so we can track semantic changes * Remove redundant dt.attrs fallback in _extract_metrics and _extract_sketches ds is derived from dt.to_dataset() and subsequent operations preserve attrs, so dt.attrs and ds.attrs always have the same content here. Made-with: Cursor
* Start working on action: sample port to devinterp
* Port action: sample to devinterp works
* Minor cleanup
* Add susceptibilities post-processing to devinterp port
Port of aether's calc_sus_fast to work directly on raw sampling
DataTrees (flat variable format from do_sampling_with_observables).
Pure numpy/xarray, no aether imports. Includes end-to-end parity
test against aether's sample + post-process pipeline.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add tiny sus YAML configs for susceptibilities parity test
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add BIF post-processing to devinterp port
Port of aether's BIF (Bayesian Influence Function) computation.
Includes batched covariance/correlation utils, token-wise and
sequence-level BIF modes. Works on raw sampling DataTrees with
flat variable format. End-to-end parity test against aether.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add sample() API, simplify tests, fix warnings
- Add user-friendly sample() entry point to sampling.py that handles
observable construction, context length derivation, and config wiring
- do_sampling_with_observables takes output_path instead of ZarrStore,
returns lazy DataTree directly
- Remove zarr consolidated metadata (not needed for local files)
- Remove DataLoader num_workers (data is pre-tokenized, no I/O to parallelize)
- Set HF_HUB_DOWNLOAD_TIMEOUT in test conftest to avoid flaky timeouts
- Simplify all 3 parity tests to use sample() API
- Merge _DROPPED_VARS and _EXPECTED_DIFFERING_ATTRS into _strip_reference
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix BIF docstring: Bayesian Influence Function, not Information Flow
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Extract shared test helpers into conftest, add sample smoke test
Move model/data loading and param mask construction into conftest
helpers (load_model_and_data, build_param_masks) to reduce
duplication across the 3 parity tests. Add smoke test for sampling.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Enable sampling metrics by default in tests
Add save_metrics parameter to sample() API. Enable metrics in
tiny-sample.yaml so the aether reference includes them. Parity test
now verifies metrics are bitwise identical to aether.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Move test YAML configs into test directory
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove dead code: vocab_size, seed_worker, SamplingMethodKwargs, loader param
- Remove vocab_size calculation and uint16 optimization (int64 is fine
for all vocab sizes, negligible storage impact vs loss arrays)
- Remove _seed_worker and worker_init_fn (no DataLoader workers)
- Remove SamplingMethodKwargs TypedDict (never enforced)
- Remove loader parameter from sample_single_chain (always built internally)
- Replace Optional[X] with X | None throughout sampler.py
- Fix SamplerConfig import in sampler.py
- Clean up unused imports (numpy, IterableDataset, _torch_dtype_from_str)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Simplify sampling pipeline: collapse layers, remove dead code
- Inline do_sampling_with_observables and sampler.sample into sample()
Call chain is now sample() -> sample_single_chain(), two functions
- Simplify write_init_loss: plain forward loop, no SGLD machinery
- Remove no_grad detection (only needed for old init_loss hack)
- Remove _maybe_no_grad helper
- Remove SamplerCallback dependency, use plain Callable
- Remove results accumulation from inner loop (unused by callbacks)
- Simplify make_data_feed from 65 lines to 10
- Remove dead code: _seed_worker, _torch_dtype_from_str, vocab_size,
SamplingMethodKwargs TypedDict, loader parameter
- Re-export old devinterp sample() for backwards compatibility
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Simplify observables, add pydantic validation to SamplerConfig
- Remove Observable ABC, ObservableVarSpec, validation machinery
- Observable is now a plain class: compute_loss(model) + input_ids
- Schema built directly from observable attributes in sampling.py
- SamplerConfig converted from dataclass to pydantic BaseModel
(validates types, value ranges, rejects extra fields)
- Cosmetic: update docstrings, remove tqdm RANK positioning,
fix unused variables
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Simplify susceptibilities, add sample_susceptibilities helper
- Clean up compute_susceptibilities: import to_obs_id from observables,
simplify variable names, add output format comment
- Add sample_susceptibilities() convenience wrapper that runs sampling
for multiple weight restrictions and computes susceptibilities in one call
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add standalone test with only public dependencies
Tests sample, susceptibilities, and BIF using EleutherAI/pythia-14m
(public model) and public Pile datasets. Zero aether imports.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Review and simplify all devinterp files (3643 -> 1787 lines)
- bif.py: merge normalize helpers, simplify concatenation, remove
unused imports and dask checks (390 -> 247)
- covariance.py: remove unused Union import (106 -> 103)
- lm_loss.py: remove unused logits return, inline validation,
remove logging (99 -> 47)
- observables.py: use compute_per_token_loss (no logits) (110 unchanged)
- sampling.py: add sample_susceptibilities helper, use
compute_per_token_loss, handle model_path alias in conftest (430)
- writing.py: remove Buffer partial row mask, chunk validation,
dtype validation verbosity, unused exceptions (543 -> 231)
- zarr_schema.py: remove unused coordinate system, group validation,
post-creation validation (493 -> 145)
- conftest.py: handle model_path as alias for model_name_or_path
- standalone test: use public EleutherAI/pythia-14m model
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add high-level susceptibilities() and bif() helpers
Move sample_susceptibilities from sampling.py to susceptibilities.py
as susceptibilities(). Add bif() to bif.py. Both take model + datasets
and handle sampling internally. Low-level compute_* functions remain
for power users.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Expose gradient_accumulation_steps, init_noise; default epoch_mode=cycle
- Add gradient_accumulation_steps and init_noise to sample() signature
- Change epoch_mode default from "once" to "cycle" (avoids errors on
small datasets, no behavior change when dataset is large enough)
- compute_covariance already exposed on bif()/compute_bif()
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Document why parity tests use triangle-40k (CUDA non-determinism)
Pythia's scaled_dot_product_attention has non-deterministic CUDA
kernels. Verified bitwise parity with AETHER_DETERMINISTIC=1.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add metadata to zarr output, fix markers and type hints
- Store sampler config, observables, dtype, chain_buffer_size,
shared_observable_dims, tokenizer_context_length in zarr metadata
- Metadata test compares all fields against aether with explicit
skip lists for fields we can't produce
- Fix mutable default: callbacks=[] -> callbacks=None
- Add missing type hints (num_draws, num_burnin_steps, num_steps_bw_draws)
- Add @pytest.mark.gpu to test_sample_parity tests
- Remove @pytest.mark.nightly from sus/bif tests (fast enough)
- Add TODO notes about removing reference caching
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add LLC computation from sampling results
- llc_new.py: compute_llc() post-processes stored sampling_loss and
init_loss arrays. llc() helper runs sampling + LLC in one call.
Returns llc_mean, llc_std, llc_per_chain, loss_trace, init_loss.
- Parity test verifies LLC matches manual xarray computation
- Standalone test with public model (pythia-14m)
- Named llc_new.py to avoid conflicting with existing llc.py
(LLCEstimator used by old devinterp tests)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove dtype/autocast from library, let user control precision
- Remove dtype parameter and internal autocast from sample() and
sample_single_chain(). Users wrap with torch.autocast themselves.
- Parity tests add torch.autocast to match aether's behavior.
- Standalone tests work without autocast (model dtype is sufficient).
- Verified model on GPU works (deepcopy + .to(device) handles it).
- Update docstrings: "PyTorch model" instead of "PyTorch model on CPU".
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Document why parity tests use triangle-40k (CUDA non-determinism)
Add Pythia-14m parity test with deterministic mode enabled.
Fix tokenizer model_max_length for models that don't set it (Pythia).
Set CUBLAS_WORKSPACE_CONFIG at module level for deterministic CuBLAS.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix review findings: input_ids assertion, HF model support, feed exhaustion, init_noise
- Add input_ids consistency assertion in Observable.compute_loss()
- Handle generic HF models in lm_forward_logits (output.logits fallback)
- Raise RuntimeError on data exhaustion instead of silent StopIteration
- Store init_noise as float|None in SamplerConfig (was bool)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Clean devinterp for public release: move legacy callbacks to aether
Move old callback-based devinterp modules (SamplerCallback, LLCEstimator,
MalaAcceptanceRate, OnlineLossStatistics, old sample()) into aether's
analysis/slt/, where they are used by the legacy LLC action. Remove dead
code (cov, gradient, norms, trace, wbic, mechinterp, vis_utils, backends).
Rename llc_new.py to llc.py now that the old one is moved.
The public devinterp API is now:
- slt/sampling.py: sample() → xr.DataTree
- slt/llc.py: llc(), compute_llc()
- slt/bif.py: bif(), compute_bif()
- slt/susceptibilities.py: susceptibilities(), compute_susceptibilities()
- optim/: SGMCMC, SGLD, SGNHT, Metrics, priors, preconditioners
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Update devinterp docs for new public API
Rewrite README with new sample/llc/bif/susceptibilities API and working
examples. Update Sphinx index with concepts section (LLC, susceptibilities,
BIF) adapted from aether docs. Update autodoc to cover new slt modules,
remove references to deleted modules (backends, mechinterp, vis_utils).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add sampling guide with illustrations from aether
Add sampling.rst with chain/draw/step explanations and SVG figures
adapted from aether's actions.rst. Covers how sampling works, key
parameters, weight restrictions, and output format.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Mark credits/citation as TODO for v2 attribution review
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix ruff F541: remove placeholderless f-strings
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Run make format on devinterp
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add xarray, zarr, datasets, pydantic to devinterp dependencies
The new sampling pipeline requires these at runtime. CI was failing
because devinterp's requirements.txt didn't list them.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove old example notebooks (use old callback API)
These notebooks reference the removed callback-based API (LLCEstimator,
SamplerCallback, MalaAcceptanceRate, etc.) and would need to be
rewritten for the new sample()/llc()/bif() API. New examples are
out of scope for this release.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Address review: GPU markers, dedup types, type hints, public names
- Add @pytest.mark.gpu to LLC parity tests
- Deduplicate EpochMode/SamplingMethodLiteral (canonical in config.py)
- Add type hints to _iter_blocks in bif.py
- Make set_seed and is_transformer_lens_model public (drop underscore)
- Remove old example notebooks (use deleted callback API)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix devinterp deps: declare in pyproject.toml, not requirements.txt
uv doesn't read setuptools dynamic dependencies from requirements.txt.
Move dependency declarations to pyproject.toml so uv.lock includes
xarray, zarr, datasets, and pydantic. Fixes CI ModuleNotFoundError.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add legacy comments to moved files, add missing type hints
Mark all files moved from devinterp to aether with a legacy comment
so reviewers know they're intentionally not modernized. Add type hints
to is_transformer_lens_model, _cycle, and _numpy_fill_value.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Switch docs to furo theme with dark mode
Replace the messy conf.py (duplicate configs, alabaster fallback) with
a clean furo-based config matching aether's style. Dark mode, clean
sidebar, no Pydantic inherited method noise.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Invert SVG figures in dark mode
Add custom.css that applies filter: invert() to .dark-invert images
when furo's dark theme is active, matching aether's behavior.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Move parity tests from devinterp to aether
Parity tests (sample, LLC, BIF, susceptibilities) require both devinterp
and aether. They belong in aether's test suite, not devinterp's. This
keeps devinterp's tests free of aether dependencies so it can be tested
and published independently.
- Move test_*_parity.py, conftest.py, YAML configs, reference_data
to shared/aether/tests/integration/devinterp_parity/
- Keep only test_standalone.py in devinterp (no aether imports)
- Revert CI change (devinterp GPU tests no longer need aether)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add type annotations to make_evaluate_fn and EvaluateFn type alias
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add output formats reference to docs
Document exact xr.Dataset/DataTree structure for sample(),
compute_llc(), compute_bif(), and compute_susceptibilities()
including all variables, dimensions, coordinates, and attributes.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add v1 → v2 migration guide to README
Before/after code example showing the old callback API vs new
data-centric API, with a summary of what changed.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add zarr usage examples to output formats doc
Show how to save to a specific path, reopen saved results, and
post-process from disk.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix test_model_device_moving leaking cuda:1 as default device
torch.cuda.set_device(1) was never reset, causing all subsequent
tests in the same process to use cuda:1 as the default device.
This broke the devinterp parity tests on multi-GPU runners with
"Expected all tensors to be on the same device, cuda:0 and cuda:1".
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Skip mask_padding_in_loss in parity metadata comparison
Aether added mask_padding_in_loss to SamplerArgs in the padding PR.
Devinterp doesn't support padding masks, so skip this field in the
metadata comparison test.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Bump requires-python to >=3.10 (match/case in writing.py)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Skip sketch_metrics in parity metadata comparison
Aether added sketch_metrics to SamplerArgs (count sketches PR).
Devinterp doesn't support count sketches.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Vendor tokenize_and_concatenate, drop transformer_lens and einops
Replace the transformer_lens.utils.tokenize_and_concatenate import with
a vendored copy in devinterp.utils (numpy reshape instead of einops).
Remove 16 unused symbols from devinterp/utils.py and move the ones
aether still needs into aether/analysis/slt/legacy_utils.py.
Add test_dataset_loading_parity.py to verify the standalone dataset
loading path matches get_cached_dataset.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove unused imports flagged by ruff
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix lint and formatting from pre-commit hooks
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add YAML-driven weight_restrictions to devinterp, replacing aether dependency
Implements create_param_masks() in devinterp that supports 96 HuggingFace
model types plus TransformerLens, driven by a static Python dict generated
from johan/head_structures.yaml. No new runtime dependencies (no PyYAML).
Switches parity test conftest to use devinterp's create_param_masks instead
of aether's create_masked_params. Adds mask parity tests verifying exact
equivalence across all restriction types for both triangle-40k (TL) and
pythia-14m-deduped (GPT-NeoX).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add group mask support and model structure validation tests
- Add l0g0 restriction pattern for GQA group masking (Q heads sharing a KV head)
- Vendor tiny_configs.py for creating minimal model configs in tests
- Add test_listed_params_exist: verifies all YAML entries map to real params
(mirrors johan/test_yaml_entries_exist.py, checks named_parameters | state_dict)
- Add test_dead_head_gradient: zeroing O projection kills gradients on masked
head/group elements (mirrors johan/test_head_structure_yaml.py)
- Both tests pass for all 95 model types with transformers 5.5
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Simplify _tiny_configs: table-driven, zero special cases
Replace scattered per-attribute loops with a single _ATTR_TABLE. Gate on
HEAD_STRUCTURES membership, eliminating SKIP_MODELS and all model-specific
hacks (dbrx, deepseek, falcon, etc). Use NUM_HEADS=4 (codegen mp_num
compatibility) and even_if_none for head_dim/num_key_value_heads.
Remove prophetnet (seq2seq, not a causal LM) from _model_structures.py,
eliminating the last special case. _tiny_configs.py is now purely generic.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Support TransformerLens GQA models in weight restrictions
- Add _W_K/_W_V/_b_K/_b_V variants to hooked_transformer entry (TL stores
GQA KV heads with underscore prefix and n_kv size instead of n_heads)
- Fix _get_model_config to read TL's n_key_value_heads (not num_key_value_heads)
- Both MHA and GQA variants are listed; whichever doesn't exist is skipped
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add HF↔TransformerLens mask parity test
Verifies that masking in HF then converting to TL produces the same
magic-value positions as converting then masking in TL. Uses tiny configs
with manual HookedTransformerConfig + TL weight conversion functions.
Covers 8 cases: gpt2, gpt_neox, bloom, llama (MHA + GQA), mistral (GQA),
gemma (GQA), qwen2 (GQA). Runs in ~3s with no network access.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Test all restriction types in TL parity, add mlp.W_gate to TL entry
- Parametrize test over full/l0/l0h1/l0g0/l0 attn/l0 mlp (48 cases)
- Add mlp.W_gate to hooked_transformer mlp list (gated MLP models)
- Exclude embed/unembed from TL parity: tied-weight models make the
comparison inherently asymmetric (HF has one tensor, TL has two)
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Expand TL parity test to 12 architectures, fix OPT word_embed_proj_dim
Add gpt_neo, gptj, opt, phi to TL mask parity test (now 84 cases).
Fix OPT tiny config by adding word_embed_proj_dim to HIDDEN aliases.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Support list-of-dict variants in HEAD_STRUCTURES, bidirectional TL test
- HEAD_STRUCTURES entries can now be a list of dicts (variants). _get_spec
picks the best match by counting how many listed params exist in the model.
- hooked_transformer: MHA variant (W_K) and GQA variant (_W_K) as separate
entries instead of duplicating both in one dict.
- mixtral/qwen2_moe: old transformers (per-expert indexed) and new
transformers (fused expert weights) as separate variants.
- Fix llama attn list: was missing self_attn.k_proj.weight.
- TL parity test now checks bidirectionally (catches missing entries).
Filters out TL-only zero biases via baseline comparison.
3 xfails for Gemma embed scaling and Mixtral MoE gate restructuring.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix missing unembed biases, restore embed+unembed TL parity test
Add missing unembed bias entries for 21 models (lm_head.bias,
cls.predictions.decoder.bias, etc.) found by johan/fix_embed_unembed.py.
Use FG/BG magic values in TL parity test so every param gets a non-default
value and TL-only biases can't hide. Re-enable embed+unembed restriction.
Skip b_K (intentionally omitted, mathematically inert for non-rotary attn).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Use create_param_masks in standalone susceptibilities test
Replace manual param name matching with create_param_masks(model, "l0h1")
to exercise the weight_restrictions module in the integration test.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Add quickstart example and update README with create_param_masks
- examples/quickstart.py: LLC + susceptibilities on Qwen2.5-0.5B
- Update README susceptibilities snippet to use create_param_masks
- Document supported restriction patterns
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Rename tokenize_and_concatenate output column from "tokens" to "input_ids"
Every caller immediately renamed it anyway. Removes a papercut for new users.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove johan/fix_embed_unembed.py from tracking
Utility script, not part of the devinterp package.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix input validation and side effects in devinterp public API
- Fix sample() silently moving user's model to GPU via _write_init_loss
(now saves/restores original device and training mode)
- Fix bif() shadowing sample()'s batch_size/device params (renamed to
bif_batch_size/bif_device so sampling params pass through **kwargs)
- Add upfront validation in susceptibilities(): require "full" key,
sampling_task in observables, and >1 observable
- Add bounds checking in create_param_masks for layer/head/group indices
(previously returned empty dict or raw IndexError)
- Add type check in create_param_masks for non-str/list input
- Validate sequence length >= 2 in sample() (length 1 produced NaN)
- Validate max_length and empty text in tokenize_and_concatenate
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Move test_tl_mask_parity to aether parity tests (needs transformer_lens)
test_tl_mask_parity.py imports transformer_lens, which is an aether
dependency, not a devinterp dependency. The test-devinterp CI job
(which only installs devinterp deps) fails with ModuleNotFoundError.
Move to shared/aether/tests/integration/devinterp_parity/ where
transformer_lens is available.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove HF_HUB_OFFLINE from test_weight_restriction_parity, add gpu marker
The module-level os.environ.setdefault("HF_HUB_OFFLINE", "1") poisoned
the environment for all tests in the same pytest session, causing
unrelated tests to fail with OfflineModeIsEnabled errors.
Also adds @pytest.mark.gpu since these tests load real models from
HuggingFace and should only run on GPU runners with cached models.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix test_dataset_loading_parity after tokenize_and_concatenate rename
tokenize_and_concatenate now outputs "input_ids" directly (changed in
82976b03b), so the rename_column("tokens", "input_ids") call fails.
Remove the now-unnecessary rename.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Fix susceptibilities() output_path overwrite, document output_path in docstrings
susceptibilities() calls sample() once per weight restriction. If a user
passed output_path via **kwargs, every call would overwrite the same zarr.
Now appends the WR name (e.g. samples_full.zarr, samples_l0h1.zarr).
Also mention output_path in the **kwargs docstrings for llc(), bif(),
and susceptibilities() so users know it's available.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* Remove unused devinterp_sampler.py and generate_model_structures.py
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Parameterize dataset loading parity test over pythia and triangle-40k
Drop the tokenizer.model_max_length > 1_000_000 guard now that both
sentinel (pythia) and non-sentinel (triangle-40k) tokenizers are
exercised. The fallback is a no-op for models whose tokenizers already
report a sensible model_max_length.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Initialize sampler loss as tensor; assert susceptibility dims
- sampler.py: loss is now always a float32 tensor, so .detach() and
torch.isnan(loss) don't depend on implicit float->tensor promotion.
Uses `loss = loss + ...` since `torch.zeros(())` doesn't support `+=`
against the (batch, seq_len-1) per-token loss.
- susceptibilities.py: assert (chain, draw, batch_{oid}, token_pos)
layout before positional slicing. Fix stale target_pos -> token_pos
comments.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Simplify Observable by using stored input_ids directly
- compute_loss: use self.input_ids[s:e] directly instead of re-pulling
from a cycled dataloader and asserting equality. The stored tensor
was already populated from the same deterministic loader in __init__.
- __init__: iterate loader once via enumerate, no cycling needed
(DeterministicShuffledSampler + drop_last yields exactly
batches_per_draw batches).
- Drop self._feed and _cycle static method, no longer used.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Relocate legacy_utils and delete orphaned legacy callbacks
- Move get_seeded_dataloader from legacy_utils.py into aether.utils.utils,
next to set_seed.
- Swap test_bif_functions.py to import set_seed from aether.utils.utils
(behaviorally equivalent for int inputs).
- Swap aether/analysis/loss.py to import get_seeded_dataloader from
aether.utils.utils.
- Delete legacy_utils.py: only 2 symbols were used externally (now moved),
the rest (prepare_input, split_results, EvalResults, EvaluateFn,
Outputs, cycle, get_init_loss_multi_batch) had no importers anywhere.
- Delete loss_callbacks.py: OnlineLossStatistics had zero importers.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Move legacy SLT callback files into slt/legacy/ subdirectory
Per Will's review suggestion. Files moved:
callback.py, mala.py, llc_callbacks.py
into:
aether/analysis/slt/legacy/
Updated 5 import sites: sampler.py, llc.py, estimators.py use absolute
imports; mala.py and llc_callbacks.py use same-directory relative
imports within the new subpackage.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Restore SamplerCallback cast in sample action
Collateral from f7e520a2d (devinterp public release) — the cast was
relaxed to Callable when SamplerCallback was dropped from devinterp,
but the aether sampler.sample() signature still declares
Sequence[SamplerCallback]. Re-import SamplerCallback from its new
aether home and restore the cast to match the signature.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add TransformerLens attribution to tokenize_and_concatenate docstring
Both licenses are MIT (TL: Copyright 2022 TransformerLensOrg; devinterp:
Copyright 2025 Ashgro). Expanded the attribution to include the copyright
notice and a note on local divergences: input validation, numpy reshape
in place of einops, and the output column rename from "tokens" to
"input_ids".
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Remove redundant .astype(np.float32) in susceptibilities
The computation's operands are already float32 (SAMPLES_LOSS_DTYPE_STR =
"float32"), and aether's equivalent susc_fast.py has no astype here.
Removing aligns with aether and drops dead work. Parity tests green.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Document _tiny_configs.py purpose and _ATTR_TABLE structure
Per review: explain why the file exists, what make_tiny produces, and
what the (value, [attr_names]) tuple means (handles HuggingFace's
inconsistent attribute naming across model families).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Drop n_beta scalar data var; read from metadata attrs
Per Will's review: scalar-dim data arrays are painful (aether has
workarounds scattered across its writer and test infrastructure, and is
actively migrating away from them). n_beta was the only scalar data var
in devinterp's sample output, and it was already being written into
attrs.metadata.config.sampler via the pydantic config dump.
Changes:
- sampling.py: drop the arrays_meta["n_beta"] spec and the
writer.write("n_beta", ...) call.
- llc.py: read n_beta from attrs["metadata"]["config"]["sampler"]
instead of the data var.
- test_sample_parity.py: add "n_beta" to _AETHER_ONLY_VARS so the
aether reference (which still has it as a data var) strips it
before comparison.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Use dot-access for top-level metadata attr in sample parity test
xarray's __getattr__ falls through to top-level attrs, so
ds.attrs["metadata"] can be written as ds.metadata.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Cache sample() output by path, with completed flag and config check
Per Will's review: add filesystem-backed caching so users don't lose
hours of sampling when a downstream call crashes.
- sample() writes attrs["completed"] = 1 at the end of a successful
run. Interrupted runs leave the flag unset, so partial zarrs won't
be silently reused.
- On subsequent calls with the same output_path, sample() validates
the existing file: readable as zarr, completed flag set, and stored
sampler config matches the current call. Mismatches raise with a
clear "delete and retry" message that includes the path.
- Warns when output_path is None and the run is non-trivial
(num_chains * num_draws >= 32): results land in a temp directory and
are lost on process exit.
bif() and susceptibilities() pick this up for free since they already
pass output_path through to sample(); susceptibilities() derives a
per-WR path so each WR caches independently.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add caching integration tests for bif() and susceptibilities()
- test_bif_caching: completed=1 is written, cache reused on matching
args, RuntimeError raised on sampler config mismatch.
- test_susceptibilities_per_wr_cache: each WR's per-path zarr has
completed=1 after a multi-WR run.
Also:
- Raise the big-run warning threshold to num_chains * total_steps *
batch_size > 1000 and clarify the message (say "save" not
"persist").
- Drop the now-redundant model_max_length > 1_000_000 guard in
test_standalone.py, matching T1.15.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add custom-model integration tests
Verify that sample(), bif(), llc(), and susceptibilities() work with a
plain torch.nn.Module that isn't a HuggingFace wrapper or
TransformerLens model. Exercises the fallback path in lm_forward_logits
(model returns logits directly).
TinyTransformer: Embedding + MultiheadAttention + Linear, returning
logits from forward(). Hand-crafted 'head0' mask mirrors the README's
'l0h0' pattern so susceptibilities test runs on a genuine restriction.
No gpu marker -- runs in test-basic.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Generalize gpt_neox branch to all HuggingFace models
The hasattr(model, "gpt_neox") check was unnecessarily narrow -- the
return_dict=False / use_cache=False optimization applies to any
HuggingFace causal LM, not just Pythia. Replace with
isinstance(model, PreTrainedModel), which covers every HF model we
care about without naming specific families.
Verified: parity tests (including Pythia) pass, test_custom_model
(TinyTransformer fallback) passes, and the Qwen quickstart runs
end-to-end.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add loss_fn parameter to sample/bif/llc/susceptibilities
Per Will's suggestion on lm_loss.py: let users supply their own loss
function, with the existing cross-entropy as the default. The LossFn
type is (model, input_ids) -> per-token loss of shape (batch, seq-1).
A single loss_fn is applied uniformly to the sampling dataset, the
init_loss computation, and every observable -- the loss is a property
of the model, not of which dataset is being evaluated.
- lm_loss.py: define LossFn type; make_evaluate_fn(loss_fn=None)
defaults to compute_per_token_loss.
- observables.py: Observable(loss_fn=None) stores and uses it.
- sampling.py: sample(loss_fn=None) threads through to Observable,
make_evaluate_fn, and _write_init_loss.
- bif/llc/susceptibilities: add explicit loss_fn kwarg for
discoverability (still also flows through **kwargs).
- test_custom_model.py: add test_custom_loss_fn verifying the hook.
- README: document the escape hatch in Model Requirements.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Replace weak custom-loss test with LLC scaling-invariance test
Mathematical invariant: LLC(loss_fn=2*L, n_beta=N/2) equals
LLC(loss_fn=L, n_beta=N). Scaling the loss by 2 doubles the gradient;
halving n_beta cancels (SGLD update is grad.mul(nbeta)), so sampling
trajectories are bit-identical. The LLC arithmetic
n_beta*(mean_loss - init_loss) also cancels.
This is strictly stronger than a "loss_fn was called" check -- it
verifies that the user-supplied loss flows through the gradient path
(for SGLD dynamics), the init_loss computation, and the LLC
arithmetic, end-to-end.
The custom loss is written inline (direct log_softmax + gather) rather
than calling compute_per_token_loss, so the test doesn't depend on the
default implementation details.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Tighten invariance test to exact equality
Empirically verified the SGLD trajectory is bit-identical under the
(2*loss, nbeta/2) transformation, so assert_close with nonzero
tolerance was stronger than needed. Plain `==` is clearer.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Validate loss_fn output shape in _write_init_loss
A user-supplied loss_fn that returns e.g. a scalar would silently
propagate garbage through the sampling path (scalars broadcast into
the zarr writer and observable buffers). _write_init_loss runs first
in sample(), so a single shape check on its first call catches bad
loss_fn before any sampling, gradients, or observable evaluations
happen.
Example messages:
loss_fn must return per-token loss of shape (2, 15), got ()
loss_fn must return per-token loss of shape (2, 15), got (2, 16)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add include_sampling_task option to susceptibilities()
Per Will's review: Max found value in running susceptibilities with
only the sampling_task observable (per-token variation within the
task is still informative). Previously devinterp required at least
one non-sampling observable.
Adds include_sampling_task: bool = False to both susceptibilities()
and compute_susceptibilities(), matching aether's API. Default
behavior unchanged; when True, sampling_task is included in obs_ids
and the pre-check is skipped.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Document loss_fn and include_sampling_task in docstrings
- sample(), bif(), llc(): add loss_fn to the Args section.
- susceptibilities(): add loss_fn and include_sampling_task.
- compute_susceptibilities(): add include_sampling_task and clarify
observable_names now interacts with it.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Add preview_weight_restriction for inspecting mask selection
Port of aether's preview-wr CLI formatting as a plain function. Takes
a model and a ParamMasks dict and prints a tree view with per-key
selected/total counts, or one param name per line in plain mode.
Useful for debugging ambiguous selection patterns (Will's "does unembed
include the layer norm?" example).
No new dependency; uses raw ANSI escapes for colors.
Demoed in examples/quickstart.py on two cases:
- The "l0h0" string-derived mask, with a comment explaining Qwen2.5's
GQA-driven asymmetry (~7% on Q/O, 50% on K/V).
- The hand-crafted manual_masks example at the bottom.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Port micro_callback and sampling_loss_micro zarr vars from aether
Port of aether's MicroCallback protocol + sampling_loss_micro /
sampling_input_ids_micro zarr variables.
- sampler.py: add MicroCallback Protocol, micro_callback param to
sample_single_chain, and call it inside the gradient accumulation
loop with (loss, input_ids, chain, step, micro_step).
- sampling.py: add sampling_loss_micro (chain, step, batch_sampling,
token_pos) and sampling_input_ids_micro (chain, step, batch_sampling,
token) schema entries. Define on_micro closure that writes per-micro
slices to the zarr; pass as micro_callback to sample_single_chain.
- test_sample_parity.py: drop sampling_loss_micro and
sampling_input_ids_micro from _AETHER_ONLY_VARS since they're now
in devinterp's output. Parity now compares them bit-for-bit.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Exercise gradient_accumulation_steps > 1 in parity fixtures
Bump gradient_accumulation_steps to 2 in tiny-sample.yaml,
tiny-sample-pythia.yaml, tiny-bif-1-sample.yaml, and
tiny-sus-1-sample.yaml. Thread the value through the 5 devinterp
sample() call sites in the parity tests.
Caught a bug in the first iteration of T1.19's micro-callback port
that would otherwise have slipped through: the on_micro closure pushed
a batch_size slice instead of a full micro-batch row, which is
invisible at grad_accum=1 because one slice equals the full row.
Previews T1.1's wider "vary all sampler args" coverage.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Expand sample() with 9 config kwargs; partial T1.1 parity-fixture randomization
- sampling.py: sample() now accepts localization, noise_level,
llc_weight_decay, bounding_box_size, sampling_method,
sampling_method_kwargs, shuffle, epoch_mode, and
match_sampling_input_ids_across_chains as kwargs and threads them
into SamplerConfig. Previously these fell back to SamplerConfig
defaults, meaning yaml overrides were silently ignored on the
devinterp side.
- conftest.py: new sampler_kwargs(sampler_cfg) helper picks the
subset of keys that sample() accepts. Used from all 4 parity
tests (sample, llc, bif, sus) to reduce boilerplate.
- Parity yamls now exercise many more args with randomized
non-default values (batch_size, num_chains, num_draws,
num_burnin_steps, num_steps_bw_draws, gradient_accumulation_steps,
num_init_loss_batches, init_seed, localization, noise_level,
llc_weight_decay, bounding_box_size, save_metrics,
sampling_method, shuffle, match_sampling_input_ids_across_chains,
init_noise). Covers Will's T1.1 gated-on-approval concern.
Known failing parity tests this exposes (real aether/devinterp
divergences, not regressions):
- test_llc_parity: LLC diverges when num_chains > 1 AND
gradient_accumulation_steps > 2. Devinterp's compute_llc averages
sampling_loss (draw-level); aether's calculate reduces across micro
samples via sampling_loss_micro. For num_chains=1 the totals
coincide, otherwise they don't.
- test_sample_parity_pythia: init_loss diverges ~8x. Currently
uses rmsprop_sgld + init_noise=true + shuffle=false. Not yet
bisected to the specific culprit.
- test_susceptibilities_parity: susceptibility values diverge ~10x.
Same suspect settings as pythia.
tiny-sample.yaml is set to a combination known to pass (num_chains=3,
gradient_accumulation_steps=2), so test_sample_parity and
test_llc_parity stay green on this yaml. Pythia and sus yamls
intentionally kept in the aggressive T1.1 state so the divergences
remain visible for investigation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Move SAMPLING_METHODS next to SamplingMethodLiteral in config.py
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Drop epoch_mode from sample() public signature
Users don't need the "once" vs "cycle" choice on the public entry
point -- "cycle" is the only sensible mode for typical SGLD runs.
SamplerConfig and sample_single_chain still carry the knob for
internal/extension use, and sample_single_chain's default is
aligned to "cycle" to match SamplerConfig and sample().
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Forward SamplerConfig fields through _write_init_loss
`_write_init_loss` hardcoded shuffle=True, used init_seed as the
dataloader seed for every chain regardless of
match_sampling_input_ids_across_chains, ignored init_noise, and
iterated the loader with zip(range(num_batches), loader) — silently
truncating and dividing by num_batches when the dataset was smaller
than needed. Thread all four through, matching sample_single_chain's
masked init_noise application. Also swap _make_feed's cycle branch to
itertools.cycle to match aether's first-epoch-caching semantics.
Shrink tiny-sample-pythia.yaml's dataset with metadata.limit=10 and
epoch_mode: cycle so the parity test actually exercises the init_loss
cycling path.
* Remove reference-zarr caching from devinterp parity tests
Each parity test had a reference_<X> fixture that cached the aether
subprocess output to a local zarr under reference_data/ and reused
it across runs. The cache could go stale when aether changed without
anyone noticing, and the TODO to remove it was already in place.
Drop the reference_<X> fixtures and point the tests at the uncached
aether_<X> fixture directly. Rename the pythia fixture
reference_sample_pythia → aether_sample_pythia for consistency.
* Broaden SamplerArgs coverage across parity yamls
- tiny-sample.yaml: adopt aggressive config (burnin=1, bw_draws=2,
grad_accum=3, non-default localization/noise_level/llc_wd/bbox,
explicit sampling_method/shuffle/match/init_noise). Set the two
observables to distinct batches_per_draw (2/3) to catch bugs that
assume a shared value.
- tiny-sample-pythia.yaml: add non-default sampling_method_kwargs
(alpha=0.95, eps=0.05) to exercise kwargs pass-through to
SGMCMC.rmsprop_sgld.
- tiny-sus-1-sample.yaml: swap rmsprop_sgld -> sgld (new coverage of
the deprecated SGLD class path) and flip match=true so the (shuffle,
match) = (F,T) corner is tested. Reduce to a single observable to
exercise that path.
- tiny-bif-1-sample.yaml: flip match=false so (T,F) is tested.
Net coverage: sampling_method sgld/sgmcmc_sgld/rmsprop_sgld (3 of 5);
all 4 shuffle x match combinations; sampling_method_kwargs exercised;
single- and multi-observable paths both covered; mixed per-observable
batches_per_draw. sgnht/sgmcmc_sgnht remain blocked by sample()
unconditionally forwarding localization/noise_level/weight_decay that
those constructors don't accept.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Match aether LLC by reducing sampling_loss_micro, not sampling_loss
aether's calculate action reduces `samples.sampling.loss` — which is
the v3 name for `sampling_loss_micro` (per-step, per-micro-batch
losses from micro_callback). compute_llc was averaging
`sampling_loss` (per-draw, post-step), so the scalar LLC differed
even though the upstream sample zarrs matched bit-for-bit.
Switch compute_llc to read `sampling_loss_micro`, reducing over
`batch_sampling, token_pos` for loss_trace and then over `step` for
llc_per_chain. loss_trace now has dims (chain, step) instead of
(chain, draw); docstring updated.
* Gate devinterp parity + standalone tests behind workflow_dispatch
Introduce a `manual` pytest marker for tests that are too slow / HF-heavy
for standard CI but should stay runnable from the GitHub Actions UI.
- Register `manual` in pyproject.toml and shared/aether/pytest.ini.
- Apply `pytestmark = pytest.mark.manual` to every file under
shared/aether/tests/integration/devinterp_parity/ (7 files) and to
the two HF-downloading integration tests under
shared/devinterp/tests/integration/ (test_standalone.py,
test_caching.py). test_custom_model.py and the optim/slt unit tests
stay in CI.
- Append `and not manual` to both MARKER_FILTERs in gpu_test.yml so the
standard PR CI path skips these tests.
- Add .github/workflows/devinterp-parity.yml with `on: workflow_dispatch`
only. Two jobs (aether-side parity, devinterp-side standalone) run
pytest with `-m "manual and not only_multi_gpu"` on a single_gpu
runner. Triggered from Actions -> Devinterp Parity -> Run workflow.
Collection checks:
aether `-m manual` on devinterp_parity/ collects 116 tests.
aether `-m "not manual"` on devinterp_parity/ collects 0.
devinterp `-m manual` on tests/ collects 6 (4 standalone + 2 caching),
374 deselected.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Rename manual marker -> devinterp_parity
Clarify scope: the gated bucket is specifically the devinterp
port's parity + standalone integration tests, not a generic
"manual-only" bucket.
- Rename the pytest marker in pyproject.toml and shared/aether/pytest.ini.
- Update all 9 test files (7 under devinterp_parity/, 2 devinterp
standalone integration tests).
- Update the four MARKER_FILTER lines in gpu_test.yml and the two
pytest -m invocations in devinterp-parity.yml.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Expose rmsprop_eps / rmsprop_alpha as top-level sample() args
Previously users had to pass them through sampling_method_kwargs as
dict keys. They're now first-class kwargs on sample(), validated up
front: only permitted when sampling_method='rmsprop_sgld', and must
not collide with the same key in sampling_method_kwargs. Under the
hood they're merged into sampling_method_kwargs so SamplerConfig and
the serialized metadata keep their existing shape (matching aether's
layout).
Parity conftest translates yaml's sampling_method_kwargs.{eps,alpha}
into the new top-level args on the devinterp side; aether continues
reading them from sampling_method_kwargs. tiny-sample-pythia.yaml
already carries these values, so the pythia parity test now
exercises both code paths end-to-end.
Tests in devinterp cover the validation errors and bitwise
equivalence between top-level args and sampling_method_kwargs.
* Ungate devinterp standalone integration tests
Timing the full shared/devinterp/tests/integration/ folder shows the
whole set runs in ~31s wall time:
test_standalone.py 4 tests ~15.3s (bif 5.0s, sus 3.9s, sample 3.7s, llc 2.8s)
test_caching.py 2 tests ~10.2s (bif_caching 4.4s + 4.8s setup, sus 1.1s)
test_custom_model.py 7 tests ~3.4s
That's squarely in "keep in standard CI" territory, and continuous
regression coverage of the public-facing devinterp API is more valuable
than the saved seconds. Drop the devinterp_parity marker from
test_standalone.py and test_caching.py; they now run on every PR under
gpu_test.yml's test-devinterp-gpu job.
Also remove the now-empty test-devinterp-standalone job from
devinterp-parity.yml; the workflow is now aether-only (spawns aether
subprocesses, genuinely slow).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Split LLC parity onto its own tiny-llc-sample.yaml
test_llc_parity had been reusing tiny-sample.yaml, so its parity
signal was entirely coincident with test_sample_parity. Give LLC an
independent config that also closes three remaining gaps in the
parity-yaml coverage matrix:
- bounding_box_size: None (all 4 existing yamls set a non-None value)
- batches_per_draw: 1 and 4 (only 2 and 3 appeared)
- weight_restrictions: l0 (plain layer, no head filter; new variety)
Also varies lr (0.002) and n_beta (15.0), which were identical across
all 4 existing yamls.
Preserves the known-tricky LLC combo (num_chains>1 and
gradient_accumulation_steps>2) that test_llc_parity originally needed
as a regression guard -- see T1.1 commit b3d39b5e8.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Diversify LLC yaml: rmsprop_sgld with distinct kwargs
Before: 3 of 5 parity yamls used sgmcmc_sgld (sample, llc, bif),
while rmsprop_sgld appeared only on the pythia yaml.
Swap the new LLC yaml to rmsprop_sgld with a distinct kwargs combo
(alpha=0.9, eps=0.2, add_grad_correction=true) from pythia's
(alpha=0.95, eps=0.05). This:
- Rebalances sampling_method distribution (sgmcmc_sgld 3->2,
rmsprop_sgld 1->2).
- Exercises rmsprop on the fast triangle path (pythia is slower and
uses deterministic CUDA algorithms for bitwise repro).
- Covers add_grad_correction=True, which no yaml exercised before.
- Hits two distinct (alpha, eps) points.
Aether's sampler imports SGMCMC.rmsprop_sgld directly from devinterp
(sampler.py:22,67), so any kwarg the factory accepts has guaranteed
parity across the two sides.
Config stays in the safe-combo zone for rmsprop (shuffle=T,
init_noise=F, match=T) -- not the known-failing combination from
T1.1 commit b3d39b5e8.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Shrink devinterp sample() to supported methods only
SamplingMethodLiteral exposed 5 options, but only 3 actually worked
via sample(): sgld, sgmcmc_sgld, rmsprop_sgld. sgnht and sgmcmc_sgnht
TypeError because sample() (and aether's process_sampling_method_kwargs)
unconditionally forward localization / noise_level / weight_decay that
those factories don't accept -- a limitation shared by both libraries.
Rather than paper over this with an inspect.signature filter, cut the
public surface down to what the library actually supports and
recommends:
- sgmcmc_sgld (the modern SGLD implementation)
- rmsprop_sgld (adaptive variant)
Drop sgld, sgnht, sgmcmc_sgnht from SamplingMethodLiteral and
SAMPLING_METHODS. The bare SGLD class already warns "deprecated, use
SGMCMC.sgld instead" on every construction, so keeping it as a public
name was misleading. The SGLD/SGNHT classes themselves stay in
devinterp.optim for anyone who wants to bypass sample() and call
sample_single_chain directly with a custom optimizer class.
Test + yaml fallout:
- tiny-sus-1-sample.yaml: sampling_method sgld -> sgmcmc_sgld.
- test_custom_model.py's rmsprop-arg-rejection test: swap the
non-rmsprop marker method from "sgld" to "sgmcmc_sgld" so it still
validates the error path.
Coverage across the 5 parity yamls now: sample, bif, sus all on
sgmcmc_sgld; pythia and llc on rmsprop_sgld (with distinct kwargs
combos -- alpha/eps for pythia, alpha/eps/add_grad_correction for
llc). Both supported methods exercised.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Document sample()'s optimizer-hyperparameter args
sample() grew 9 kwargs in commit b3d39b5e8 (T1.1 expansion) but the
docstring was never updated. Add entries for: localization, noise_level,
llc_weight_decay, bounding_box_size, sampling_method,
sampling_method_kwargs, shuffle, match_sampling_input_ids_across_chains.
Also extend the "Key Parameters" section of docs/sampling.rst with an
"Optimizer hyperparameters" subsection covering the same knobs plus
init_noise.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Document rmsprop_eps / rmsprop_alpha aliases in sampling.rst
The Optimizer hyperparameters section mentioned that rmsprop's alpha
and eps can go into sampling_method_kwargs, but didn't surface the
top-level rmsprop_eps / rmsprop_alpha convenience aliases that
sample() exposes. Add them.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Drop add_grad_correction=true from LLC parity yaml
RMSpropPreconditioner raises NotImplementedError("Gradient correction
not yet implemented for RMSprop") at construction when
add_grad_correction=True (preconditioner.py:159-161). Aether uses the
same devinterp preconditioner class via SGMCMC.rmsprop_sgld, so this
fails identically on both sides -- not a parity issue, just an
unimplemented feature.
The LLC yaml still covers rmsprop_sgld with distinct (alpha=0.9,
eps=0.2) vs. pythia's (alpha=0.95, eps=0.05), so
sampling_method_kwargs pass-through is still exercised.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Raise ValueError on empty obs_ids in compute_susceptibilities
The high-level susceptibilities() wrapper (cd5620d81) already rejects
an observables dict whose only entry is the sampling task, but the
lower-level compute_susceptibilities() — which the parity test calls
directly to control wr_map construction — never had that check.
Result: obs_ids silently became [], the inner loop produced no
results, and _build_output crashed with an opaque
`IndexError: list index out of range` on `wr_order[0]`.
Mirror the wrapper's validation down into compute_susceptibilities()
with a message that names the available loss_* observables and the
include_sampling_task flag so the fix is obvious.
Also restore the pile-github observable to tiny-sus-1-sample.yaml so
the sus parity fixture exercises a real probe dataset (distinct from
the sampling task). The second observable was dropped in d1dca3975.
* Fix init_noise to respect masks on partially-restricted params
Both libraries' samplers read the parameter mask from
optimizer.param_groups via pg.get("mask", 1.0) or 1.0. That worked for
the legacy SGLD class (which keeps "mask" on each group), but
SGMCMC.__init__ pops "mask" out of each group during construction and
folds it into a MaskPreconditioner. By the time init_noise ran,
pg.get("mask") returned the default 1.0 and noise was added to the
full parameter tensor -- including positions the restriction was meant
to freeze. (It also would have raised
RuntimeError: Boolean value of Tensor is ambiguous on the SGLD path
with a partial mask, since `tensor or 1.0` calls bool(tensor).)
Source the mask from the dict that was built before the optimizer
constructor ran: masked_parameters in aether, param_masks in devinterp.
Matches the pattern already used in devinterp's _write_init_loss
(sampling.py:596-605), which was correct -- so before this fix
devinterp's init_loss was computed on a correctly-masked-noisy model
while the sampling chain started from an unmasked-noisy model. The two
sites now agree.
No behavioral change for weight_restrictions="full" on either side:
mask is None there, so noise is applied everywhere as before.
Scope notes for aether:
- No aether snapshot or reference test sets init_noise=True with a
partial weight restriction (verified across
integration/actions/__snapshots__, seed_stability, test_determinism,
test_vs_known_llc, unit test fixtures). This change is invisible to
the rest of the aether test surface.
- Callers that combined init_noise with a partial WR were previously
getting unintended full-tensor noise; after the fix they get the
masked-only noise the API names suggest.
Exposure: tiny-sample.yaml now sets init_noise: true (was false) on a
partial WR (l0h1) so test_sample_parity compares both sides'
correctly-masked init_loss.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Clean up Sphinx docs: RST math, missing utils page, duplicate modules
- Convert \$\$...\$\$ and \$...\$ in SGLD and SGNHT docstrings to
\`.. math::\` blocks and \`:math:\` roles. Napoleon / RST doesn't
recognize the MyST dollar-delimiter syntax, so the formulas were
rendering as a string of individual italic glyphs. Also replace the
undefined \`:python:\` role with plain \`\`code\`\` spans (same
fix: that role isn't registered anywhere, which was firing its own
sphinx errors).
- Add source/devinterp.utils.rst. The toctree in index.rst referenced
it but the file didn't exist (orphan automodule was living inside
the retired source/devinterp.rst).
- Delete source/devinterp.rst and source/modules.rst. Both duplicated
content that's now in the three targeted \`devinterp.\{slt,optim,utils}.rst\`
files, and both were orphan warnings in the build.
- Fix a missing blank line before :param: in SGMCMC.sgld that was
firing "Unexpected indentation" and "Block quote ends without a
blank line" warnings.
Build goes from 19 warnings to 0 (plus some pre-existing pyright
notes unrelated to docs).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Retain mask on SGMCMC param groups so init_noise can see it
Earlier fix rewrote aether's init_noise loop to source the mask from
masked_parameters.values(). That turned out to hit a KeyError: PyTorch's
Optimizer stores the param-group dicts by reference, so SGMCMC's
destructive group.pop("mask", None) mutates the same dicts held by
masked_parameters. The init_noise loop then saw the mutated (mask-less)
dicts and crashed.
Simpler fix: stop mutating. Change group.pop -> group.get in
SGMCMC._init_group. The mask has already been absorbed into a
MaskPreconditioner by then; keeping the key in the group dict costs only
a few serialized bytes in optimizer.state_dict() and has no functional
impact (SGMCMC.step() uses the preconditioner, not group["mask"]).
Verified neither aether nor devinterp invokes SGMCMC/SGLD
optimizer.state_dict() anywhere.
With SGMCMC no longer popping, aether's original pg-based init_noise
loop works as written -- except for one pre-existing bug: the pattern
mask = pg.get("mask", 1.0) or 1.0
crashes with "bool value of Tensor ambiguous" when the mask is a
multi-element tensor (previously hidden by SGMCMC popping the key).
Replace with an explicit None check. Revert the masked_parameters-based
rewrite entirely -- the shape of the loop matches aether's established
style.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Match aether: _write_init_loss applies init_noise to all params
Aether's write_init_loss passes weight_restrictions="full" to its
sampler (actions/sampling/sampling.py:555) and init_noise=args.init_noise
on line 574, so init_loss is computed at a state where noise has been
added to the entire model -- independent of the weight restriction that
the sampling chain uses.
Devinterp's _write_init_loss was applying init_noise only to
param_masks entries, so test_sample_parity with init_noise=true +
partial WR (l0h1) diverged:
init_loss ACTUAL (devinterp, masked noise): ~8.9
init_loss DESIRED (aether, full noise): ~14.9
Mirror aether: iterate over all named_parameters when applying and
reverting init_noise. The sampling chain itself (sample_single_chain)
still applies masked noise per param_masks -- that behavior is correct
on both sides; only the init_loss reference point is deliberately
unrestricted to match aether.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* Vary device/dtype across parity yamls
All four parity fixtures ran the same pairing — CUDA + bfloat16 autocast —
leaving several code paths untested: CPU execution, the autocast-disabled
float32 branch (which aether takes when dtype='float32'), and
devinterp's "caller wraps their own autocast" contract without an ambient
autocast context.
Spread coverage across the sampling yamls:
tiny-sample.yaml cpu + float32 (no autocast; CPU path)
tiny-sample-pythia.yaml cuda + bfloat16 (unchanged; determinism test)
tiny-llc-sample.yaml cuda + bfloat16 (unchanged; dtype now explicit)
tiny-bif-1-sample.yaml cuda + bfloat16 (unchanged; dtype now explicit)
tiny-sus-1-sample.yaml cuda + float32 (GPU autocast-disabled branch)
Covers 3 of the 4 cells in {cpu,cuda}×{float32,bfloat16}; the missing
cell (cpu+bfloat16) is a rarely-useful combo.
Plumbing in conftest.py:
- ``test_device`` at the top level of the yaml is a test-harness
directive; ``aether_subprocess_env`` maps it to ``PJRT_DEVICE``, and
``strip_test_keys`` removes it before handing the yaml to aether.
- ``parameters.dtype`` is aether-native; the devinterp fixture reads it
to choose between ``torch.autocast`` (bf16/fp16) and
``contextlib.nullcontext()`` (float32), matching aether's
``enabled=(dtype != torch.float32)`` behavior.
- ``load_model_and_data`` now loads the model with the yaml's dtype
instead of hardcoding ``torch.bfloat16``.
Also benches confirm CPU isn't a cost hit for these tiny fixtures
(triangle-40k sampling: 28.8s on CPU vs 30.1s on GPU; CPU wins at this
scale because GPU launch overhead dominates).
* Shrink tiny-bif-1-sample.yaml to reduce BIF post-processing memory
BIF post-processing memory scales roughly with
num_chains * (batch_size * batches_per_draw * num_draws * seq_len)^2,
and this fixture was hitting ~18 GiB peak VRAM during
test_bif_parity on a GPU runner. Dropping the four key knobs to 2
each (num_chains, batch_size, num_steps_bw_draws,
batches_per_draw on both observables) brings peak VRAM to ~4.6 GiB
and runtime from 77s to 52s.
Coverage retained: two chains still exercise chain reduction
(stack mode), both observables are present so inter-observable
correlation is still tested, the unique
match_sampling_input_ids_across_chains=false + shuffle=true +
init_noise=false combo is preserved, and weight_restrictions=full
is unchanged. Pushing further (num_chains=1 or num_draws=1) would
make BIF math degenerate with no chain or draw variance.
* Apply devinterp_parity marker via decorator, not pytestmark
The repo-level test_pyproject_markers_are_used greps for
@pytest.mark.<name> decorators to validate marker usage, so the
module-level `pytestmark = pytest.mark.devinterp_parity` form
registered zero uses and tripped the check.
Replace the 7 module-level assignments with per-function
@pytest.mark.devinterp_parity decorators (stacked above the
existing @pytest.mark.gpu / @pytest.mark.parametrize). Collection
under `-m devinterp_parity` still yields 116 tests, same as before.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
---------
Co-authored-by: Johan Sokrates Wind <johan@timaeus.co>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
* change reqs, install, deps, etc. for devinterp port * fix ci/cd * readme changes * update index.rst * changes to docs, readme, pyproject * update docs * update docs yet again
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.