Skip to content

lol #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 81 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
81 commits
Select commit Hold shift + click to select a range
4aea538
Fix BatchTopKSAE training
mntss Oct 14, 2024
fe54b00
Add a simple end to end test
adamkarvonen Dec 17, 2024
9ed4af2
Rename input to inputs per nnsight 0.3.0
adamkarvonen Dec 17, 2024
807f6ef
Complete nnsight 0.2 to 0.3 changes
adamkarvonen Dec 18, 2024
dc30720
Fix frac_alive calculation, perform evaluation over multiple batches
adamkarvonen Dec 18, 2024
067bf7b
Obtain better test results using multiple batches
adamkarvonen Dec 18, 2024
05fe179
Add early stopping in forward pass
adamkarvonen Dec 18, 2024
f1b9b80
Change save_steps to a list of ints
adamkarvonen Dec 18, 2024
d350415
Check for is_tuple to support mlp / attn submodules
adamkarvonen Dec 18, 2024
2ec1890
Merge pull request #26 from mntss/batchtokp_aux_fix
adamkarvonen Dec 18, 2024
c4eed3c
Merge pull request #30 from adamkarvonen/add_tests
adamkarvonen Dec 18, 2024
d416eab
Ensure activation buffer has the correct dtype
adamkarvonen Dec 20, 2024
552a8c2
Fix JumpReLU training and loading
adamkarvonen Dec 20, 2024
712eb98
Begin creation of demo script
adamkarvonen Dec 20, 2024
dcc02f0
Modularize demo script
adamkarvonen Dec 21, 2024
32d198f
Track threshold for batchtopk, rename for consistency
adamkarvonen Dec 21, 2024
b5821fd
Track thresholds for topk and batchtopk during training
adamkarvonen Dec 22, 2024
57f451b
Remove demo script and graphing notebook
adamkarvonen Dec 22, 2024
81968f2
Add option to normalize dataset activations
adamkarvonen Dec 26, 2024
488a154
Fix topk bfloat16 dtype error
adamkarvonen Dec 26, 2024
484ca01
Add bias scaling to topk saes
adamkarvonen Dec 26, 2024
8b95ec9
Use the correct standard SAE reconstruction loss, initialize W_dec to…
adamkarvonen Dec 26, 2024
efd76b1
Also scale topk thresholds when scaling biases
adamkarvonen Dec 26, 2024
67a7857
Merge pull request #31 from saprmarks/add_demo
adamkarvonen Dec 26, 2024
9687bb9
Remove leftover variable, update expected results with standard SAE i…
adamkarvonen Dec 26, 2024
f0bb66d
Track lr decay implementation
adamkarvonen Dec 27, 2024
e0db40b
Clean up lr decay
adamkarvonen Dec 27, 2024
911b958
Add sparsity warmup for trainers with a sparsity penalty
adamkarvonen Dec 27, 2024
a11670f
Merge pull request #32 from saprmarks/add_sparsity_warmup
adamkarvonen Dec 27, 2024
a2d6c43
Standardize learning rate and sparsity schedules
adamkarvonen Dec 29, 2024
e00fd64
Properly set new parameters in end to end test
adamkarvonen Dec 30, 2024
316dbbe
Merge pull request #33 from saprmarks/lr_scheduling
adamkarvonen Dec 30, 2024
1df47d8
Make sure we step the learning rate scheduler
adamkarvonen Dec 30, 2024
8ade55b
Initial matroyshka implementation
adamkarvonen Dec 31, 2024
764d4ac
Fix loading matroyshkas from_pretrained()
adamkarvonen Dec 31, 2024
5383603
norm the correct decoder dimension
adamkarvonen Jan 1, 2025
ceabbc5
Add temperature scaling to matroyshka
adamkarvonen Jan 1, 2025
3e31571
Format with ruff
adamkarvonen Jan 1, 2025
8eaa8b2
Use kaiming initialization if specified in paper, fix batch_top_k aux…
adamkarvonen Jan 1, 2025
ec961ac
Fix jumprelu training
adamkarvonen Jan 1, 2025
c2fe5b8
Add option to ignore bos tokens
adamkarvonen Jan 1, 2025
810dbb8
Add notes
adamkarvonen Jan 2, 2025
3b03b92
Add trainer number to wandb name
adamkarvonen Jan 2, 2025
77da794
Log number of dead features to wandb
adamkarvonen Jan 2, 2025
936a69c
Log dead features for batch top k SAEs
adamkarvonen Jan 2, 2025
370272a
Prevent wandb cuda multiprocessing errors
adamkarvonen Jan 2, 2025
0ff687b
Add a verbose option during training
adamkarvonen Jan 3, 2025
92648d4
Merge pull request #34 from adamkarvonen/matroyshka
adamkarvonen Jan 3, 2025
9751c57
Consolidate LR Schedulers, Sparsity Schedulers, and constrained optim…
adamkarvonen Jan 3, 2025
f19db98
Merge pull request #35 from saprmarks/code_cleanup
adamkarvonen Jan 3, 2025
cfb36ff
Add April Update Standard Trainer
adamkarvonen Jan 3, 2025
8316a44
Add an option to pass LR to TopK trainers
adamkarvonen Jan 3, 2025
3c5a5cd
Save state dicts to cpu
adamkarvonen Jan 3, 2025
832f4a3
Add torch autocast to training loop
adamkarvonen Jan 7, 2025
17aa5d5
Disable autocast for threshold tracking
adamkarvonen Jan 7, 2025
65e7af8
Also update context manager for matroyshka threshold
adamkarvonen Jan 7, 2025
52b0c54
By default, don't normalize Gated activations during inference
adamkarvonen Jan 7, 2025
8363ff7
Import trainers from correct relative location for submodule use
adamkarvonen Jan 7, 2025
c697d0f
Make sure x is on the correct dtype for jumprelu when logging
adamkarvonen Jan 7, 2025
6c2fcfc
Remove experimental matroyshka temperature
adamkarvonen Jan 8, 2025
200ed3b
Normalize decoder after optimzer step
adamkarvonen Jan 13, 2025
0af1971
Standardize and fix topk auxk loss implementation
adamkarvonen Jan 13, 2025
34eefda
Merge pull request #36 from saprmarks/aux_loss_fixes
adamkarvonen Jan 13, 2025
db2b564
Make sure to detach reconstruction before calculating aux loss
adamkarvonen Jan 13, 2025
77f2690
Add citation
adamkarvonen Jan 14, 2025
784a62a
Fix incorrect auxk logging name
adamkarvonen Jan 14, 2025
aa45bf6
Fix matryoshka spelling
adamkarvonen Jan 16, 2025
505a445
Use torch.split() instead of direct indexing for 25% speedup
adamkarvonen Jan 16, 2025
43421f5
simplify matryoshka loss
adamkarvonen Jan 16, 2025
0ff8888
feat: pypi packaging and auto-release with semantic release
chanind Feb 10, 2025
a711efe
Merge pull request #37 from chanind/pypi-package
saprmarks Feb 12, 2025
07975f7
0.1.0
invalid-email-address Feb 12, 2025
944edd1
Add a pytorch activation buffer, enable model truncation
adamkarvonen May 13, 2025
c644ccd
Add better dataset generators
adamkarvonen May 13, 2025
17a41c7
Add optional backup step
adamkarvonen May 13, 2025
fe9d8c7
add activault buffer implementation
adamkarvonen May 13, 2025
d639166
Merge pull request #42 from saprmarks/mixed_datasets
adamkarvonen May 13, 2025
c7b2527
assert right padding for remove_bos logic
andyrdt May 18, 2025
59abc88
mask out bos activations
andyrdt May 18, 2025
b2bb4b5
Merge pull request #44 from andyrdt/andyrdt/bos_bugfix
adamkarvonen May 18, 2025
61ac634
add handling of device and dtype to IdentityDict
canrager Jul 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
name: build

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
build:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Cache Huggingface assets
uses: actions/cache@v4
with:
key: huggingface-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
path: ~/.cache/huggingface
restore-keys: |
huggingface-0-${{ runner.os }}-${{ matrix.python-version }}-
- name: Load cached Poetry installation
id: cached-poetry
uses: actions/cache@v4
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-1 # increment to reset cache
- name: Install Poetry
if: steps.cached-poetry.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
venv-0-${{ runner.os }}-${{ matrix.python-version }}-
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
- name: Run Unit Tests
run: poetry run pytest tests/unit
- name: Build package
run: poetry build

