diff --git a/.env.template b/.env.template new file mode 100644 index 0000000..0852a9b --- /dev/null +++ b/.env.template @@ -0,0 +1,3 @@ +HF_TOKEN= +NDIF_API_KEY= +WANDB_API_KEY= \ No newline at end of file diff --git a/README.md b/README.md index 18aef43..d760d30 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,16 @@ # Crosslayer Transcoder - This repository trains Crosslayer Transcoders and variants with PyTorch/Lightning on multi‑GPU via tensor parallelism. It implements Anthropic’s [crosslayer transcoders](https://transformer-circuits.pub/2024/crosscoders/index.html) and related architectures (per‑layer transcoders, MOLTs, SAEs, Matryoshka CLTs) and supports losses such as ReLU, JumpReLU, TopK, and BatchTopK, for learning human‑interpretable features from LLM activations and building replacement models for [circuit tracing](https://transformer-circuits.pub/2025/attribution-graphs/methods.html). We want to understand the “brain” of LLMs: what their representations encode and what algorithms emerged from circuits. To start, we can learn feature dictionaries with [Sparse Autoencoders](https://transformer-circuits.pub/2023/monosemantic-features) to break activations into human‑interpretable features. They tell us _what_ features representations contain but not _how_ they interact to make circuits and algorithms. For that, we need to sparsify the entire model (we call this a sparse replacement model), not just representations of a single layer. One approach are [Transcoders](https://arxiv.org/abs/2406.11944), which learn features that approximate MLP components, which lets us swap in a replacement model and trace circuits end to end. [Crosslayer transcoders](https://transformer-circuits.pub/2024/crosscoders/index.html) allow features to affect all subsequent layers, essentially letting features live across layers. This yields smaller and more interpretable circuits and enables [circuit tracing](https://transformer-circuits.pub/2025/attribution-graphs/methods.html) and studies of [LLM biology](https://transformer-circuits.pub/2025/attribution-graphs/biology.html). - ## Implemented and Planned Features - > **⚠️ Early Development Disclaimer** > This repository is still in very early development and under active development. It's not yet a stable, production-ready package. There will likely be many breaking changes in the future as the codebase evolves. Use at your own risk and expect API changes between commits. ### Architectures + - ✅ Per-Layer Transcoder (PLT) - ✅ Crosslayer Transcoder (CLT) - ✅ Sparse Mixture of Linear Transforms (MOLT) @@ -20,6 +18,7 @@ We want to understand the “brain” of LLMs: what their representations encode - ⏳ SAEs (by tweaking the activation data extractor) ### Nonlinearities and Loss Functions + - ✅ ReLU and JumpReLU (via straight-through estimators) - ✅ TopK - ✅ BatchTopK (per layer and across layers) @@ -29,6 +28,7 @@ We want to understand the “brain” of LLMs: what their representations encode - ✅ Activation standardization ### Training + - ✅ On-demand activation extraction and streaming using a shared-memory activation buffer - ⚠️ Tensor parallelism using PyTorch DTensor API (requires PyTorch 2.8; comms optimization in progress) - ⏳ Sparse Kernels @@ -37,13 +37,13 @@ We want to understand the “brain” of LLMs: what their representations encode - ✅ Mixed precision (float16 + gradient scaler or bfloat16), gradient accumulation, checkpointing, profiling ### Metrics (logged to WandB during training) + - ✅ Replacement Model Accuracy and KL divergence - ✅ Dead Features - ✅ Feature activation frequency and other statistics - ✅ L0 - ⏳ Replacement Model Score - ## Installation Recommended: use the setup script (it installs uv if needed and creates the venv). @@ -58,17 +58,47 @@ cd crosslayer-transcoder ``` Notes: + - This will create `.venv/` and install from `pyproject.toml`, using `uv.lock` for reproducibility. - For GPU installs, ensure you have a compatible PyTorch build for your CUDA setup. If needed, follow the official PyTorch instructions to select the right wheel for your CUDA version. +## Environment Variables + +### Required Environment Variables + +- `HF_TOKEN` - HuggingFace API token for accessing models and datasets +- `NDIF_API_KEY` - NDIF API key required for training features +- `WANDB_API_KEY` - Weights & Biases API key for experiment tracking and logging +### Setup Options +**Option 1: During installation (recommended)** + +The `setup.sh` script will automatically prompt you for these values. You can press Enter to skip any prompt, but note that all three are required for training to work properly. + +**Option 2: Manual configuration** + +```bash +# Copy the template +cp .env.template .env + +# Edit the .env file with your API keys +nano .env # or use your preferred editor + +``` + +### Getting API Keys + +- **HuggingFace Token**: Get it from [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) +- **NDIF API Key**: Sign up [here](https://ndif.us/) +- **Weights & Biases Key**: Get it from [https://wandb.ai/authorize](https://wandb.ai/authorize) ## How to Use (Configure and Customize with Lightning CLI) You can customize almost everything: datasets and activation extraction, model architecture, loss functions, and all training hyperparameters. This works by using PyTorch Lightning’s CLI to read a YAML config that defines which classes to use and how to compose them. By editing a single `config.yaml`, you control the entire run and keep every parameter in one place; you can still override any field from the command line for quick experiments. - Why this is great + - Single source of truth for all settings → easy to reproduce and share - Composable: swap architectures, losses, data pipelines by changing class entries in YAML - Discoverable and explicit: every knob is visible in one file, with sane defaults @@ -129,7 +159,6 @@ The `config` folder contains example configuration files for different architect 3. **It plugs into the same training loop** - multi‑GPU (DDP) works out of the box 4. **Tensor parallelism works automatically** because PyTorch Lightning handles the distributed setup and PyTorch's Distributed Tensor API shards your model across GPUs without requiring changes to your component code - ## Testing Run the test suite to ensure everything is working correctly: diff --git a/crosslayer_transcoder/main.py b/crosslayer_transcoder/main.py index 2afce89..3324ee5 100755 --- a/crosslayer_transcoder/main.py +++ b/crosslayer_transcoder/main.py @@ -4,6 +4,7 @@ """ import os +from dotenv import load_dotenv import lightning as L from lightning.pytorch.cli import LightningCLI @@ -29,6 +30,7 @@ def add_arguments_to_parser(self, parser): def main(): """Main entry point for training.""" + load_dotenv() # Set up wandb directories os.environ.setdefault("WANDB_DIR", f"{os.getcwd()}/wandb") os.environ.setdefault("WANDB_CACHE_DIR", f"{os.getcwd()}/wandb_cache") diff --git a/pyproject.toml b/pyproject.toml index 9077485..f45dc81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ dependencies = [ "transformers>=4.46.0", "numpy>=1.24.0", "jsonargparse[signatures]>=4.27.7", + "dotenv>=0.9.9", ] [project.optional-dependencies] diff --git a/setup.sh b/setup.sh index 48d1062..269ee8b 100755 --- a/setup.sh +++ b/setup.sh @@ -54,6 +54,82 @@ if [ $? -eq 0 ]; then # Fix permissions for all executables in .venv/bin chmod +x .venv/bin/* + # Setup environment variables + echo "" + echo -e "${BLUE}🔑 Setting up environment variables...${NC}" + + # Create .env from template if it doesn't exist + if [ ! -f ".env" ]; then + if [ -f ".env.template" ]; then + cp .env.template .env + echo -e "${GREEN}✅ Created .env file from template${NC}" + else + echo -e "${YELLOW}⚠️ .env.template not found, creating .env file${NC}" + touch .env + fi + else + echo -e "${GREEN}✅ .env file already exists${NC}" + fi + + # Function to prompt for environment variable + prompt_env_var() { + local var_name="$1" + local var_description="$2" + local current_value="${!var_name}" + + # Check if already set in environment + if [ -n "$current_value" ]; then + echo -e "${GREEN}✅ $var_name already set in environment${NC}" + # Ensure it's in .env file + if ! grep -q "^${var_name}=" .env 2>/dev/null; then + echo "${var_name}=${current_value}" >> .env + fi + return + fi + + # Check if set in .env file + if [ -f ".env" ]; then + local env_value=$(grep "^${var_name}=" .env | cut -d '=' -f 2-) + if [ -n "$env_value" ]; then + echo -e "${GREEN}✅ $var_name already set in .env file${NC}" + export "${var_name}=${env_value}" + return + fi + fi + + # Prompt user + echo "" + echo -e "${YELLOW}${var_description}${NC}" + echo -e "${BLUE}[Press Enter to skip]${NC}" + read -p "${var_name}=" user_input + + # Update .env file + if grep -q "^${var_name}=" .env 2>/dev/null; then + # Update existing line + sed -i.bak "s|^${var_name}=.*|${var_name}=${user_input}|" .env && rm .env.bak + else + # Add new line + echo "${var_name}=${user_input}" >> .env + fi + + # Export if not empty + if [ -n "$user_input" ]; then + export "${var_name}=${user_input}" + echo -e "${GREEN}✅ $var_name set and exported${NC}" + else + echo -e "${YELLOW}⚠️ $var_name left empty (required for training)${NC}" + fi + } + + # Prompt for each environment variable + prompt_env_var "HF_TOKEN" "HuggingFace API token - required for model/dataset access" + prompt_env_var "NDIF_API_KEY" "NDIF API key - required for training" + prompt_env_var "WANDB_API_KEY" "Weights & Biases API key - required for logging" + + echo "" + echo -e "${BLUE}📝 Environment variables saved to .env file${NC}" + echo "" + # Add uv to PATH permanently by updating shell profile UV_PATH_EXPORT='export PATH="$HOME/.local/bin:$PATH"' diff --git a/uv.lock b/uv.lock index 5cd028c..775d136 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = "==3.12.*" resolution-markers = [ "sys_platform == 'linux'", @@ -333,6 +333,7 @@ version = "0.1.0" source = { editable = "." } dependencies = [ { name = "datasets" }, + { name = "dotenv" }, { name = "einops" }, { name = "h5py" }, { name = "jaxtyping" }, @@ -362,6 +363,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "datasets", specifier = ">=3.6.0" }, + { name = "dotenv", specifier = ">=0.9.9" }, { name = "einops", specifier = ">=0.8.1" }, { name = "h5py", specifier = ">=3.13.0" }, { name = "ipykernel", marker = "extra == 'dev'", specifier = ">=6.29.5" }, @@ -458,6 +460,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637", size = 36533, upload-time = "2024-03-15T10:39:41.527Z" }, ] +[[package]] +name = "dotenv" +version = "0.9.9" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dotenv" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/b2/b7/545d2c10c1fc15e48653c91efde329a790f2eecfbbf2bd16003b5db2bab0/dotenv-0.9.9-py2.py3-none-any.whl", hash = "sha256:29cf74a087b31dafdb5a446b6d7e11cbce8ed2741540e2339c69fbef92c94ce9", size = 1892, upload-time = "2025-02-19T22:15:01.647Z" }, +] + [[package]] name = "einops" version = "0.8.1" @@ -1673,6 +1686,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-dotenv" +version = "1.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f0/26/19cadc79a718c5edbec86fd4919a6b6d3f681039a2f6d66d14be94e75fb9/python_dotenv-1.2.1.tar.gz", hash = "sha256:42667e897e16ab0d66954af0e60a9caa94f0fd4ecf3aaf6d2d260eec1aa36ad6", size = 44221, upload-time = "2025-10-26T15:12:10.434Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/1b/a298b06749107c305e1fe0f814c6c74aea7b2f1e10989cb30f544a1b3253/python_dotenv-1.2.1-py3-none-any.whl", hash = "sha256:b81ee9561e9ca4004139c6cbba3a238c32b03e4894671e181b671e8cb8425d61", size = 21230, upload-time = "2025-10-26T15:12:09.109Z" }, +] + [[package]] name = "python-engineio" version = "4.12.2"