diff --git a/README.md b/README.md index 6ad286e6..e3946158 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,14 @@ # DevInterp -[![PyPI version](https://badge.fury.io/py/devinterp.svg)](https://badge.fury.io/py/devinterp) ![Python version](https://img.shields.io/pypi/pyversions/devinterp) ![Contributors](https://img.shields.io/github/contributors/timaeus-research/devinterp) [![Docs](https://img.shields.io/badge/Read_the_Docs!-white?style=flat&logo=Read-the-Docs&logoColor=black)](https://devinterp.timaeus.co/) +[![PyPI version](https://badge.fury.io/py/devinterp.svg)](https://badge.fury.io/py/devinterp) ![Python version](https://img.shields.io/pypi/pyversions/devinterp) ![Contributors](https://img.shields.io/github/contributors/timaeus-research/devinterp) [![Docs](https://img.shields.io/badge/docs-devinterp.timaeus.co-blue?style=flat)](https://devinterp.timaeus.co/) - -## A Python Library for Developmental Interpretability Research - -DevInterp is a python library for conducting research on developmental interpretability, a novel AI safety research agenda rooted in Singular Learning Theory (SLT). DevInterp proposes tools for detecting, locating, and ultimately _controlling_ the development of structure over training. - -[Read more about developmental interpretability](https://www.lesswrong.com/posts/TjaeCWvLZtEDAS5Ex/towards-developmental-interpretability). +DevInterp is [Timaeus](https://timaeus.co)' open source research package, built to allow external researchers to do SLT/DevInterp-style research on Large Language Models. ## Features - **SGLD Sampling** with per-token loss storage to xarray/Zarr - **Local Learning Coefficient (LLC)** estimation from sampling results -- **Susceptibilities** measuring first-order posterior response to data perturbations, localized on model components +- **Susceptibilities** measuring first-order posterior response to data perturbations, optionally restricted to specific model components - **Bayesian Influence Functions (BIF)** as posterior correlations (or covariances) between per-sample losses - **Weight restrictions** for sampling over parameter subsets (e.g., individual attention heads) @@ -27,51 +22,51 @@ uv add devinterp ## Example -See [`examples/quickstart.py`](examples/quickstart.py) for a runnable script that computes LLC and susceptibilities on Qwen2.5-0.5B. +See the [Quickstart Notebook](examples/quickstart.ipynb) ([open in Colab](https://colab.research.google.com/github/timaeus-research/devinterp/blob/main/examples/quickstart.ipynb)) or the [Quickstart Script](examples/quickstart.py) for examples of how to compute LLCs and susceptibilities on Qwen2.5-0.5B (GPU required). ## Quick Start -### Compute the Local Learning Coefficient +### Sampling with Observables ```python -from devinterp.slt.llc import llc +from devinterp.slt.sampling import sample -result = llc( +tree = sample( model=model, - dataset=dataset, # HuggingFace Dataset with "input_ids" - observables={"train": dataset}, + dataset=train_data, + observables={ + "train": train_data, + "code": (code_data, 5), # (dataset, batches_per_draw) + }, lr=0.001, n_beta=30, num_chains=4, num_draws=200, ) - -print(result["llc_mean"]) # scalar LLC -print(result["llc_per_chain"]) # (num_chains,) per-chain LLC -print(result["loss_trace"]) # (num_chains, num_steps) per-step loss, num_steps = num_draws * num_steps_bw_draws + num_burnin_steps +# tree is an xr.DataTree backed by Zarr with full per-token loss traces ``` -### Sample with Observables +### Computing the Local Learning Coefficient ```python -from devinterp.slt.sampling import sample +from devinterp.slt.llc import llc -tree = sample( +result = llc( model=model, - dataset=train_data, - observables={ - "train": train_data, - "code": (code_data, 5), # (dataset, batches_per_draw) - }, + dataset=dataset, # HuggingFace Dataset with "input_ids" + observables={"train": dataset}, lr=0.001, n_beta=30, num_chains=4, num_draws=200, ) -# tree is an xr.DataTree backed by Zarr with full per-token loss traces + +print(result["llc_mean"]) # scalar LLC +print(result["llc_per_chain"]) # (num_chains,) per-chain LLC +print(result["loss_trace"]) # (num_chains, num_steps) per-step loss, num_steps = num_draws * num_steps_bw_draws + num_burnin_steps ``` -### Compute Susceptibilities +### Computing Susceptibilities ```python from devinterp.slt.susceptibilities import susceptibilities @@ -96,7 +91,7 @@ result = susceptibilities( `create_param_masks` supports 85+ HuggingFace model types and TransformerLens. Restriction patterns: `"full"`, `"l0"`, `"l0h1"`, `"l0g0"` (GQA group), `"l0 attn"`, `"l0 mlp"`, `"embed"`, `"unembed"`. -### Compute BIF +### Computing Bayesian Influence Functions ```python from devinterp.slt.bif import bif @@ -172,16 +167,24 @@ llc_value = float(result["llc_mean"]) ## Hyperparameter selection -All sampling is sensitive to hyperparameters. See our [Sampling Hyperparameter Guide](https://timaeus.co/research/2026-04-21-sampling-guide). +All sampling is sensitive to hyperparameters. Our [Sampling Hyperparameter Guide](https://timaeus.co/research/2026-04-21-sampling-guide) covers the three primary knobs — step size (`lr`), inverse temperature (`n_beta`), and localization strength (`localization`) — along with burn-in, steps between draws, and chain count, and walks through diagnosing common failure modes (non-convergence, spikes, NaNs, low signal-to-noise) from the loss traces. ## Further Reading -- [You're Measuring Model Complexity Wrong](https://www.lesswrong.com/posts/6g8cAftfQufLmFDYT/you-re-measuring-model-complexity-wrong) - Introduction to LLC and phase transitions (2024) +Blog Posts: +- [Spectroscopy at Scale: Finding Interpretable Structure in Pythia-1.4B](https://timaeus.co/research/2026-04-21-spectroscopy-main) (2026) +- [Guide for Sampling Hyperparameter Selection](https://timaeus.co/research/2026-04-21-sampling-guide) (2026) + +Papers: - [Structural Inference with Susceptibilities](https://arxiv.org/abs/2504.18274) (2025) - [Towards Spectroscopy: Susceptibility Clusters in Language Models](https://arxiv.org/abs/2601.12703) (2026) - [The Local Learning Coefficient: A Singularity-Aware Complexity Measure](https://arxiv.org/pdf/2308.12108) (2023) -- [Algebraic Geometry and Statistical Learning Theory](https://www.cambridge.org/core/books/algebraic-geometry-and-statistical-learning-theory/9C8FD1BDC817E2FC79117C7F41544A3A#fndtn-information) Watanabe (2009) + +Background: +- [Algebraic Geometry and Statistical Learning Theory](https://www.cambridge.org/core/books/algebraic-geometry-and-statistical-learning-theory/9C8FD1BDC817E2FC79117C7F41544A3A#fndtn-information), Watanabe (2009) +- [Interpreting the Ising Model](https://timaeus.co/research/2026-04-21-spectroscopy-ising) (2026) +- [You're Measuring Model Complexity Wrong](https://www.lesswrong.com/posts/6g8cAftfQufLmFDYT/you-re-measuring-model-complexity-wrong) (2024) ## Credits & Citations @@ -201,3 +204,9 @@ If this package was useful in your work, please cite it as: howpublished = {\url{https://github.com/timaeus-research/devinterp}}, } ``` + +The authors would like to thank Zach Furman, Matthew Farrugia-Roberts, Rohan Hitchcock, and Edmund Lau for useful advice. + +## About Timaeus + +Timaeus is a non-profit advancing AI safety through research in Singular Learning Theory (SLT). We use SLT to understand how training data shapes AI behavior, combining deep mathematical insights from algebraic geometry and statistical physics with empirical research to develop interpretability tools for how capabilities and values emerge during neural network training. This foundational work enables us to build interventions that ensure models are aligned with human values. diff --git a/docs/index.rst b/docs/index.rst index 9da7bf89..3d2bfae4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,17 +1,25 @@ Welcome to DevInterp's documentation! ===================================== -DevInterp is a Python library for conducting research on developmental interpretability, -a novel AI safety research agenda rooted in Singular Learning Theory (SLT). DevInterp -proposes tools for detecting, locating, and ultimately *controlling* the development of -structure over training. +DevInterp is `Timaeus `_' open source research package, built to allow +external researchers to do SLT/DevInterp-style research on Large Language Models. -Read more about `developmental interpretability here `_! - -For questions, `join the DevInterp discord `_! +Source: `github.com/timaeus-research/devinterp `_. +For questions, `Join the SLT and AI Safety Discord `_! .. warning:: This library is under active development. The API may change between releases. + +Features +======== + +- **SGLD Sampling** with per-token loss storage to xarray/Zarr +- **Local Learning Coefficient (LLC)** estimation from sampling results +- **Susceptibilities** measuring first-order posterior response to data perturbations, optionally restricted to specific model components +- **Bayesian Influence Functions (BIF)** as posterior correlations (or covariances) between per-sample losses +- **Weight restrictions** for sampling over parameter subsets (e.g., individual attention heads) + + Installation ============ @@ -24,11 +32,42 @@ Installation **Requirements**: Python 3.10 or higher. +Example +======= + +See the `Quickstart Notebook `_ +(`open in Colab `_) +or the `Quickstart Script `_ +for examples of how to compute LLCs and susceptibilities on Qwen2.5-0.5B (GPU required). + + Quick Start =========== -Compute the Local Learning Coefficient ---------------------------------------- +Sampling with Observables +------------------------- + +.. code-block:: python + + from devinterp.slt.sampling import sample + + tree = sample( + model=model, + dataset=train_data, + observables={ + "train": train_data, + "code": (code_data, 5), # (dataset, batches_per_draw) + }, + lr=0.001, + n_beta=30, + num_chains=4, + num_draws=200, + ) + # tree is an xr.DataTree backed by Zarr with full per-token loss traces + + +Computing the Local Learning Coefficient +---------------------------------------- .. code-block:: python @@ -49,84 +88,52 @@ Compute the Local Learning Coefficient # num_steps = num_draws * num_steps_bw_draws + num_burnin_steps -Sample with Observables ------------------------ +Computing Susceptibilities +-------------------------- .. code-block:: python - from devinterp.slt.sampling import sample + from devinterp.slt.susceptibilities import susceptibilities + from devinterp.slt.weight_restrictions import create_param_masks - tree = sample( + result = susceptibilities( model=model, dataset=train_data, - observables={ - "train": train_data, - "code": (code_data, 5), # (dataset, batches_per_draw) + observables={"train": train_data, "code": code_data}, + weight_restrictions={ + "full": None, + "l0h0": create_param_masks(model, "l0h0"), + "l0h1": create_param_masks(model, "l0h1"), }, + sampling_task="train", lr=0.001, n_beta=30, - num_chains=4, - num_draws=200, ) - # tree is an xr.DataTree backed by Zarr with full per-token loss traces - - -Concepts -======== - -Posterior Sampling with SGLD ----------------------------- - -The core workflow: - -1. Start at a checkpoint :math:`\hat{w}^*` -2. Take SGLD steps (SGD + noise) using one dataset for gradients -3. Evaluate losses on multiple datasets (observables) at each draw -4. Store the full per-token loss chains as Zarr datasets -5. Compute observables (LLC, susceptibilities, BIF) from these chains - -The SGLD noise allows exploring low-loss directions while staying near the original -checkpoint. This samples from the local posterior distribution around the checkpoint. + # result is a DataTree with /susceptibilities and /context subtrees -Local Learning Coefficient (LLC) --------------------------------- +``create_param_masks`` supports 85+ HuggingFace model types and TransformerLens. +Restriction patterns: ``"full"``, ``"l0"``, ``"l0h1"``, ``"l0g0"`` (GQA group), +``"l0 attn"``, ``"l0 mlp"``, ``"embed"``, ``"unembed"``. -The **LLC** measures model complexity by counting "effective parameters" in a region of -weight space: -.. math:: +Computing Bayesian Influence Functions +-------------------------------------- - \hat{\lambda}(\hat{w}^*) = n\beta \cdot (\bar{L}_n - L_n(\hat{w}^*)) - -Unlike parameter count or Hessian rank, LLC accounts for **singularities** -- regions where -multiple parameter configurations produce identical outputs. This makes it suitable for -neural networks. - -**Why LLC matters:** - -- **Detect phase transitions** during training (sudden capability changes) -- **Predict generalization** via the Free Energy formula -- **Compare checkpoints** across training - -Susceptibilities ----------------- - -**Susceptibilities** measure how a model component responds to distribution shifts. For -example, how does an attention head's behavior change when shifting from general text toward -code or math? - -This is computed by sampling with different **weight restrictions** (parameter subsets) and -measuring the covariance between sampling loss and observable loss. - -See `Structural Inference: Interpreting Small Language Models with Susceptibilities -`_ (Baker et al., 2025) for details. +.. code-block:: python -Bayesian Influence Functions (BIF) ----------------------------------- + from devinterp.slt.bif import bif -**BIF** computes pairwise correlations between observable loss traces across sequences from -SGLD sampling results. This reveals which sequences influence each other's loss under -posterior sampling, providing a measure of functional similarity. + result = bif( + model=model, + dataset=train_data, + observables={"train": train_data, "code": code_data}, + lr=0.001, + n_beta=30, + num_chains=4, + num_draws=200, + correlation_method="token", # or "sequence" + ) + # result["influences"] contains pairwise correlation matrix Architecture @@ -156,26 +163,86 @@ sequences: ``HookedTransformer``, or any model returning a tensor or object with ``.logits``) - Dataset must be a HuggingFace ``Dataset`` with an ``"input_ids"`` column of uniform-length sequences -- Loss is next-token cross-entropy +- Loss defaults to next-token cross-entropy + +For non-standard losses, pass ``loss_fn=...`` to ``sample()``, ``bif()``, ``llc()``, +or ``susceptibilities()``. The function takes ``(model, input_ids)`` and must return +per-token loss of shape ``(batch, seq_len-1)``. For more exotic control, +``sample_single_chain()`` in ``devinterp.slt.sampler`` accepts a custom ``evaluate`` +callable. + -For non-standard models, ``sample_single_chain()`` in ``devinterp.slt.sampler`` accepts a -custom ``evaluate`` callable. +Migrating from v1 +================= + +The v2 API replaces the callback-based sampling with a data-centric pipeline. Key +changes: + +.. code-block:: python + + # v1 (old) + from devinterp.slt.sampler import estimate_learning_coeff_with_summary + from devinterp.optim import SGLD + + result = estimate_learning_coeff_with_summary( + model, loader, + sampling_method=SGLD, + sampling_method_kwargs={"lr": 0.001, "nbeta": 30}, + num_chains=4, num_draws=200, + ) + llc = result["llc/mean"] + + # v2 (new) + from devinterp.slt.llc import llc + + result = llc( + model=model, + dataset=dataset, # HF Dataset, not DataLoader + observables={"train": dataset}, + lr=0.001, n_beta=30, + num_chains=4, num_draws=200, + ) + llc_value = float(result["llc_mean"]) + +**What changed:** + +- ``estimate_learning_coeff`` / ``LLCEstimator`` / ``SamplerCallback`` → ``llc()`` and ``compute_llc()`` +- ``DataLoader`` → HuggingFace ``Dataset`` with ``"input_ids"`` column +- ``sampling_method_kwargs={"nbeta": ...}`` → ``n_beta=...`` as a direct parameter +- Results are ``xr.Dataset`` / ``xr.DataTree``, not dicts with string keys +- New capabilities: ``susceptibilities()``, ``bif()``, observables, weight restrictions, per-token loss storage Hyperparameter selection ======================== -All sampling is sensitive to hyperparameters. See our `Sampling Hyperparameter Guide -`_. +All sampling is sensitive to hyperparameters. Our `Sampling Hyperparameter Guide +`_ covers the three primary knobs — +step size (``lr``), inverse temperature (``n_beta``), and localization strength +(``localization``) — along with burn-in, steps between draws, and chain count, and walks +through diagnosing common failure modes (non-convergence, spikes, NaNs, low +signal-to-noise) from the loss traces. + Further Reading =============== -- `You're Measuring Model Complexity Wrong `_ - Introduction to LLC and phase transitions (2024) +Blog Posts: + +- `Spectroscopy at Scale: Finding Interpretable Structure in Pythia-1.4B `_ (2026) +- `Guide for Sampling Hyperparameter Selection `_ (2026) + +Papers: + - `Structural Inference with Susceptibilities `_ (2025) - `Towards Spectroscopy: Susceptibility Clusters in Language Models `_ (2026) - `The Local Learning Coefficient: A Singularity-Aware Complexity Measure `_ (2023) -- `Algebraic Geometry and Statistical Learning Theory `_ Watanabe (2009) + +Background: + +- `Algebraic Geometry and Statistical Learning Theory `_, Watanabe (2009) +- `Interpreting the Ising Model `_ (2026) +- `You're Measuring Model Complexity Wrong `_ (2024) Credits & Citations @@ -200,6 +267,20 @@ If this package was useful in your work, please cite it as: howpublished = {\url{https://github.com/timaeus-research/devinterp}}, } +The authors would like to thank Zach Furman, Matthew Farrugia-Roberts, Rohan Hitchcock, +and Edmund Lau for useful advice. + + +About Timaeus +============= + +Timaeus is a non-profit advancing AI safety through research in Singular Learning Theory +(SLT). We use SLT to understand how training data shapes AI behavior, combining deep +mathematical insights from algebraic geometry and statistical physics with empirical +research to develop interpretability tools for how capabilities and values emerge during +neural network training. This foundational work enables us to build interventions that +ensure models are aligned with human values. + Guides ====== diff --git a/docs/sampling.rst b/docs/sampling.rst index 3a7244f5..f056d23a 100644 --- a/docs/sampling.rst +++ b/docs/sampling.rst @@ -4,8 +4,8 @@ Sampling ======== DevInterp uses Stochastic Gradient Langevin Dynamics (SGLD) to sample from the posterior -distribution around model parameters. This is the foundation for computing LLC, -susceptibilities, and BIF. +distribution around model parameters. This is the foundation for computing Local Learning Coefficients (LLCs), +Susceptibilities, and Bayesian Influence Functions (BIFs). .. figure:: figures/sample-macro.svg :class: dark-invert dark-screen @@ -19,10 +19,9 @@ How Sampling Works The ``sample()`` function: -1. Initializes the model at a checkpoint -2. Runs multiple independent SGLD chains -3. At each draw, evaluates per-token loss on the sampling dataset and all observables -4. Writes everything to a Zarr store, returned as an ``xr.DataTree`` +1. Runs multiple independent SGLD chains +2. At each draw, evaluates per-token loss on the sampling dataset and all observables +3. Writes everything to a Zarr store, returned as an ``xr.DataTree`` .. code-block:: python diff --git a/examples/quickstart.ipynb b/examples/quickstart.ipynb new file mode 100644 index 00000000..e921db7d --- /dev/null +++ b/examples/quickstart.ipynb @@ -0,0 +1,212 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# DevInterp Quickstart\n", + "\n", + "LLC, susceptibilities, and weight restrictions on Qwen2.5-0.5B.\n", + "\n", + "GPU recommended. On Colab: **Runtime → Change runtime type → T4 GPU** (or better)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -q devinterp transformers datasets ipywidgets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from datasets import load_dataset\n", + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "\n", + "from devinterp.slt.llc import llc\n", + "from devinterp.slt.susceptibilities import susceptibilities\n", + "from devinterp.slt.weight_restrictions import (\n", + " create_param_masks,\n", + " preview_weight_restriction,\n", + ")\n", + "from devinterp.utils import default_nbeta, tokenize_and_concatenate\n", + "\n", + "MODEL = \"Qwen/Qwen2.5-0.5B\"\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load model and data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n", + "\n", + "raw = load_dataset(\"NeelNanda/pile-10k\", split=\"train\")\n", + "ds = tokenize_and_concatenate(\n", + " raw.select(range(200)),\n", + " tokenizer,\n", + " column_name=\"text\",\n", + " add_bos_token=False,\n", + " max_length=256,\n", + ")\n", + "# A second dataset for probing susceptibilities\n", + "probe = tokenize_and_concatenate(\n", + " raw.select(range(200, 400)),\n", + " tokenizer,\n", + " column_name=\"text\",\n", + " add_bos_token=False,\n", + " max_length=256,\n", + ")\n", + "\n", + "n_beta = default_nbeta(BATCH_SIZE)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## LLC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = llc(\n", + " model=model,\n", + " dataset=ds,\n", + " observables={\"train\": ds},\n", + " lr=1e-4,\n", + " n_beta=n_beta,\n", + " num_chains=2,\n", + " num_draws=50,\n", + " batch_size=BATCH_SIZE,\n", + " num_init_loss_batches=4,\n", + ")\n", + "print(f\"LLC: {result['llc_mean']:.2f} +/- {result['llc_std']:.2f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Weight restrictions\n", + "\n", + "Preview which params a restriction selects before running susceptibilities.\n", + "`\"l0h0\"` = layer 0, head 0. For Qwen2.5 (14 Q heads, 2 KV heads via GQA), Q and O are per-head, so selecting a single head hits ~7% (1/14). K and V are per-KV-head, and each KV head is shared by 7 Q heads, so selecting any of those 7 picks up the full shared KV head → 50% (1/2)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "l0h0_mask = create_param_masks(model, \"l0h0\")\n", + "preview_weight_restriction(model, l0h0_mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Susceptibilities with weight restrictions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = susceptibilities(\n", + " model=model,\n", + " dataset=ds,\n", + " observables={\"train\": (ds, 2), \"probe\": (probe, 2)},\n", + " weight_restrictions={\n", + " \"full\": None,\n", + " \"l0h0\": l0h0_mask,\n", + " \"l0h1\": create_param_masks(model, \"l0h1\"),\n", + " },\n", + " sampling_task=\"train\",\n", + " lr=1e-4,\n", + " n_beta=n_beta,\n", + " num_chains=2,\n", + " num_draws=50,\n", + " batch_size=BATCH_SIZE,\n", + " num_init_loss_batches=4,\n", + ")\n", + "sus = result[\"susceptibilities\"].dataset\n", + "print(f\"Susceptibilities shape: {dict(sus.dims)}\")\n", + "print(sus[\"sus\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Manual weight restrictions\n", + "\n", + "For unsupported architectures, build masks directly. A mask is just `{param_name: bool_tensor | None}` where `None` means unrestricted. Example: restrict to the first MLP layer's gate projection and the first half of its up projection." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "manual_masks = {}\n", + "for name, param in model.named_parameters():\n", + " if \"model.layers.0.mlp.gate_proj\" in name:\n", + " manual_masks[name] = None # optimize this entire param\n", + " elif \"model.layers.0.mlp.up_proj\" in name:\n", + " # partially mask: only first half of neurons\n", + " mask = torch.zeros_like(param, dtype=torch.bool)\n", + " mask[: param.shape[0] // 2] = True\n", + " manual_masks[name] = mask\n", + "\n", + "preview_weight_restriction(model, manual_masks)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/devinterp/slt/bif.py b/src/devinterp/slt/bif.py index 8b9f45de..ba4131aa 100644 --- a/src/devinterp/slt/bif.py +++ b/src/devinterp/slt/bif.py @@ -17,7 +17,7 @@ import torch import xarray as xr from datasets import Dataset -from tqdm import tqdm +from tqdm.auto import tqdm from devinterp.slt.covariance import ( batch_corrcoef, diff --git a/src/devinterp/slt/sampler.py b/src/devinterp/slt/sampler.py index 75d5f567..7318f555 100644 --- a/src/devinterp/slt/sampler.py +++ b/src/devinterp/slt/sampler.py @@ -21,7 +21,7 @@ from devinterp.slt.lm_loss import NonFiniteLogitsError from torch import nn from torch.utils.data import DataLoader -from tqdm import trange +from tqdm.auto import trange # param name -> mask tensor (or None for unrestricted). # Only params in the dict are optimized; all others are frozen.