release:
needs: build
permissions:
contents: write
id-token: write
# https://github.community/t/how-do-i-specify-job-dependency-running-in-another-workflow/16482
if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):')
runs-on: ubuntu-latest
concurrency: release
environment:
name: pypi
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Semantic Release
id: release
uses: python-semantic-release/[email protected]
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
if: steps.release.outputs.released == 'true'
- name: Publish package distributions to GitHub Releases
uses: python-semantic-release/upload-to-gh-release@main
if: steps.release.outputs.released == 'true'
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ ipython_config.py
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
Expand Down
669 changes: 669 additions & 0 deletions CHANGELOG.md

Large diffs are not rendered by default.

32 changes: 23 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks and Aaron Mueller.
This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks, Adam Karvonen, and Aaron Mueller.

For accessing, saving, and intervening on NN activations, we use the [`nnsight`](http://nnsight.net/) package; as of March 2024, `nnsight` is under active development and may undergo breaking changes. That said, `nnsight` is easy to use and quick to learn; if you plan to modify this repo, then we recommend going through the main `nnsight` demo [here](https://nnsight.net/notebooks/tutorials/walkthrough/).

Some dictionaries trained using this repository (and asociated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries.
Some dictionaries trained using this repository (and associated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. SAEs trained with `dictionary_learning` can be evaluated with [SAE Bench](https://www.neuronpedia.org/sae-bench/info) using a convenient [evaluation script](https://github.com/adamkarvonen/SAEBench/tree/main/sae_bench/custom_saes).

# Set-up

Navigate to the to the location where you would like to clone this repo, clone and enter the repo, and install the requirements.
```bash
git clone https://github.com/saprmarks/dictionary_learning
cd dictionary_learning
pip install -r requirements.txt
pip install dictionary-learning
```

To use `dictionary_learning`, include it as a subdirectory in some project's directory and import it; see the examples below.
We also provide a [demonstration](https://github.com/adamkarvonen/dictionary_learning_demo), which trains and evaluates 2 SAEs in ~30 minutes before plotting the results.

# Using trained dictionaries

Expand Down Expand Up @@ -61,7 +59,9 @@ This allows us to implement different training protocols (e.g. p-annealing) for
Specifically, this repository supports the following trainers:
- [`StandardTrainer`](trainers/standard.py): Implements a training scheme similar to that of [Bricken et al., 2023](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder).
- [`GatedSAETrainer`](trainers/gdm.py): Implements the training scheme for Gated SAEs described in [Rajamanoharan et al., 2024](https://arxiv.org/abs/2404.16014).
- [`AutoEncoderTopK`](trainers/top_k.py): Implemented the training scheme for Top-K SAEs described in [Gao et al., 2024](https://arxiv.org/abs/2406.04093).
- [`TopKSAETrainer`](trainers/top_k.py): Implemented the training scheme for Top-K SAEs described in [Gao et al., 2024](https://arxiv.org/abs/2406.04093).
- [`BatchTopKSAETrainer`](trainers/batch_top_k.py): Implemented the training scheme for Batch Top-K SAEs described in [Bussmann et al., 2024](https://arxiv.org/abs/2412.06410).
- [`JumpReluTrainer`](trainers/jumprelu.py): Implemented the training scheme for JumpReLU SAEs described in [Rajamanoharan et al., 2024](https://arxiv.org/abs/2407.14435).
- [`PAnnealTrainer`](trainers/p_anneal.py): Extends the `StandardTrainer` by providing the option to anneal the sparsity parameter p.
- [`GatedAnnealTrainer`](trainers/gated_anneal.py): Extends the `GatedSAETrainer` by providing the option for p-annealing, similar to `PAnnealTrainer`.

Expand Down Expand Up @@ -121,8 +121,11 @@ ae = trainSAE(
```
Some technical notes our training infrastructure and supported features:
* Training uses the `ConstrainedAdam` optimizer defined in `training.py`. This is a variant of Adam which supports constraining the `AutoEncoder`'s decoder weights to be norm 1.
* Neuron resampling: if a `resample_steps` argument is passed to `trainSAE`, then dead neurons will periodically be resampled according to the procedure specified [here](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling).
* Learning rate warmup: if a `warmup_steps` argument is passed to `trainSAE`, then a linear LR warmup is used at the start of training and, if doing neuron resampling, also after every time neurons are resampled.
* Neuron resampling: if a `resample_steps` argument is passed to the Trainer, then dead neurons will periodically be resampled according to the procedure specified [here](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling).
* Learning rate warmup: if a `warmup_steps` argument is passed to the Trainer, then a linear LR warmup is used at the start of training and, if doing neuron resampling, also after every time neurons are resampled.
* Sparsity penalty warmup: if a `sparsity_warmup_steps` is passed to the Trainer, then a linear warmup is applied to the sparsity penalty at the start of training.
* Learning rate decay: if a `decay_start` is passed to the Trainer, then a linear LR decay is used from `decay_start` to the end of training.
* If `normalize_activations` is True and passed to `trainSAE`, then the activations will be normalized to have unit mean squared norm. The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. This is very helpful for hyperparameter transfer between different layers and models.

If `submodule` is a model component where the activations are tuples (e.g. this is common when working with residual stream activations), then the buffer yields the first coordinate of the tuple.

Expand Down Expand Up @@ -205,4 +208,15 @@ We've included support for some experimental features. We briefly investigated t
* **Replacing L1 loss with entropy**. Based on the ideas in this [post](https://transformer-circuits.pub/2023/may-update/index.html#simple-factorization), we experimented with using entropy to regularize a dictionary's hidden state instead of L1 loss. This seemed to cause the features to split into dead features (which never fired) and very high-frequency features which fired on nearly every input, which was not the desired behavior. But plausibly there is a way to make this work better.
* **Ghost grads**, as described [here](https://transformer-circuits.pub/2024/jan-update/index.html).

# Citation

Please cite the package as follows:

```
@misc{marks2024dictionary_learning,
title = {dictionary_learning},
author = {Samuel Marks, Adam Karvonen, and Aaron Mueller},
year = {2024},
howpublished = {\url{https://github.com/saprmarks/dictionary_learning}},
}
```
2 changes: 0 additions & 2 deletions __init__.py

This file was deleted.

6 changes: 6 additions & 0 deletions dictionary_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
__version__ = "0.1.0"

from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder
from .buffer import ActivationBuffer

__all__ = ["AutoEncoder", "GatedAutoEncoder", "JumpReluAutoEncoder", "ActivationBuffer"]
Loading