diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..c99b233 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,90 @@ +name: build + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Cache Huggingface assets + uses: actions/cache@v4 + with: + key: huggingface-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + path: ~/.cache/huggingface + restore-keys: | + huggingface-0-${{ runner.os }}-${{ matrix.python-version }}- + - name: Load cached Poetry installation + id: cached-poetry + uses: actions/cache@v4 + with: + path: ~/.local # the path depends on the OS + key: poetry-${{ runner.os }}-${{ matrix.python-version }}-1 # increment to reset cache + - name: Install Poetry + if: steps.cached-poetry.outputs.cache-hit != 'true' + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v4 + with: + path: .venv + key: venv-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + venv-0-${{ runner.os }}-${{ matrix.python-version }}- + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --no-interaction + - name: Run Unit Tests + run: poetry run pytest tests/unit + - name: Build package + run: poetry build + + release: + needs: build + permissions: + contents: write + id-token: write + # https://github.community/t/how-do-i-specify-job-dependency-running-in-another-workflow/16482 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') + runs-on: ubuntu-latest + concurrency: release + environment: + name: pypi + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Semantic Release + id: release + uses: python-semantic-release/python-semantic-release@v8.0.7 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + if: steps.release.outputs.released == 'true' + - name: Publish package distributions to GitHub Releases + uses: python-semantic-release/upload-to-gh-release@main + if: steps.release.outputs.released == 'true' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 98cc45c..71326c1 100644 --- a/.gitignore +++ b/.gitignore @@ -99,7 +99,7 @@ ipython_config.py # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock +poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..207ed78 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,669 @@ +# CHANGELOG + + + +## v0.1.0 (2025-02-12) + +### Feature + +* feat: pypi packaging and auto-release with semantic release ([`0ff8888`](https://github.com/saprmarks/dictionary_learning/commit/0ff88883e7caac8ebd7ea0d8e07585451d8b7f9f)) + +### Unknown + +* Merge pull request #37 from chanind/pypi-package + +feat: pypi packaging and auto-release with semantic release ([`a711efe`](https://github.com/saprmarks/dictionary_learning/commit/a711efe3b60aabc99a35e7279cd35fa8bf4c930a)) + +* simplify matryoshka loss ([`43421f5`](https://github.com/saprmarks/dictionary_learning/commit/43421f5934a1476cb3f32f0b9e1b5d14b84540a1)) + +* Use torch.split() instead of direct indexing for 25% speedup ([`505a445`](https://github.com/saprmarks/dictionary_learning/commit/505a4455358f079db9f2b0309cc0922169869965)) + +* Fix matryoshka spelling ([`aa45bf6`](https://github.com/saprmarks/dictionary_learning/commit/aa45bf6ed9aa981f6a266f333e6d4a8b9d459909)) + +* Fix incorrect auxk logging name ([`784a62a`](https://github.com/saprmarks/dictionary_learning/commit/784a62a405be4ee8754a76ad4d3e61fd7de06348)) + +* Add citation ([`77f2690`](https://github.com/saprmarks/dictionary_learning/commit/77f2690abcd56ce19aaf3c1404dcfcfc6cf9381b)) + +* Make sure to detach reconstruction before calculating aux loss ([`db2b564`](https://github.com/saprmarks/dictionary_learning/commit/db2b5642e2966559a907e4885bf3317ea997a494)) + +* Merge pull request #36 from saprmarks/aux_loss_fixes + +Aux loss fixes, standardize decoder normalization ([`34eefda`](https://github.com/saprmarks/dictionary_learning/commit/34eefdafcbcac784f3761abf5037c5178cbfd866)) + +* Standardize and fix topk auxk loss implementation ([`0af1971`](https://github.com/saprmarks/dictionary_learning/commit/0af19713feb5b4c35788039245013736bf974383)) + +* Normalize decoder after optimzer step ([`200ed3b`](https://github.com/saprmarks/dictionary_learning/commit/200ed3bed09c88d336c25a886eee4cb98c1e616e)) + +* Remove experimental matroyshka temperature ([`6c2fcfc`](https://github.com/saprmarks/dictionary_learning/commit/6c2fcfc2a8108ba720591eb414be6ab16157dc36)) + +* Make sure x is on the correct dtype for jumprelu when logging ([`c697d0f`](https://github.com/saprmarks/dictionary_learning/commit/c697d0f83984f0f257be2044231c30f2abb15aa1)) + +* Import trainers from correct relative location for submodule use ([`8363ff7`](https://github.com/saprmarks/dictionary_learning/commit/8363ff779eee04518edaac9d10d97e459f708b66)) + +* By default, don't normalize Gated activations during inference ([`52b0c54`](https://github.com/saprmarks/dictionary_learning/commit/52b0c54ba92630cfb2ae007f020ed447d4a5ba9f)) + +* Also update context manager for matroyshka threshold ([`65e7af8`](https://github.com/saprmarks/dictionary_learning/commit/65e7af80441e5b601114756afc36a4041cec152f)) + +* Disable autocast for threshold tracking ([`17aa5d5`](https://github.com/saprmarks/dictionary_learning/commit/17aa5d52f818545afe5fbbe3edf1f774cde92f44)) + +* Add torch autocast to training loop ([`832f4a3`](https://github.com/saprmarks/dictionary_learning/commit/832f4a32428cda68ec418aff9abe7dca66a9f66e)) + +* Save state dicts to cpu ([`3c5a5cd`](https://github.com/saprmarks/dictionary_learning/commit/3c5a5cdef682cbeb12e23b825f39709f518e2c0a)) + +* Add an option to pass LR to TopK trainers ([`8316a44`](https://github.com/saprmarks/dictionary_learning/commit/8316a4418dc4acb70ccad9854d3b05df1b817b9d)) + +* Add April Update Standard Trainer ([`cfb36ff`](https://github.com/saprmarks/dictionary_learning/commit/cfb36fff661fa60f38a2d1b372b6802517c08257)) + +* Merge pull request #35 from saprmarks/code_cleanup + +Consolidate LR Schedulers, Sparsity Schedulers, and constrained optimizers ([`f19db98`](https://github.com/saprmarks/dictionary_learning/commit/f19db98106302ed1d75dc8380160463ff812b1ad)) + +* Consolidate LR Schedulers, Sparsity Schedulers, and constrained optimizers ([`9751c57`](https://github.com/saprmarks/dictionary_learning/commit/9751c57731a25c04871e8173d16a0e4d902edc19)) + +* Merge pull request #34 from adamkarvonen/matroyshka + +Add Matroyshka, Fix Jump ReLU training, modify initialization ([`92648d4`](https://github.com/saprmarks/dictionary_learning/commit/92648d4e3d28aa397dbc89c43147aa6faf8874b7)) + +* Add a verbose option during training ([`0ff687b`](https://github.com/saprmarks/dictionary_learning/commit/0ff687bdc12cba66a0233825cb301df28da3a9db)) + +* Prevent wandb cuda multiprocessing errors ([`370272a`](https://github.com/saprmarks/dictionary_learning/commit/370272a4aac0ad0e59a2982073aa7b08970712b6)) + +* Log dead features for batch top k SAEs ([`936a69c`](https://github.com/saprmarks/dictionary_learning/commit/936a69c38a74980830f24fc851c40fb93abe8f07)) + +* Log number of dead features to wandb ([`77da794`](https://github.com/saprmarks/dictionary_learning/commit/77da7945f520f448b0524e476f539b3a44a4ca43)) + +* Add trainer number to wandb name ([`3b03b92`](https://github.com/saprmarks/dictionary_learning/commit/3b03b92b97d61a95e98b6f187dad97e939f6f977)) + +* Add notes ([`810dbb8`](https://github.com/saprmarks/dictionary_learning/commit/810dbb8bdce4ac6f1ce371872297b4f7a104e3f6)) + +* Add option to ignore bos tokens ([`c2fe5b8`](https://github.com/saprmarks/dictionary_learning/commit/c2fe5b89e78ae4a9d41a4809f4d00b8a3fcd0b36)) + +* Fix jumprelu training ([`ec961ac`](https://github.com/saprmarks/dictionary_learning/commit/ec961acde2244b98b26bcf796c3ec00b721088bb)) + +* Use kaiming initialization if specified in paper, fix batch_top_k aux_k_alpha ([`8eaa8b2`](https://github.com/saprmarks/dictionary_learning/commit/8eaa8b2407eabd714bbe7d55fd0c15fcb05fba24)) + +* Format with ruff ([`3e31571`](https://github.com/saprmarks/dictionary_learning/commit/3e31571b20d3e86823540882ec03c87b155d8e3d)) + +* Add temperature scaling to matroyshka ([`ceabbc5`](https://github.com/saprmarks/dictionary_learning/commit/ceabbc5233dcf28f0f5afd53e0de850d19f34d78)) + +* norm the correct decoder dimension ([`5383603`](https://github.com/saprmarks/dictionary_learning/commit/53836033b305142fb6d076a52a7679e0642ddb7a)) + +* Fix loading matroyshkas from_pretrained() ([`764d4ac`](https://github.com/saprmarks/dictionary_learning/commit/764d4ac4450ea6b7d79de52fdec70c7c1e0dfb79)) + +* Initial matroyshka implementation ([`8ade55b`](https://github.com/saprmarks/dictionary_learning/commit/8ade55b6eb57ed7c7b06a70187ee68e1056bb95b)) + +* Make sure we step the learning rate scheduler ([`1df47d8`](https://github.com/saprmarks/dictionary_learning/commit/1df47d83d9dea07d2fb905509b635ac6139bcd48)) + +* Merge pull request #33 from saprmarks/lr_scheduling + +Lr scheduling ([`316dbbe`](https://github.com/saprmarks/dictionary_learning/commit/316dbbe9a905bdab91fb2db63bbc61646e7039a6)) + +* Properly set new parameters in end to end test ([`e00fd64`](https://github.com/saprmarks/dictionary_learning/commit/e00fd643050584f4cfe15ad41e6a01e29e3c0766)) + +* Standardize learning rate and sparsity schedules ([`a2d6c43`](https://github.com/saprmarks/dictionary_learning/commit/a2d6c43e94ef068821441d47fef8ae7b3215d09e)) + +* Merge pull request #32 from saprmarks/add_sparsity_warmup + +Add sparsity warmup ([`a11670f`](https://github.com/saprmarks/dictionary_learning/commit/a11670fc6b96b1af3fe8a97175218041f2a9791f)) + +* Add sparsity warmup for trainers with a sparsity penalty ([`911b958`](https://github.com/saprmarks/dictionary_learning/commit/911b95890e20998df92710a01d158f4663d6834b)) + +* Clean up lr decay ([`e0db40b`](https://github.com/saprmarks/dictionary_learning/commit/e0db40b8fadcdd1e24c1945829ecd4eb57451fa8)) + +* Track lr decay implementation ([`f0bb66d`](https://github.com/saprmarks/dictionary_learning/commit/f0bb66d1c25bcb7dc8df62d8dbc3bfd47d26b14c)) + +* Remove leftover variable, update expected results with standard SAE improvements ([`9687bb9`](https://github.com/saprmarks/dictionary_learning/commit/9687bb9858ef05306227309af99cd5c09d91642a)) + +* Merge pull request #31 from saprmarks/add_demo + +Add option to normalize dataset, track thresholds for TopK SAEs, Fix Standard SAE ([`67a7857`](https://github.com/saprmarks/dictionary_learning/commit/67a7857ca63eb9299c340bc8f9804cdd569df1a9)) + +* Also scale topk thresholds when scaling biases ([`efd76b1`](https://github.com/saprmarks/dictionary_learning/commit/efd76b138f429bb8e5e969e2e45926e886fdd71b)) + +* Use the correct standard SAE reconstruction loss, initialize W_dec to W_enc.T ([`8b95ec9`](https://github.com/saprmarks/dictionary_learning/commit/8b95ec9b6e9a6d8d6255092e51b7580dccac70d6)) + +* Add bias scaling to topk saes ([`484ca01`](https://github.com/saprmarks/dictionary_learning/commit/484ca01f405e5791968883123718fd67ee35f299)) + +* Fix topk bfloat16 dtype error ([`488a154`](https://github.com/saprmarks/dictionary_learning/commit/488a1545922249cdb9ce5a5885c1931a5c21a37f)) + +* Add option to normalize dataset activations ([`81968f2`](https://github.com/saprmarks/dictionary_learning/commit/81968f2659082996539f08ea3188a5d2ed327696)) + +* Remove demo script and graphing notebook ([`57f451b`](https://github.com/saprmarks/dictionary_learning/commit/57f451b5635c4677ab47a4172aa588a5bdffdb4e)) + +* Track thresholds for topk and batchtopk during training ([`b5821fd`](https://github.com/saprmarks/dictionary_learning/commit/b5821fd87e3676e7a9ab6b87d423c03c57a344dd)) + +* Track threshold for batchtopk, rename for consistency ([`32d198f`](https://github.com/saprmarks/dictionary_learning/commit/32d198f738c61b0c1109f1803c43e01afb977d3e)) + +* Modularize demo script ([`dcc02f0`](https://github.com/saprmarks/dictionary_learning/commit/dcc02f04e504331011a54ce851a91976daf15170)) + +* Begin creation of demo script ([`712eb98`](https://github.com/saprmarks/dictionary_learning/commit/712eb98f78d9537aa3ff01a1d9e007361e67c267)) + +* Fix JumpReLU training and loading ([`552a8c2`](https://github.com/saprmarks/dictionary_learning/commit/552a8c2c12d41b5d520c99bf3534dff5329f0fde)) + +* Ensure activation buffer has the correct dtype ([`d416eab`](https://github.com/saprmarks/dictionary_learning/commit/d416eab5de1edfe8ea75c972cdf78d9de68642c2)) + +* Merge pull request #30 from adamkarvonen/add_tests + +Add end to end test, upgrade nnsight to support 0.3.0, fix bugs ([`c4eed3c`](https://github.com/saprmarks/dictionary_learning/commit/c4eed3cca27e93f0ad80cd49057cb862d03c86d7)) + +* Merge pull request #26 from mntss/batchtokp_aux_fix + +Fix BatchTopKSAE training ([`2ec1890`](https://github.com/saprmarks/dictionary_learning/commit/2ec18905045109ec0647bc127bacb794312fc2f6)) + +* Check for is_tuple to support mlp / attn submodules ([`d350415`](https://github.com/saprmarks/dictionary_learning/commit/d350415e119cacb6547703eb9733daf8ef57075b)) + +* Change save_steps to a list of ints ([`f1b9b80`](https://github.com/saprmarks/dictionary_learning/commit/f1b9b800bc8e2cc308d4d14690df71f854b30fce)) + +* Add early stopping in forward pass ([`05fe179`](https://github.com/saprmarks/dictionary_learning/commit/05fe179f5b0616310253deaf758c370071f534fa)) + +* Obtain better test results using multiple batches ([`067bf7b`](https://github.com/saprmarks/dictionary_learning/commit/067bf7b05470f61b9ed4f38b95be55c5ac45fb8f)) + +* Fix frac_alive calculation, perform evaluation over multiple batches ([`dc30720`](https://github.com/saprmarks/dictionary_learning/commit/dc3072089c24ce1eb8bc40e9f5248c69a92f5174)) + +* Complete nnsight 0.2 to 0.3 changes ([`807f6ef`](https://github.com/saprmarks/dictionary_learning/commit/807f6ef735872a5cab68773a315f15bc920c3d72)) + +* Rename input to inputs per nnsight 0.3.0 ([`9ed4af2`](https://github.com/saprmarks/dictionary_learning/commit/9ed4af245a22e095e932d6065d368c58947d9a3d)) + +* Add a simple end to end test ([`fe54b00`](https://github.com/saprmarks/dictionary_learning/commit/fe54b001cba976ca96d46add8539580268dc5cb6)) + +* Create LICENSE ([`32fec9c`](https://github.com/saprmarks/dictionary_learning/commit/32fec9c4556b3acaa709d756e8693edde1e74644)) + +* Fix BatchTopKSAE training ([`4aea538`](https://github.com/saprmarks/dictionary_learning/commit/4aea5388811284f4fd3daa8fb97916073bfe8841)) + +* dtype for loading SAEs ([`932e10a`](https://github.com/saprmarks/dictionary_learning/commit/932e10a46523347e8c2da70a10bb8e6dd42d17c6)) + +* Merge pull request #22 from pleask/jumprelu + +Implement jumprelu training ([`713f638`](https://github.com/saprmarks/dictionary_learning/commit/713f6389dde35177c83361f90daaba99b5ac3d08)) + +* Merge branch 'main' into jumprelu ([`099dbbf`](https://github.com/saprmarks/dictionary_learning/commit/099dbbfcdcad07dfc85dd65bfbd15ca9eece70a5)) + +* Merge pull request #21 from pleask/separate-wandb-runs + +Use separate wandb runs for each SAE being trained ([`df60f52`](https://github.com/saprmarks/dictionary_learning/commit/df60f52737f18ce0b1ecd2eb9e08d0706871442d)) + +* Merge branch 'main' into jumprelu ([`3dfc069`](https://github.com/saprmarks/dictionary_learning/commit/3dfc069d39ceeb33ce60581fc7cb17f08ec0e428)) + +* implement jumprelu training ([`16bdfd9`](https://github.com/saprmarks/dictionary_learning/commit/16bdfd95bc04000b89f81b0496df59f17653a2f8)) + +* handle no wandb ([`8164d32`](https://github.com/saprmarks/dictionary_learning/commit/8164d32ec79325d3cc31063098b9108386eb15cf)) + +* Merge pull request #20 from pleask/batchtopk + +Implement BatchTopK ([`b001fb0`](https://github.com/saprmarks/dictionary_learning/commit/b001fb0fd358efc7647acf835123a5e874a9a822)) + +* separate runs for each sae being trained ([`7d3b127`](https://github.com/saprmarks/dictionary_learning/commit/7d3b12778070b88fd39c439751973ac83afbe7a0)) + +* add batchtopk ([`f08e00b`](https://github.com/saprmarks/dictionary_learning/commit/f08e00b2585ab9a965984af4932614a2e408b6e3)) + +* Move f_gate to encoder's dtype ([`43bdb3b`](https://github.com/saprmarks/dictionary_learning/commit/43bdb3b903f7a45ee52b4d865f6d6b7bd60647a3)) + +* Ensure that x_hat is in correct dtype ([`3376f1b`](https://github.com/saprmarks/dictionary_learning/commit/3376f1bd9d05bedd03179475052d3a26a61fad7a)) + +* Preallocate buffer memory to lower peak VRAM usage when replenishing buffer ([`90aff63`](https://github.com/saprmarks/dictionary_learning/commit/90aff63b042c50c3c81a3977b62248254115e907)) + +* Perform logging outside of training loop to lower peak memory usage ([`57f8812`](https://github.com/saprmarks/dictionary_learning/commit/57f8812ff93d4d9ac437d29a74f1d920daa45515)) + +* Remove triton usage ([`475fece`](https://github.com/saprmarks/dictionary_learning/commit/475feceba9e47d6e74b17c87844253f0a209d75d)) + +* Revert to triton TopK implementation ([`d94697d`](https://github.com/saprmarks/dictionary_learning/commit/d94697df1783da8b6739e565c3a1bd297b8b1e98)) + +* Add relative reconstruction bias from GDM Gated SAE paper to evaluate() ([`8984b01`](https://github.com/saprmarks/dictionary_learning/commit/8984b0112e6f9eebcf869aba78ad713b2016d6a6)) + +* git push origin main:Merge branch 'ElanaPearl-small_bug_fixes' into main ([`2d586e4`](https://github.com/saprmarks/dictionary_learning/commit/2d586e417cd30473e1c608146df47eb5767e2527)) + +* simplifying readme ([`9c46e06`](https://github.com/saprmarks/dictionary_learning/commit/9c46e061eb3b29d055e7221ce92524c6546d2a59)) + +* simplify readme ([`5c96003`](https://github.com/saprmarks/dictionary_learning/commit/5c9600344e033b5a7834a48914e958b257bcb720)) + +* add missing imports ([`7f689d9`](https://github.com/saprmarks/dictionary_learning/commit/7f689d9a3a60d577a0d860ac306ae7ba0c71240a)) + +* fix arg name in trainer_config ([`9577d26`](https://github.com/saprmarks/dictionary_learning/commit/9577d26c92affa71a9dcc3a3b3f6cb905f230388)) + +* update sae training example code ([`9374546`](https://github.com/saprmarks/dictionary_learning/commit/937454616f087a6e30afa2ae5f6d52ea685ebfee)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`7d405f7`](https://github.com/saprmarks/dictionary_learning/commit/7d405f7d7555444c66121bc853ab027f49c408b0)) + +* GatedSAE: moved feature re-normalization into encode ([`f628c0e`](https://github.com/saprmarks/dictionary_learning/commit/f628c0ef2ec53d20ffd4d3d06f84100054c358e1)) + +* documenting JumpReLU SAE support ([`322b6c0`](https://github.com/saprmarks/dictionary_learning/commit/322b6c0c75767b7fe110d1454b9dcd4106bb942b)) + +* support for JumpReluAutoEncoders ([`57df4e7`](https://github.com/saprmarks/dictionary_learning/commit/57df4e75cbf181e3662058a6609ab2bb5921c9c4)) + +* Add submodule_name to PAnnealTrainer ([`ecdac03`](https://github.com/saprmarks/dictionary_learning/commit/ecdac0376285912d9468695c024b39100c663b07)) + +* host SAEs on huggingface ([`0ae37fe`](https://github.com/saprmarks/dictionary_learning/commit/0ae37feeb8beac0fce5036c6ff4188c86627775e)) + +* fixed batch loading in examine_dimension ([`82485d7`](https://github.com/saprmarks/dictionary_learning/commit/82485d78bcb6d3bcec67965743fac32e6d29ff37)) + +* Merge pull request #17 from saprmarks/collab + +Merge Collab Branch ([`cdf8222`](https://github.com/saprmarks/dictionary_learning/commit/cdf82227d24295fe8a83fbcfe785e6d6d4f2b997)) + +* moved experimental trainers to collab-dev ([`8d1d581`](https://github.com/saprmarks/dictionary_learning/commit/8d1d581f3df482c77ca99d0839f1677b19ca1ae7)) + +* Merge branch 'main' into collab ([`dda38b9`](https://github.com/saprmarks/dictionary_learning/commit/dda38b94a491261fd92bf9754f1c673221d7f270)) + +* Update README.md ([`4d6c6a6`](https://github.com/saprmarks/dictionary_learning/commit/4d6c6a6cb5816571e045f3c42c9f5b508d395d83)) + +* remove a sentence ([`2d40ed5`](https://github.com/saprmarks/dictionary_learning/commit/2d40ed598074c57904e9566d82bbd8ce27b661b5)) + +* add a list of trainers to the README ([`746927a`](https://github.com/saprmarks/dictionary_learning/commit/746927ae0b597e1fcb69aed58a5e9d4b6103732c)) + +* add architecture details to README ([`60422a8`](https://github.com/saprmarks/dictionary_learning/commit/60422a87231439425b9e27384352b03bc245365a)) + +* make wandb integration optional ([`a26c4e5`](https://github.com/saprmarks/dictionary_learning/commit/a26c4e57985458735bbf887685b679d16008de98)) + +* make wandb integration optional ([`0bdc871`](https://github.com/saprmarks/dictionary_learning/commit/0bdc871a95dae4de17b5116eda38f20d2375ebd1)) + +* Fix tutorial 404 ([`deb3df7`](https://github.com/saprmarks/dictionary_learning/commit/deb3df7906c8a0d00a4286f42cb65ae27667b2a7)) + +* Add missing values to config ([`9e44ea9`](https://github.com/saprmarks/dictionary_learning/commit/9e44ea9dc015c6bf919bd61aa40892be1da66dc3)) + +* changed TrainerTopK class name ([`c52ff00`](https://github.com/saprmarks/dictionary_learning/commit/c52ff008869a021b5f58d1beb80f8afe014757c5)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`c04ee3b`](https://github.com/saprmarks/dictionary_learning/commit/c04ee3b006ae72e69266d0ac2163035aee326b6a)) + +* fixed loss_recovered to incorporate top_k ([`6be5635`](https://github.com/saprmarks/dictionary_learning/commit/6be563540801caf185069051985b453dacc421d8)) + +* fixed TopK loss (spotted by Anish) ([`a3b71f7`](https://github.com/saprmarks/dictionary_learning/commit/a3b71f71212b839c8814ffa4223a5026837738c3)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`40bcdf6`](https://github.com/saprmarks/dictionary_learning/commit/40bcdf65b646f0e387d030b0c2211eaf07636b4c)) + +* naming conventions ([`5ff7fa1`](https://github.com/saprmarks/dictionary_learning/commit/5ff7fa101da07dfdb0663a484214b75c79e02fe0)) + +* small fix to triton kernel ([`5d21265`](https://github.com/saprmarks/dictionary_learning/commit/5d21265bd390d35b937d10d83cdf617151212cb3)) + +* small updates for eval ([`585e820`](https://github.com/saprmarks/dictionary_learning/commit/585e82070620771ee5bef4278d4d500b02983e0c)) + +* added some housekeeping stuff to top_k ([`5559c2c`](https://github.com/saprmarks/dictionary_learning/commit/5559c2c02d84df49531a631d3f4b29ef8acf94c4)) + +* add support for Top-k SAEs ([`2d549d0`](https://github.com/saprmarks/dictionary_learning/commit/2d549d0d98e400fedf4d7c4127d540f97240b89e)) + +* add transcoder eval ([`8446f4f`](https://github.com/saprmarks/dictionary_learning/commit/8446f4fc1aa9e7a08ece6e2fd59e6fa9583a7501)) + +* add transcoder support ([`c590a25`](https://github.com/saprmarks/dictionary_learning/commit/c590a254990691947b244e09849db7b288ed6bee)) + +* added wandb finish to trainer ([`113c042`](https://github.com/saprmarks/dictionary_learning/commit/113c042101b6df6de60b04c7e65116c3a9460904)) + +* fixed anneal end bug ([`fbd9ee4`](https://github.com/saprmarks/dictionary_learning/commit/fbd9ee41ed23d65cbaedb43447b64ae4117dab9a)) + +* added layer and lm_name ([`d173235`](https://github.com/saprmarks/dictionary_learning/commit/d17323572d23067bbba949732a842b1c2c149188)) + +* adding layer and lm_name to trainer config ([`6168ee0`](https://github.com/saprmarks/dictionary_learning/commit/6168ee0210308a42f3536f5bff19db70e91311ae)) + +* make tracer_args optional ([`31b2828`](https://github.com/saprmarks/dictionary_learning/commit/31b2828869bd560ac29eafbd3abf06f752063047)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`87d2b58`](https://github.com/saprmarks/dictionary_learning/commit/87d2b58da4b5714a44e6d301f2b5595e6bdd4296)) + +* bug fix evaluating CE loss with NNsight models ([`f8d81a1`](https://github.com/saprmarks/dictionary_learning/commit/f8d81a1d56b96f34c26fcc9f3feac0cb11ab3065)) + +* Combining P Annealing and Anthropic Update ([`44318e9`](https://github.com/saprmarks/dictionary_learning/commit/44318e999d6d123daad63fa399935ba339421070)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`43e9ca6`](https://github.com/saprmarks/dictionary_learning/commit/43e9ca63664dafd9a9f23f81b0bf57917a9f36ba)) + +* removing normalization ([`7a98d77`](https://github.com/saprmarks/dictionary_learning/commit/7a98d77318b3abcc0aab237de455eb33e20f691e)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`5f2b598`](https://github.com/saprmarks/dictionary_learning/commit/5f2b598cdbeb311f32c4ce1e2e816769240bb75e)) + +* added buffer for NNsight models (not LanguageModel classes) as an extra class. We'll want to combine the three buffers wo currently have at some point ([`f19d284`](https://github.com/saprmarks/dictionary_learning/commit/f19d2843f9fc64192ddac12f345a4ad910b96310)) + +* fixed nnsight issues model tracing for chess-gpt ([`7e8c9f9`](https://github.com/saprmarks/dictionary_learning/commit/7e8c9f95cd25bb6bc56def8210852841a30f22fd)) + +* added W_O projection to HeadBuffer ([`47bd4cd`](https://github.com/saprmarks/dictionary_learning/commit/47bd4cdea4a64563d2f8ba9ab39b246caf9f3c8c)) + +* added support for training SAEs on individual heads ([`a0e3119`](https://github.com/saprmarks/dictionary_learning/commit/a0e31199f2a02c86328bc4551f1d9a0b89d0d87b)) + +* added support for training SAEs on individual heads ([`47351b4`](https://github.com/saprmarks/dictionary_learning/commit/47351b4f6ca0bbd42981f73a784e1a395a941025)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`7de0bd3`](https://github.com/saprmarks/dictionary_learning/commit/7de0bd3d062693d8a35f309a6bc8b494c98408a3)) + +* default hyperparameter adjustments ([`a09346b`](https://github.com/saprmarks/dictionary_learning/commit/a09346b928a9782f57ec137b95d9e7636eda2abf)) + +* normalization in gated_new ([`104aba2`](https://github.com/saprmarks/dictionary_learning/commit/104aba291b0c17a5ec9e86655a281f457ce14cbc)) + +* fixing bug where inputs can get overwritten ([`93fd46e`](https://github.com/saprmarks/dictionary_learning/commit/93fd46e3884daf2fb2e17d952b7a4030b0129957)) + +* fixing tuple bug ([`b05dcaf`](https://github.com/saprmarks/dictionary_learning/commit/b05dcafc816370be8f0584700fd5a882be4a2e8f)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`73b5663`](https://github.com/saprmarks/dictionary_learning/commit/73b5663ed47c91c1c0d2fa8d47c029686fcf8a48)) + +* multiple steps debugging ([`de3eef1`](https://github.com/saprmarks/dictionary_learning/commit/de3eef10d322502f150dc63e9f71d84c9b777b71)) + +* adding gradient pursuit function ([`72941f1`](https://github.com/saprmarks/dictionary_learning/commit/72941f10e794401b2a6b682aa097f4db3f7aa1fe)) + +* bugfix ([`53aabc0`](https://github.com/saprmarks/dictionary_learning/commit/53aabc0ae45464fd3d1d1d384969fe7066d94a7a)) + +* bugfix ([`91691b5`](https://github.com/saprmarks/dictionary_learning/commit/91691b5b8da50f6b1d44eae501529d72b935752e)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`9ce7d80`](https://github.com/saprmarks/dictionary_learning/commit/9ce7d80ec96e7324c095ffef81039d0e6a896feb)) + +* logging more things ([`8498a75`](https://github.com/saprmarks/dictionary_learning/commit/8498a754acacca467494182dbf7444b34e1184c3)) + +* changing initialization for AutoEncoderNew ([`c7ee7ec`](https://github.com/saprmarks/dictionary_learning/commit/c7ee7ec8e7c4bc235cf969f7653a2d99f9bd5723)) + +* fixing gated SAE encoder scheme ([`4084bc3`](https://github.com/saprmarks/dictionary_learning/commit/4084bc3fa50f0764864630b2fe476722a9303b47)) + +* changes to gatedSAE API ([`9e001d1`](https://github.com/saprmarks/dictionary_learning/commit/9e001d170c752c1887c27942bf3d6336322a0ff0)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`05b397b`](https://github.com/saprmarks/dictionary_learning/commit/05b397bcc60f3026da7d55aefebfd3b2223273a6)) + +* changing initialization ([`ebe0d57`](https://github.com/saprmarks/dictionary_learning/commit/ebe0d57c62ebde85386fd7ec59157758e85d3ce3)) + +* finished combining gated and p-annealing ([`4c08614`](https://github.com/saprmarks/dictionary_learning/commit/4c08614403d51328c672983cb28aaaee846092bc)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`8e0a6f9`](https://github.com/saprmarks/dictionary_learning/commit/8e0a6f998ded264270c019ee9b14ffb9c31d650a)) + +* gated_anneal first steps ([`ba8b8fa`](https://github.com/saprmarks/dictionary_learning/commit/ba8b8fa1efda86ea843b0a837f98f106ab089448)) + +* jump SAE ([`873b764`](https://github.com/saprmarks/dictionary_learning/commit/873b764b5a17bdc8704a4a871362e1b03de3ef5f)) + +* adapted loss logging in p_anneal ([`33997c0`](https://github.com/saprmarks/dictionary_learning/commit/33997c05699862896191dc683c9622efc3e97f95)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`1eecbda`](https://github.com/saprmarks/dictionary_learning/commit/1eecbdaf651dffa1ed4962d79d3f2577d1979e91)) + +* merging gated and Anthropic SAEs ([`b6a24d0`](https://github.com/saprmarks/dictionary_learning/commit/b6a24d001234e38c2f6b4c52215d65fdcb50a09e)) + +* revert trainer naming ([`c0af6d9`](https://github.com/saprmarks/dictionary_learning/commit/c0af6d9c20fda36ee700e2884611fd12edc3fb59)) + +* restored trainer naming ([`2ec3c67`](https://github.com/saprmarks/dictionary_learning/commit/2ec3c6768d21b019ba12d0065876011e85bc2aae)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`fe7e93b`](https://github.com/saprmarks/dictionary_learning/commit/fe7e93bf606a6c8e2e2d13335565455359905345)) + +* various changes ([`32027ae`](https://github.com/saprmarks/dictionary_learning/commit/32027ae3781e367affc23c3fde5fc504ef49ebc4)) + +* debug panneal ([`463907d`](https://github.com/saprmarks/dictionary_learning/commit/463907dab4ee91254d0ae674752d3c3803a8044d)) + +* debug panneal ([`8c00100`](https://github.com/saprmarks/dictionary_learning/commit/8c00100423223dc21d57ee4114f9ef6b38ee209e)) + +* debug panneal ([`dc632cd`](https://github.com/saprmarks/dictionary_learning/commit/dc632cd69df0c1719ebed7bdd677d7373f37dc74)) + +* debug panneal ([`166f6a9`](https://github.com/saprmarks/dictionary_learning/commit/166f6a9e582d45728d1e8291c6ab451dbb7a35fd)) + +* debug panneal ([`bcebaa6`](https://github.com/saprmarks/dictionary_learning/commit/bcebaa6b2adedaecd0779d6551929b3a213aef1e)) + +* debug pannealing ([`446c568`](https://github.com/saprmarks/dictionary_learning/commit/446c568d32ff7c93c9688c193dd459abe9086ed5)) + +* p_annealing loss buffer ([`e4d4a35`](https://github.com/saprmarks/dictionary_learning/commit/e4d4a3532536d9b450f39c034c0aabd8e95560fa)) + +* implement Ben's p-annealing strategy ([`06a27f0`](https://github.com/saprmarks/dictionary_learning/commit/06a27f096c0e62df695d60d9e1ec7df77c305498)) + +* panneal changes ([`fe4ff6f`](https://github.com/saprmarks/dictionary_learning/commit/fe4ff6fa5d0c85942b45fffa0bb2908f4d13a2aa)) + +* logging trainer names to wandb ([`f9c5e45`](https://github.com/saprmarks/dictionary_learning/commit/f9c5e45a85ed345fdd95502a4de7a873c25f8456)) + +* bugfixes for StandardTrainerNew ([`70acd85`](https://github.com/saprmarks/dictionary_learning/commit/70acd8572b1b5250ac38cb9be04069cb1a6f981e)) + +* trainer for new anthropic infrastructure ([`531c285`](https://github.com/saprmarks/dictionary_learning/commit/531c28596cbbe296c45a2a6e22ea175e8633f2a1)) + +* adding r_mag parameter to GSAE ([`198ddf4`](https://github.com/saprmarks/dictionary_learning/commit/198ddf4bd4210b95a11b0a29862fe615d1774fe0)) + +* gatedSAE trainer ([`3567d6d`](https://github.com/saprmarks/dictionary_learning/commit/3567d6d2a2cb6d810df32b838029daacc354aaaa)) + +* cosmetic change ([`0200976`](https://github.com/saprmarks/dictionary_learning/commit/0200976ba04409d477e5321b586b844dd545b976)) + +* GatedAutoEncoder class ([`2cfc47b`](https://github.com/saprmarks/dictionary_learning/commit/2cfc47b42e89c294e14c98969a993f5910604211)) + +* p annealing not affected by resampling ([`ad8d837`](https://github.com/saprmarks/dictionary_learning/commit/ad8d8371411067c6a031d87faf08f4ec2fe96032)) + +* integrated trainer update ([`c7613d3`](https://github.com/saprmarks/dictionary_learning/commit/c7613d386a5677451f6da6f9260ceb9d28a3a4d4)) + +* Merge branch 'collab' into p_annealing ([`933b80c`](https://github.com/saprmarks/dictionary_learning/commit/933b80c91a3e49e2e7a761422c629588774370eb)) + +* fixed p calculation ([`9837a6f`](https://github.com/saprmarks/dictionary_learning/commit/9837a6fa4e88303b7694aa3556485661fa512f1c)) + +* getting rid of useless seed arguement ([`377c762`](https://github.com/saprmarks/dictionary_learning/commit/377c762d9a9333aed42ad097d393796fcf8a7e57)) + +* trainer initializes SAE ([`7dffb66`](https://github.com/saprmarks/dictionary_learning/commit/7dffb663a0dcc5f5e3c2855e24e9f8b322704bcc)) + +* trainer initialized SAE ([`6e80590`](https://github.com/saprmarks/dictionary_learning/commit/6e80590fb441c53df70345bfd20da4fbad7c9cf9)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`c58d23d`](https://github.com/saprmarks/dictionary_learning/commit/c58d23d5a6e2d38c0ff47e42b157f1686f7e98a6)) + +* changes to lista p_anneal trainers ([`3cc6642`](https://github.com/saprmarks/dictionary_learning/commit/3cc6642b414608e5d0e86c733b0855f927afa52c)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`9dfd3db`](https://github.com/saprmarks/dictionary_learning/commit/9dfd3dbf42d3ad35b0bb32f9d8374ac00201edda)) + +* decoupled lr warmup and p warmup in p_anneal trainer ([`c3c1645`](https://github.com/saprmarks/dictionary_learning/commit/c3c164540476d69ff4c3bfa7f9a1a4532c4603c0)) + +* Merge pull request #14 from saprmarks/p_annealing + +added annealing and trainer_param_callback ([`61927bc`](https://github.com/saprmarks/dictionary_learning/commit/61927bcf99537a15651a9829a6a261cffad9e65f)) + +* cosmetic changes to interp ([`4a7966f`](https://github.com/saprmarks/dictionary_learning/commit/4a7966f979ea4b660613c980cdefd48494511955)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`c76818e`](https://github.com/saprmarks/dictionary_learning/commit/c76818e4dbf7e980251a6f652529e50cd1b1b7de)) + +* Merge pull request #13 from jannik-brinkmann/collab + +add ListaTrainer ([`d4d2fd9`](https://github.com/saprmarks/dictionary_learning/commit/d4d2fd9b57a4ab380a56b1b5fa0faf1d91a29989)) + +* additional evluation metrics ([`fa2ec08`](https://github.com/saprmarks/dictionary_learning/commit/fa2ec081e2ff42377eb98b031320933806b2faf7)) + +* add GroupSAETrainer ([`60e6068`](https://github.com/saprmarks/dictionary_learning/commit/60e6068924a42b8252d11b398b9972205b46ece4)) + +* added annealing and trainer_param_callback ([`18e3fca`](https://github.com/saprmarks/dictionary_learning/commit/18e3fcaaf5428e998d26a0be80f1be56ffea7981)) + +* Merge remote-tracking branch 'upstream/collab' into collab ([`4650c2a`](https://github.com/saprmarks/dictionary_learning/commit/4650c2a7db87c7ca32db043cb15db8a28450a013)) + +* fixing neuron resampling ([`a346be9`](https://github.com/saprmarks/dictionary_learning/commit/a346be9abc6644fd59ae493e44ef8fdbd1e339e2)) + +* improvements to saving and logging ([`4a1d7ae`](https://github.com/saprmarks/dictionary_learning/commit/4a1d7ae76d59713fe0c4722e821ad3882c0aa757)) + +* can export buffer config ([`d19d8d9`](https://github.com/saprmarks/dictionary_learning/commit/d19d8d956da3e04ab899b93fc67c63b0a7bd5020)) + +* fixing evaluation.py ([`c91a581`](https://github.com/saprmarks/dictionary_learning/commit/c91a5815e4e11197a8031d21193381f9b596b95c)) + +* fixing bug in neuron resampling ([`67a03c7`](https://github.com/saprmarks/dictionary_learning/commit/67a03c763feec3bcebd9070389b8481257bdf10b)) + +* add ListaTrainer ([`880f570`](https://github.com/saprmarks/dictionary_learning/commit/880f5706a42c337e021530855166089b6722e1df)) + +* fixing neuron resampling in standard trainer ([`3406262`](https://github.com/saprmarks/dictionary_learning/commit/3406262b31dd97f29130532d694aecd62f092f80)) + +* improvements to training and evaluating ([`b111d40`](https://github.com/saprmarks/dictionary_learning/commit/b111d40898d97123722cda60084f46d0766cd3e2)) + +* Factoring out SAETrainer class ([`fabd001`](https://github.com/saprmarks/dictionary_learning/commit/fabd001d97f869c01e67ea26f2e02822eba9ab82)) + +* updating syntax for buffer ([`035a0f9`](https://github.com/saprmarks/dictionary_learning/commit/035a0f9d4ffa8e7307ae637fb801a78c0ea9eb95)) + +* updating readme for from_pretrained ([`70e8c2a`](https://github.com/saprmarks/dictionary_learning/commit/70e8c2a13682ef12658f92b459c1bf552cb78180)) + +* from_pretrained ([`db96abc`](https://github.com/saprmarks/dictionary_learning/commit/db96abc96be7ba975bb09a41c7a81b13c2ea5f3e)) + +* Change syntax for specifying activation dimensions and batch sizes ([`bdf1f19`](https://github.com/saprmarks/dictionary_learning/commit/bdf1f19b292b152b3c4601fc7a77fc6d66cd04c0)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`86c7475`](https://github.com/saprmarks/dictionary_learning/commit/86c7475a945c0a70c0a82c914d9733c8d2bcc651)) + +* activation_dim for IdentityDict is optional ([`be1b68c`](https://github.com/saprmarks/dictionary_learning/commit/be1b68c0df0de955d722f1739f5c115dfbfbf702)) + +* update umap requirement ([`776b53e`](https://github.com/saprmarks/dictionary_learning/commit/776b53e506a2c720139d056542a3397d883e2c79)) + +* Merge pull request #10 from adamkarvonen/shell_script_change + +Add sae_set_name to local_path for dictionary downloader ([`33b5a6b`](https://github.com/saprmarks/dictionary_learning/commit/33b5a6be4ea3c76aa918178f2dfcd3f7c81e2b97)) + +* Add sae_set_name to local_path for dictionary downloader ([`d6163be`](https://github.com/saprmarks/dictionary_learning/commit/d6163be200d28653394c2b9adac540c7a27e2659)) + +* dispatch no longer needed when loading models ([`69c32ca`](https://github.com/saprmarks/dictionary_learning/commit/69c32ca6fcf1c94c4b7fb7ac8b82fe7257123400)) + +* removed in_and_out option for activation buffer ([`cf6ad1d`](https://github.com/saprmarks/dictionary_learning/commit/cf6ad1d72de9fc11acba34e73a03799e2b893692)) + +* updating readme with 10_32768 dictionaries ([`614883f`](https://github.com/saprmarks/dictionary_learning/commit/614883f9476613e7c1c48b951cd3947451e1f534)) + +* upgrade to nnsight 0.2 ([`cbc5f79`](https://github.com/saprmarks/dictionary_learning/commit/cbc5f7991c9233579c36b4972c6273f3f250f0ef)) + +* downloader script ([`7a305c5`](https://github.com/saprmarks/dictionary_learning/commit/7a305c583dbbf06f3dbb223387dc3536a489b0de)) + +* fixing device issue in buffer ([`b1b44f1`](https://github.com/saprmarks/dictionary_learning/commit/b1b44f12e176e73544d863d1d41009a284bc1db5)) + +* added pretrained_dictionary_downloader.sh ([`0028ebe`](https://github.com/saprmarks/dictionary_learning/commit/0028ebe739ac90e2587a86b92b0aa4b2c0b8497e)) + +* added pretrained_dictionary_downloader.sh ([`8b63d8d`](https://github.com/saprmarks/dictionary_learning/commit/8b63d8d6d74f51c00b191519d383de7f6052df0b)) + +* added pretrained_dictionary_downloader.sh ([`6771aff`](https://github.com/saprmarks/dictionary_learning/commit/6771aff6543b320e14fb3db99e0c6fd2613cc905)) + +* efficiency improvements ([`94844d4`](https://github.com/saprmarks/dictionary_learning/commit/94844d4fa9ce4a593faf9b709cf61a45447f84f3)) + +* adding identity dict ([`76bd32f`](https://github.com/saprmarks/dictionary_learning/commit/76bd32fe87bf3c7f3ce45d13d6fe6a69c81e05b4)) + +* debugging interp ([`2f75db3`](https://github.com/saprmarks/dictionary_learning/commit/2f75db31233b1296af97c2002194888715355759)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`86812f5`](https://github.com/saprmarks/dictionary_learning/commit/86812f5dae6a4ebc1605f3b067c27d7b8b96001e)) + +* warns user when evaluating without enough data ([`246c472`](https://github.com/saprmarks/dictionary_learning/commit/246c472d7efb845875c4aa67a8e0dfd417c28f6d)) + +* cleaning up interp ([`95d7310`](https://github.com/saprmarks/dictionary_learning/commit/95d7310ef39ed2fe7a496d0a63a142fe569bdcf5)) + +* examine_dimension returns mbottom_tokens and logit stats ([`40137ff`](https://github.com/saprmarks/dictionary_learning/commit/40137ffe47d9c3ee03e9b46f994c5bd98f5b953e)) + +* continuing merge ([`db693a6`](https://github.com/saprmarks/dictionary_learning/commit/db693a6c4c290bb670f37c0a7e222e25b6b916c6)) + +* progress on merge ([`949b3a7`](https://github.com/saprmarks/dictionary_learning/commit/949b3a755c1458e7d216cc02dc5bf7d8e8f62a1a)) + +* changes to buffer.py ([`792546b`](https://github.com/saprmarks/dictionary_learning/commit/792546b35c45fda3e93abcb0f8cc28f70d0e439c)) + +* fixing some things in buffer.py ([`f58688e`](https://github.com/saprmarks/dictionary_learning/commit/f58688e574f5353f906a470abfbcc386730fdda6)) + +* updating requirements ([`a54b496`](https://github.com/saprmarks/dictionary_learning/commit/a54b4961a7ac9996566a3c32f4d216968afac7b1)) + +* updating requirements ([`a1db591`](https://github.com/saprmarks/dictionary_learning/commit/a1db5917be710c046736574a48bc7f0c2ea98506)) + +* identity dictionary ([`5e1f35e`](https://github.com/saprmarks/dictionary_learning/commit/5e1f35e09abc20c6ee7bc43cfba6231d97121403)) + +* bug fix for neuron resampling ([`b281b53`](https://github.com/saprmarks/dictionary_learning/commit/b281b538c1de2b5ce220b429dd3ea4be44c5b72f)) + +* UMAP visualizations ([`81f8e1f`](https://github.com/saprmarks/dictionary_learning/commit/81f8e1f164def236423e53b89da37d50c115fc62)) + +* better normalization for ghost_loss ([`fc74af7`](https://github.com/saprmarks/dictionary_learning/commit/fc74af75ca2d9d4fdbca6fefb3feb583ef11583d)) + +* neuron resampling without replacement ([`4565e9a`](https://github.com/saprmarks/dictionary_learning/commit/4565e9a14975a4a2d9c736ba7c5551b6c9685ae2)) + +* simplifications to interp functions ([`2318666`](https://github.com/saprmarks/dictionary_learning/commit/231866665154d80e933b5d9ab5be5de5a522c398)) + +* Second nnsight 0.2 pass through ([`3bcebed`](https://github.com/saprmarks/dictionary_learning/commit/3bcebedb801d5654edb3fc7118144953af2366da)) + +* Conversion to nnsight 0.2 first pass ([`cac410a`](https://github.com/saprmarks/dictionary_learning/commit/cac410a72e52cbd6f359fd69bd6fdb346923a9e1)) + +* detaching another thing in ghost grads ([`2f212d6`](https://github.com/saprmarks/dictionary_learning/commit/2f212d6cab348d565633f1bdc0d3e305a6e98d42)) + +* Neuron resampling no longer errors when resampling zero neurons ([`376dd3b`](https://github.com/saprmarks/dictionary_learning/commit/376dd3b51b1433625386ca357c61497a13b6bf6d)) + +* NNsight v0.2 Updates ([`90bbc76`](https://github.com/saprmarks/dictionary_learning/commit/90bbc762aaf369a138f544f2e1f3a4e7a6b5fc4a)) + +* cosmetic improvements to buffer.py ([`b2bd5f0`](https://github.com/saprmarks/dictionary_learning/commit/b2bd5f09cc7f657b7121f0659514d81336903bba)) + +* fix to ghost grads ([`9531fe5`](https://github.com/saprmarks/dictionary_learning/commit/9531fe5f65a23acb32e0f1c96920d67bb1bed15b)) + +* fixing table formatting ([`0e69c8c`](https://github.com/saprmarks/dictionary_learning/commit/0e69c8cc7c446db0ddf86da984417965714ec7ec)) + +* Fixing some table formatting ([`75f927f`](https://github.com/saprmarks/dictionary_learning/commit/75f927f4c722db4c05d64d732b1d025ecdc186aa)) + +* gpt2-small support ([`f82146c`](https://github.com/saprmarks/dictionary_learning/commit/f82146cf586e53407d639ef81f64e1be481a666b)) + +* fixing bug relevant to UnifiedTransformer support ([`9ec9ce4`](https://github.com/saprmarks/dictionary_learning/commit/9ec9ce494384ab303db26be066b3a8004230a16a)) + +* Getting rid of histograms ([`31d09d7`](https://github.com/saprmarks/dictionary_learning/commit/31d09d7136d97c553b8f06c1074ef08ea65be879)) + +* Fixing tables in readme ([`5934011`](https://github.com/saprmarks/dictionary_learning/commit/59340116bb24cbc01cefa76c52641b5b5b46a340)) + +* Updates to the readme ([`a5ca51e`](https://github.com/saprmarks/dictionary_learning/commit/a5ca51ea13cfcd4bb286d644e3416a9af3b5fc53)) + +* Fixing ghost grad bugs ([`633d583`](https://github.com/saprmarks/dictionary_learning/commit/633d583ddaa3090039fca3f1f3e8820ded942e76)) + +* Handling ghost grad case with no dead neurons ([`4f19425`](https://github.com/saprmarks/dictionary_learning/commit/4f19425a4e09ea93bb7ebaad436c2ef227cb420e)) + +* adding support for buffer on other devices ([`f3cf296`](https://github.com/saprmarks/dictionary_learning/commit/f3cf296fe00bf547412f7d500b8993796e30a8b9)) + +* support for ghost grads ([`25d2a62`](https://github.com/saprmarks/dictionary_learning/commit/25d2a62fcaa8bc9be048b5e37aa57441e78262b5)) + +* add an implementation of ghost gradients ([`2e09210`](https://github.com/saprmarks/dictionary_learning/commit/2e09210099d991d45500488dac9654d141815530)) + +* fixing a bug with warmup, adding utils ([`47bbde1`](https://github.com/saprmarks/dictionary_learning/commit/47bbde13f47010bbebf6ac393ae3cdc59b804e9d)) + +* remove HF arg from buffer. rename search_utils to interp ([`7276f17`](https://github.com/saprmarks/dictionary_learning/commit/7276f17288286429162432af6a30763fa80f8117)) + +* typo fix ([`3f6b922`](https://github.com/saprmarks/dictionary_learning/commit/3f6b922c031f9b31652c3998f0ce1e985629c62a)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`278084b`](https://github.com/saprmarks/dictionary_learning/commit/278084b0a54e5804a064358a1fb28bc007e4fae4)) + +* added utils for converting hf dataset to generator ([`82fff19`](https://github.com/saprmarks/dictionary_learning/commit/82fff1968ae883afec82d14246041df793ffd170)) + +* add ablated token effects to ; restore support for HF datasets ([`799e2ca`](https://github.com/saprmarks/dictionary_learning/commit/799e2caeb3f4f4f922531cfb3b14dd34d999ae9d)) + +* merge in function for examining features ([`986bf96`](https://github.com/saprmarks/dictionary_learning/commit/986bf9646e82f35186c74ce88e6c6e4dc1c8470f)) + +* easier submodule/dictionary feature examination ([`2c8b985`](https://github.com/saprmarks/dictionary_learning/commit/2c8b98567e1908a4279efc342f46bd4bd72ab618)) + +* Adding lr warmup after every time neurons are resampled ([`429c582`](https://github.com/saprmarks/dictionary_learning/commit/429c582f84be12d6c326b131f926b33d48698c7b)) + +* fixing issues with EmptyStream exception ([`39ff6e1`](https://github.com/saprmarks/dictionary_learning/commit/39ff6e1cdccb438d335c39c36656657f974f585f)) + +* Minor changes due to updates in nnsight ([`49bbbac`](https://github.com/saprmarks/dictionary_learning/commit/49bbbac6a653398be8726587c2c634e0fd831f02)) + +* Revert "restore support for streaming HF datasets" + +This reverts commit b43527b9b6b24521f6eba68242dc22c3c68173d8. ([`23ada98`](https://github.com/saprmarks/dictionary_learning/commit/23ada983527a748887b7481e255b8dfdb310a23d)) + +* restore support for streaming HF datasets ([`b43527b`](https://github.com/saprmarks/dictionary_learning/commit/b43527b9b6b24521f6eba68242dc22c3c68173d8)) + +* first version of automatic feature labeling ([`c6753f6`](https://github.com/saprmarks/dictionary_learning/commit/c6753f62967503583aae33978b0684d5af0947e5)) + +* Add feature_effect function to search_utils.py ([`0ada2c6`](https://github.com/saprmarks/dictionary_learning/commit/0ada2c654b2dcc71e14869afc813b3adce445472)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`fab70b1`](https://github.com/saprmarks/dictionary_learning/commit/fab70b1b1a17fbe46fbdc54ea34095457c8cbe64)) + +* adding sqrt to MSE ([`63b2174`](https://github.com/saprmarks/dictionary_learning/commit/63b217449c651c78da68571bb032563ac73ebd71)) + +* Merge pull request #1 from cadentj/main + +Update README.md ([`fd79bb3`](https://github.com/saprmarks/dictionary_learning/commit/fd79bb34a7cb56bd987ce8a24764a72586999431)) + +* Update README.md ([`cf5ec24`](https://github.com/saprmarks/dictionary_learning/commit/cf5ec240bcb31db7007dceb7b4362967b044fd01)) + +* Update README.md ([`55f33f2`](https://github.com/saprmarks/dictionary_learning/commit/55f33f226d94baace938501d741ccfb5e9816a56)) + +* evaluation.py ([`2edf59e`](https://github.com/saprmarks/dictionary_learning/commit/2edf59ebb2a625e0862cecd5e4d84249589d95b9)) + +* evaluating dictionaries ([`71e28fb`](https://github.com/saprmarks/dictionary_learning/commit/71e28fbfa2976b099e849c766176252fa8d9fbc2)) + +* Removing experimental use of sqrt on MSELoss ([`865bbb5`](https://github.com/saprmarks/dictionary_learning/commit/865bbb58fdd1af681a2a435f546f4f6dceaaf930)) + +* Adding readme, evaluation, cleaning up ([`ddac948`](https://github.com/saprmarks/dictionary_learning/commit/ddac948a7971e526a47a9dae7311a25c0c56a81c)) + +* some stuff for saving dicts ([`d1f0e21`](https://github.com/saprmarks/dictionary_learning/commit/d1f0e21afc6395ddec71e274bbd3075750f4a76f)) + +* removing device from buffer ([`398f15c`](https://github.com/saprmarks/dictionary_learning/commit/398f15cb5d44ba81e12dee5299841a983e9f54df)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`7f013c2`](https://github.com/saprmarks/dictionary_learning/commit/7f013c2441620391eabba4f408deaa14140a5239)) + +* lr schedule + enabling stretched mlp ([`4eaf7e3`](https://github.com/saprmarks/dictionary_learning/commit/4eaf7e35e8c1c461da761a71968d8e9d1ef0c6b3)) + +* add random feature search ([`e58cc67`](https://github.com/saprmarks/dictionary_learning/commit/e58cc67cb8303b48cf40cb52e586d464f8cb6b48)) + +* restore HF support and progress bar ([`7e2b6c6`](https://github.com/saprmarks/dictionary_learning/commit/7e2b6c69aa7095680affe58c4251577f96505915)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`d33ef05`](https://github.com/saprmarks/dictionary_learning/commit/d33ef052e7d4175a5042855c04a6f3b60acb07ff)) + +* more support for saving checkpints ([`0ca258a`](https://github.com/saprmarks/dictionary_learning/commit/0ca258af3775910ce20d4cce541ff4de962bef3d)) + +* fix unit column bug + add scheduler ([`5a05c8c`](https://github.com/saprmarks/dictionary_learning/commit/5a05c8cd1b29894e8ba77b115727f1511c3334bd)) + +* fix merge bugs: checkpointing support ([`9c5bbd8`](https://github.com/saprmarks/dictionary_learning/commit/9c5bbd8a3ac82e8611434d7ba95da172a80a44a0)) + +* Merge: add HF datasets and checkpointing ([`ccf6ed1`](https://github.com/saprmarks/dictionary_learning/commit/ccf6ed1d9fdc7c0df68c879c893f919d8c192b83)) + +* checkpointing, progress bar, HF dataset support ([`fd8a3ee`](https://github.com/saprmarks/dictionary_learning/commit/fd8a3ee3ee70354191c4d8ecce9d4f8b878d40c6)) + +* progress bar for training autoencoders ([`0a8064d`](https://github.com/saprmarks/dictionary_learning/commit/0a8064dd7ef93904c4b5b4edb9fc7ddbc1e42af1)) + +* implementing neuron resampling ([`f9b9d02`](https://github.com/saprmarks/dictionary_learning/commit/f9b9d020cd5c2daf857d44de2c956a6df2cf7cc3)) + +* lotsa stuff ([`bc09ba4`](https://github.com/saprmarks/dictionary_learning/commit/bc09ba48a701900311d7049dab52549b8239cb15)) + +* adding __init__.py file for imports ([`3d9fd43`](https://github.com/saprmarks/dictionary_learning/commit/3d9fd43957b8c35e1d6377aa33341f663ae5d289)) + +* modifying buffer ([`ba9441b`](https://github.com/saprmarks/dictionary_learning/commit/ba9441b444cd56b2a01c341357d3ede11b06e2b6)) + +* first commit ([`ea89e90`](https://github.com/saprmarks/dictionary_learning/commit/ea89e90e3f737ec8e2a339cfd0b2f1a1082ef850)) + +* Initial commit ([`741f4d6`](https://github.com/saprmarks/dictionary_learning/commit/741f4d6e1d07e55f6c6df5340cc22b9c7f8d49b7)) diff --git a/README.md b/README.md index c2e625c..febe79f 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,17 @@ -This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks and Aaron Mueller. +This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks, Adam Karvonen, and Aaron Mueller. For accessing, saving, and intervening on NN activations, we use the [`nnsight`](http://nnsight.net/) package; as of March 2024, `nnsight` is under active development and may undergo breaking changes. That said, `nnsight` is easy to use and quick to learn; if you plan to modify this repo, then we recommend going through the main `nnsight` demo [here](https://nnsight.net/notebooks/tutorials/walkthrough/). -Some dictionaries trained using this repository (and asociated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. +Some dictionaries trained using this repository (and associated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. SAEs trained with `dictionary_learning` can be evaluated with [SAE Bench](https://www.neuronpedia.org/sae-bench/info) using a convenient [evaluation script](https://github.com/adamkarvonen/SAEBench/tree/main/sae_bench/custom_saes). # Set-up Navigate to the to the location where you would like to clone this repo, clone and enter the repo, and install the requirements. ```bash -git clone https://github.com/saprmarks/dictionary_learning -cd dictionary_learning -pip install -r requirements.txt +pip install dictionary-learning ``` -To use `dictionary_learning`, include it as a subdirectory in some project's directory and import it; see the examples below. +We also provide a [demonstration](https://github.com/adamkarvonen/dictionary_learning_demo), which trains and evaluates 2 SAEs in ~30 minutes before plotting the results. # Using trained dictionaries @@ -61,7 +59,9 @@ This allows us to implement different training protocols (e.g. p-annealing) for Specifically, this repository supports the following trainers: - [`StandardTrainer`](trainers/standard.py): Implements a training scheme similar to that of [Bricken et al., 2023](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder). - [`GatedSAETrainer`](trainers/gdm.py): Implements the training scheme for Gated SAEs described in [Rajamanoharan et al., 2024](https://arxiv.org/abs/2404.16014). -- [`AutoEncoderTopK`](trainers/top_k.py): Implemented the training scheme for Top-K SAEs described in [Gao et al., 2024](https://arxiv.org/abs/2406.04093). +- [`TopKSAETrainer`](trainers/top_k.py): Implemented the training scheme for Top-K SAEs described in [Gao et al., 2024](https://arxiv.org/abs/2406.04093). +- [`BatchTopKSAETrainer`](trainers/batch_top_k.py): Implemented the training scheme for Batch Top-K SAEs described in [Bussmann et al., 2024](https://arxiv.org/abs/2412.06410). +- [`JumpReluTrainer`](trainers/jumprelu.py): Implemented the training scheme for JumpReLU SAEs described in [Rajamanoharan et al., 2024](https://arxiv.org/abs/2407.14435). - [`PAnnealTrainer`](trainers/p_anneal.py): Extends the `StandardTrainer` by providing the option to anneal the sparsity parameter p. - [`GatedAnnealTrainer`](trainers/gated_anneal.py): Extends the `GatedSAETrainer` by providing the option for p-annealing, similar to `PAnnealTrainer`. @@ -121,8 +121,11 @@ ae = trainSAE( ``` Some technical notes our training infrastructure and supported features: * Training uses the `ConstrainedAdam` optimizer defined in `training.py`. This is a variant of Adam which supports constraining the `AutoEncoder`'s decoder weights to be norm 1. -* Neuron resampling: if a `resample_steps` argument is passed to `trainSAE`, then dead neurons will periodically be resampled according to the procedure specified [here](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling). -* Learning rate warmup: if a `warmup_steps` argument is passed to `trainSAE`, then a linear LR warmup is used at the start of training and, if doing neuron resampling, also after every time neurons are resampled. +* Neuron resampling: if a `resample_steps` argument is passed to the Trainer, then dead neurons will periodically be resampled according to the procedure specified [here](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling). +* Learning rate warmup: if a `warmup_steps` argument is passed to the Trainer, then a linear LR warmup is used at the start of training and, if doing neuron resampling, also after every time neurons are resampled. +* Sparsity penalty warmup: if a `sparsity_warmup_steps` is passed to the Trainer, then a linear warmup is applied to the sparsity penalty at the start of training. +* Learning rate decay: if a `decay_start` is passed to the Trainer, then a linear LR decay is used from `decay_start` to the end of training. +* If `normalize_activations` is True and passed to `trainSAE`, then the activations will be normalized to have unit mean squared norm. The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. This is very helpful for hyperparameter transfer between different layers and models. If `submodule` is a model component where the activations are tuples (e.g. this is common when working with residual stream activations), then the buffer yields the first coordinate of the tuple. @@ -205,4 +208,15 @@ We've included support for some experimental features. We briefly investigated t * **Replacing L1 loss with entropy**. Based on the ideas in this [post](https://transformer-circuits.pub/2023/may-update/index.html#simple-factorization), we experimented with using entropy to regularize a dictionary's hidden state instead of L1 loss. This seemed to cause the features to split into dead features (which never fired) and very high-frequency features which fired on nearly every input, which was not the desired behavior. But plausibly there is a way to make this work better. * **Ghost grads**, as described [here](https://transformer-circuits.pub/2024/jan-update/index.html). +# Citation +Please cite the package as follows: + +``` +@misc{marks2024dictionary_learning, + title = {dictionary_learning}, + author = {Samuel Marks, Adam Karvonen, and Aaron Mueller}, + year = {2024}, + howpublished = {\url{https://github.com/saprmarks/dictionary_learning}}, +} +``` diff --git a/__init__.py b/__init__.py deleted file mode 100644 index d4f5e83..0000000 --- a/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder -from .buffer import ActivationBuffer \ No newline at end of file diff --git a/dictionary_learning/__init__.py b/dictionary_learning/__init__.py new file mode 100644 index 0000000..2067aaa --- /dev/null +++ b/dictionary_learning/__init__.py @@ -0,0 +1,6 @@ +__version__ = "0.1.0" + +from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder +from .buffer import ActivationBuffer + +__all__ = ["AutoEncoder", "GatedAutoEncoder", "JumpReluAutoEncoder", "ActivationBuffer"] diff --git a/dictionary_learning/activault_s3_buffer.py b/dictionary_learning/activault_s3_buffer.py new file mode 100644 index 0000000..1b94a7e --- /dev/null +++ b/dictionary_learning/activault_s3_buffer.py @@ -0,0 +1,744 @@ +"""Copyright (2025) Tilde Research Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import io +import json +import os +import random +import signal +import sys +import time +import warnings +from multiprocessing import Process, Queue, Value +from typing import Optional + +import einops +import aiohttp +import boto3 +import torch +import torch.nn as nn +import multiprocessing as mp +import warnings +import logging + +logger = logging.getLogger(__name__) + +# Constants for file sizes +KB = 1024 +MB = KB * KB + +# Cache directory constants +OUTER_CACHE_DIR = "cache" +INNER_CACHE_DIR = "cache" +BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") + + +def _metadata_path(run_name): + """Generate the metadata file path for a given run name.""" + return f"{run_name}/metadata.json" + + +def _statistics_path(run_name): + """Generate the statistics file path for a given run name.""" + return f"{run_name}/statistics.json" + + +async def download_chunks(session, url, total_size, chunk_size): + """Download file chunks asynchronously with retries.""" + tries_left = 5 + while tries_left > 0: + chunks = [ + (i, min(i + chunk_size - 1, total_size - 1)) + for i in range(0, total_size, chunk_size) + ] + tasks = [ + asyncio.create_task(request_chunk(session, url, start, end)) + for start, end in chunks + ] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + results = [] + retry = False + for response in responses: + if isinstance(response, Exception): + logger.error(f"Error occurred: {response}") + logger.error( + f"Session: {session}, URL: {url}, Tries left: {tries_left}" + ) + tries_left -= 1 + retry = True + break + else: + results.append(response) + + if not retry: + return results + + return None + + +async def request_chunk(session, url, start, end): + """Request a specific chunk of a file.""" + headers = {"Range": f"bytes={start}-{end}"} + try: + async with session.get(url, headers=headers) as response: + response.raise_for_status() + return start, await response.read() + except Exception as e: + return e + + +def download_loop(*args): + """Run the asynchronous download loop.""" + asyncio.run(_async_download(*args)) + + +def compile(byte_buffers, shuffle=True, seed=None, return_ids=False): + """Compile downloaded chunks into a tensor.""" + combined_bytes = b"".join( + chunk for _, chunk in sorted(byte_buffers, key=lambda x: x[0]) + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # n = np.frombuffer(combined_bytes, dtype=np.float16) + # t = torch.from_numpy(n) + # t = torch.frombuffer(combined_bytes, dtype=dtype) # torch.float32 + buffer = io.BytesIO(combined_bytes) + t = torch.load(buffer) + if ( + isinstance(t, dict) and "states" in t and not return_ids + ): # backward compatibility + t = t["states"] # ignore input_ids + buffer.close() + + if shuffle and not return_ids: + t = shuffle_megabatch_tokens(t, seed) + + return t + + +def shuffle_megabatch_tokens(t, seed=None): + """ + Shuffle within a megabatch (across batches and sequences), using each token as the unit of shuffling. + + Args: + t (torch.Tensor): Input tensor of shape (batch_size * batches_per_file, sequence_length, d_in + 1) + seed (int): Seed for the random number generator + + Returns: + torch.Tensor: Shuffled tensor of the same shape as input + """ + original_shape = ( + t.shape + ) # (batch_size * batches_per_file, sequence_length, d_in + 1) + + total_tokens = ( + original_shape[0] * original_shape[1] + ) # reshape to (total_tokens, d_in + 1) + t_reshaped = t.reshape(total_tokens, -1) + + rng = torch.Generator() + if seed is not None: + rng.manual_seed(seed) + + shuffled_indices = torch.randperm(total_tokens, generator=rng) + t_shuffled = t_reshaped[shuffled_indices] + + t = t_shuffled.reshape(original_shape) # revert + + return t + + +def write_tensor(t, buffer, writeable_tensors, readable_tensors, ongoing_downloads): + """Write a tensor to the shared buffer.""" + idx = writeable_tensors.get(block=True) + if isinstance(buffer[0], SharedBuffer): + buffer[idx].states.copy_(t["states"]) + buffer[idx].input_ids.copy_(t["input_ids"]) + else: + buffer[idx] = t + + readable_tensors.put(idx, block=True) + with ongoing_downloads.get_lock(): + ongoing_downloads.value -= 1 + + +async def _async_download( + buffer, + file_index, + s3_paths, + stop, + readable_tensors, + writeable_tensors, + ongoing_downloads, + concurrency, + bytes_per_file, + chunk_size, + shuffle, + seed, + return_ids, +): + """Asynchronously download and process files from S3.""" + connector = aiohttp.TCPConnector(limit=concurrency) + async with aiohttp.ClientSession(connector=connector) as session: + while file_index.value < len(s3_paths) and not stop.value: + with ongoing_downloads.get_lock(): + ongoing_downloads.value += 1 + with file_index.get_lock(): + url = s3_paths[file_index.value] + file_index.value += 1 + bytes_results = await download_chunks( + session, url, bytes_per_file, chunk_size + ) + if bytes_results is not None: + try: + t = compile(bytes_results, shuffle, seed, return_ids) + write_tensor( + t, + buffer, + writeable_tensors, + readable_tensors, + ongoing_downloads, + ) + except Exception as e: + logger.error(f"Exception while downloading: {e}") + logger.error(f"Failed URL: {url}") + stop.value = True # Set stop flag + break # Exit the loop + else: + logger.error(f"Failed to download URL: {url}") + with ongoing_downloads.get_lock(): + ongoing_downloads.value -= 1 + + +class S3RCache: + """A cache that reads data from Amazon S3.""" + + @classmethod + def from_credentials( + self, aws_access_key_id, aws_secret_access_key, *args, **kwargs + ): + s3_client = boto3.client( + "s3", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + endpoint_url=os.environ.get("S3_ENDPOINT_URL"), + ) + return S3RCache(s3_client, *args, **kwargs) + + def __init__( + self, + s3_client, + s3_prefix, + bucket_name=BUCKET_NAME, + device="cpu", + concurrency=100, + chunk_size=MB * 16, + buffer_size=2, + shuffle=True, + preserve_file_order=False, + seed=42, + paths=None, + n_workers=1, + return_ids=False, + ) -> None: + """Initialize S3 cache.""" + ensure_spawn_context() + + # Configure S3 client with correct signature version + self.s3_client = ( + boto3.client( + "s3", + region_name="eu-north1", # Make sure this matches your bucket region + config=boto3.session.Config(signature_version="s3v4"), + ) + if s3_client is None + else s3_client + ) + + self.s3_prefix = s3_prefix + self.bucket_name = bucket_name + self.device = device + self.concurrency = concurrency + self.chunk_size = chunk_size + self.buffer_size = buffer_size + self.shuffle = shuffle + self.preserve_file_order = preserve_file_order + self.seed = seed + self.return_ids = return_ids + + random.seed(self.seed) + torch.manual_seed(self.seed) # unclear if this has effect + # but we drill down the seed to download loop anyway + + self.paths = paths + self._s3_paths = self._list_s3_files() + if isinstance(self.s3_prefix, list): + target_prefix = self.s3_prefix[0] + else: + target_prefix = self.s3_prefix + response = self.s3_client.get_object( + Bucket=bucket_name, Key=_metadata_path(target_prefix) + ) + content = response["Body"].read() + self.metadata = json.loads(content) + # self.metadata["bytes_per_file"] = 1612711320 + self._activation_dtype = eval(self.metadata["dtype"]) + + self._running_processes = [] + self.n_workers = n_workers + + self.readable_tensors = Queue(maxsize=self.buffer_size) + self.writeable_tensors = Queue(maxsize=self.buffer_size) + + for i in range(self.buffer_size): + self.writeable_tensors.put(i) + + if self.return_ids: + self.buffer = [ + SharedBuffer( + self.metadata["shape"], + self.metadata["input_ids_shape"], + self._activation_dtype, + ) + for _ in range(self.buffer_size) + ] + for shared_buffer in self.buffer: + shared_buffer.share_memory() + else: + self.buffer = torch.empty( + (self.buffer_size, *self.metadata["shape"]), + dtype=self._activation_dtype, + ).share_memory_() + + self._stop = Value("b", False) + self._file_index = Value("i", 0) + self._ongoing_downloads = Value("i", 0) + + signal.signal(signal.SIGTERM, self._catch_stop) + signal.signal(signal.SIGINT, self._catch_stop) + + self._initial_file_index = 0 + + @property + def current_file_index(self): + return self._file_index.value + + def set_file_index(self, index): + self._initial_file_index = index + + def _catch_stop(self, *args, **kwargs): + logger.info("cleaning up before process is killed") + self._stop_downloading() + sys.exit(0) + + def sync(self): + self._s3_paths = self._list_s3_files() + + def _reset(self): + self._file_index.value = self._initial_file_index + self._ongoing_downloads.value = 0 + self._stop.value = False + + while not self.readable_tensors.empty(): + self.readable_tensors.get() + + while not self.writeable_tensors.empty(): + self.writeable_tensors.get() + for i in range(self.buffer_size): + self.writeable_tensors.put(i) + + def _list_s3_files(self): + """List and prepare all data files from one or more S3 prefixes.""" + paths = [] + combined_metadata = None + combined_config = None + + # Handle single prefix case for backward compatibility + prefixes = ( + [self.s3_prefix] if isinstance(self.s3_prefix, str) else self.s3_prefix + ) + + # Process each prefix + for prefix in prefixes: + # Get metadata for this prefix + response = self.s3_client.get_object( + Bucket=self.bucket_name, Key=_metadata_path(prefix) + ) + metadata = json.loads(response["Body"].read()) + + # Get config for this prefix + try: + config_response = self.s3_client.get_object( + Bucket=self.bucket_name, + Key=f"{'/'.join(prefix.split('/')[:-1])}/cfg.json", + ) + config = json.loads(config_response["Body"].read()) + except Exception as e: + logger.warning( + f"Warning: Could not load config for prefix {prefix}: {e}" + ) + config = {} + + # Initialize combined metadata and config from first prefix + if combined_metadata is None: + combined_metadata = metadata.copy() + combined_config = config.copy() + # Initialize accumulation fields + combined_config["total_tokens"] = 0 + combined_config["n_total_files"] = 0 + combined_config["batches_processed"] = 0 + else: + # Verify metadata compatibility + if metadata["shape"][1:] != combined_metadata["shape"][1:]: + raise ValueError( + f"Incompatible shapes between datasets: {metadata['shape']} vs {combined_metadata['shape']}" + ) + if metadata["dtype"] != combined_metadata["dtype"]: + raise ValueError(f"Incompatible dtypes between datasets") + + # Accumulate config fields + combined_config["total_tokens"] += config.get("total_tokens", 0) + combined_config["n_total_files"] += config.get("n_total_files", 0) + combined_config["batches_processed"] += config.get("batches_processed", 0) + + # List files for this prefix + paginator = self.s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) + + prefix_paths = [] + for page in page_iterator: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + if ( + obj["Key"] != _metadata_path(prefix) + and obj["Key"] != _statistics_path(prefix) + and not obj["Key"].endswith("cfg.json") + ): + url = self.s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": self.bucket_name, "Key": obj["Key"]}, + ExpiresIn=604700, + ) + prefix_paths.append(url) + + paths.extend(prefix_paths) + + # Store the combined metadata and config + self.metadata = combined_metadata + self.config = combined_config # Store combined config for potential later use + + if self.preserve_file_order: + # chronological upload order + return sorted(paths) + else: + # shuffle the file order + random.shuffle(paths) + return paths + + def __iter__(self): + self._reset() + + if self._running_processes: + raise ValueError( + "Cannot iterate over cache a second time while it is downloading" + ) + + if len(self._s3_paths) > self._initial_file_index: + while len(self._running_processes) < self.n_workers: + p = Process( + target=download_loop, + args=( + self.buffer, + self._file_index, + self._s3_paths[ + self._initial_file_index : + ], # Start from the initial index + self._stop, + self.readable_tensors, + self.writeable_tensors, + self._ongoing_downloads, + self.concurrency, + self.metadata["bytes_per_file"], + self.chunk_size, + self.shuffle, + self.seed, + self.return_ids, + ), + ) + p.start() + self._running_processes.append(p) + time.sleep(0.75) + + return self + + def _next_tensor(self): + try: + idx = self.readable_tensors.get(block=True) + if self.return_ids: + t = { + "states": self.buffer[idx].states.clone().detach(), + "input_ids": self.buffer[idx].input_ids.clone().detach(), + } + else: + t = self.buffer[idx].clone().detach() + + self.writeable_tensors.put(idx, block=True) + return t + except Exception as e: + logger.error(f"exception while iterating: {e}") + self._stop_downloading() + raise StopIteration + + def __next__(self): + while ( + self._file_index.value < len(self._s3_paths) + or not self.readable_tensors.empty() + or self._ongoing_downloads.value > 0 + ): + return self._next_tensor() + + if self._running_processes: + self._stop_downloading() + raise StopIteration + + def finalize(self): + self._stop_downloading() + + def _stop_downloading(self): + logger.info("stopping workers...") + self._file_index.value = len(self._s3_paths) + self._stop.value = True + + while not all([not p.is_alive() for p in self._running_processes]): + if not self.readable_tensors.empty(): + self.readable_tensors.get() + + if not self.writeable_tensors.full(): + self.writeable_tensors.put(0) + + time.sleep(0.25) + + for p in self._running_processes: + p.join() # still join to make sure all resources are cleaned up + + self._ongoing_downloads.value = 0 + self._running_processes = [] + + +""" +tl;dr of why we need this: +shared memory is handled differently for nested structures -- see buffer intiialization +we can initialize a dict with two tensors with shared memory, and these tensors themselves are shared but NOT the dict +hence writing to buffer[idx] in write_tensor will not actually write to self.buffer[idx], which _next_tensor uses +(possibly a better fix, but for now this works) +""" + + +class SharedBuffer(nn.Module): + def __init__(self, shape, input_ids_shape, dtype): + super().__init__() + self.states = nn.Parameter(torch.ones(shape, dtype=dtype), requires_grad=False) + self.input_ids = nn.Parameter( + torch.ones(input_ids_shape, dtype=torch.int64), requires_grad=False + ) + + def forward(self): + return {"states": self.states, "input_ids": self.input_ids} + + +### mini-helper for multiprocessing +def ensure_spawn_context(): + """ + Ensures multiprocessing uses 'spawn' context if not already set. + Returns silently if already set to 'spawn'. + Issues warning if unable to set to 'spawn'. + """ + if mp.get_start_method(allow_none=True) != "spawn": + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + warnings.warn( + "Multiprocessing start method is not 'spawn'. This may cause issues." + ) + + +def create_s3_client( + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, + endpoint_url: Optional[str] = None, +) -> boto3.client: + """Create an S3 client configured for S3-compatible storage services. + + This function creates a boto3 S3 client with optimized settings for reliable + data transfer. It supports both direct credential passing and environment + variable configuration. + + Args: + access_key_id: S3 access key ID. If None, reads from AWS_ACCESS_KEY_ID env var + secret_access_key: S3 secret key. If None, reads from AWS_SECRET_ACCESS_KEY env var + endpoint_url: S3-compatible storage service endpoint URL + + Returns: + boto3.client: Configured S3 client with optimized settings + + Environment Variables: + - AWS_ACCESS_KEY_ID: S3 access key ID (if not provided as argument) + - AWS_SECRET_ACCESS_KEY: S3 secret key (if not provided as argument) + + Example: + ```python + # Using environment variables + s3_client = create_s3_client() + + # Using explicit credentials + s3_client = create_s3_client( + access_key_id="your_key", + secret_access_key="your_secret", + endpoint_url="your_endpoint_url" + ) + ``` + + Note: + The client is configured with path-style addressing and S3v4 signatures + for maximum compatibility with S3-compatible storage services. + """ + access_key_id = access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") + secret_access_key = secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") + endpoint_url = endpoint_url or os.environ.get("S3_ENDPOINT_URL") + + if not access_key_id or not secret_access_key: + raise ValueError( + "S3 credentials must be provided either through arguments or " + "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables" + ) + + if not endpoint_url: + raise ValueError( + "S3 endpoint URL must be provided either through arguments or " + "S3_ENDPOINT_URL environment variable" + ) + + session = boto3.session.Session() + return session.client( + service_name="s3", + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + endpoint_url=endpoint_url, + use_ssl=True, + verify=True, + config=boto3.session.Config( + s3={"addressing_style": "path"}, + signature_version="s3v4", + # Advanced configuration options (currently commented out): + # retries=dict( + # max_attempts=3, # Number of retry attempts + # mode='adaptive' # Adds exponential backoff + # ), + # max_pool_connections=20, # Limits concurrent connections + # connect_timeout=60, # Connection timeout in seconds + # read_timeout=300, # Read timeout in seconds + # tcp_keepalive=True, # Enable TCP keepalive + ), + ) + + +class ActivaultS3ActivationBuffer: + def __init__( + self, + cache: S3RCache, + batch_size: int = 8192, + device: str = "cpu", + io: str = "out", + ): + self.cache = iter(cache) # Make sure it's an iterator + self.batch_size = batch_size + self.device = device + self.io = io + + self.states = None # Shape: [N, D] + self.read_mask = None # Shape: [N] + self.refresh() # Load the first batch + + def __iter__(self): + return self + + def __next__(self): + with torch.no_grad(): + if (~self.read_mask).sum() < self.batch_size: + self.refresh() + + if self.states is None or self.states.shape[0] == 0: + raise StopIteration + + unreads = (~self.read_mask).nonzero().squeeze() + if unreads.ndim == 0: + unreads = unreads.unsqueeze(0) + selected = unreads[ + torch.randperm(len(unreads), device=self.device)[: self.batch_size] + ] + self.read_mask[selected] = True + return self.states[selected] + + def refresh(self): + try: + next_batch = next(self.cache) # dict with "states" key + except StopIteration: + self.states = None + self.read_mask = None + return + + states = next_batch["states"].to(self.device) # [B, L, D] + flat_states = einops.rearrange(states, "b l d -> (b l) d").contiguous() + self.states = flat_states + self.read_mask = torch.zeros( + flat_states.shape[0], dtype=torch.bool, device=self.device + ) + + def close(self): + if hasattr(self.cache, "finalize"): + self.cache.finalize() + elif hasattr(self.cache, "close"): + self.cache.close() + + +if __name__ == "__main__": + device = "cuda" + sae_batch_size = 2048 + io = "out" + + # example activault usage + + BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") + s3_prefix = ["mistral.8b.fineweb/blocks.9.hook_resid_post"] + cache = S3RCache.from_credentials( + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), + s3_prefix=s3_prefix, + bucket_name=BUCKET_NAME, + device=device, + buffer_size=2, + return_ids=True, + shuffle=True, + n_workers=2, + ) + + s3_buffer = ActivaultS3ActivationBuffer( + cache, batch_size=sae_batch_size, device=device, io=io + ) diff --git a/buffer.py b/dictionary_learning/buffer.py similarity index 93% rename from buffer.py rename to dictionary_learning/buffer.py index 86f24f9..f0e02e1 100644 --- a/buffer.py +++ b/dictionary_learning/buffer.py @@ -26,7 +26,9 @@ def __init__(self, ctx_len=128, # length of each context refresh_batch_size=512, # size of batches in which to process the data when adding to buffer out_batch_size=8192, # size of batches in which to yield activations - device='cpu' # device on which to store the activations + device='cpu', # device on which to store the activations + remove_bos: bool = False, + add_special_tokens: bool = True, ): if io not in ['in', 'out']: @@ -40,7 +42,7 @@ def __init__(self, d_submodule = submodule.out_features except: raise ValueError("d_submodule cannot be inferred and must be specified directly") - self.activations = t.empty(0, d_submodule, device=device) + self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype) self.read = t.zeros(0).bool() self.data = data @@ -54,7 +56,9 @@ def __init__(self, self.refresh_batch_size = refresh_batch_size self.out_batch_size = out_batch_size self.device = device - + self.remove_bos = remove_bos and (self.model.tokenizer.bos_token_id is not None) + self.add_special_tokens = add_special_tokens + def __iter__(self): return self @@ -96,7 +100,8 @@ def tokenized_batch(self, batch_size=None): return_tensors='pt', max_length=self.ctx_len, padding=True, - truncation=True + truncation=True, + add_special_tokens=self.add_special_tokens ) def refresh(self): @@ -105,7 +110,7 @@ def refresh(self): self.activations = self.activations[~self.read] current_idx = len(self.activations) - new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device) + new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device, dtype=self.model.dtype) new_activations[: len(self.activations)] = self.activations self.activations = new_activations @@ -115,21 +120,28 @@ def refresh(self): while current_idx < self.activation_buffer_size: with t.no_grad(): + tokens = self.tokenized_batch() with self.model.trace( - self.text_batch(), + tokens, **tracer_kwargs, invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io == "in": - hidden_states = self.submodule.input[0].save() + hidden_states = self.submodule.inputs[0].save() else: hidden_states = self.submodule.output.save() - input = self.model.input.save() - attn_mask = input.value[1]["attention_mask"] + input = self.model.inputs.save() + + self.submodule.output.stop() + + mask = (input.value[1]["attention_mask"] != 0) hidden_states = hidden_states.value if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] - hidden_states = hidden_states[attn_mask != 0] + if self.remove_bos: + bos_mask = (input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id) + mask = mask & ~bos_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx assert remaining_space > 0 @@ -251,8 +263,8 @@ def refresh(self): while len(self.activations) < self.n_ctxs * self.ctx_len: with t.no_grad(): with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): - input = self.model.input.save() - hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() + input = self.model.inputs.save() + hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.inputs[0][0]#.save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] @@ -416,7 +428,7 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io in ["in", "in_and_out"]: - hidden_states_in = self.submodule.input[0].save() + hidden_states_in = self.submodule.inputs[0].save() if self.io in ["out", "in_and_out"]: hidden_states_out = self.submodule.output.save() diff --git a/config.py b/dictionary_learning/config.py similarity index 100% rename from config.py rename to dictionary_learning/config.py diff --git a/dictionary.py b/dictionary_learning/dictionary.py similarity index 65% rename from dictionary.py rename to dictionary_learning/dictionary.py index f0eb176..238a866 100644 --- a/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -2,17 +2,20 @@ Defines the dictionary classes """ -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod import torch as t import torch.nn as nn import torch.nn.init as init +import einops + class Dictionary(ABC, nn.Module): """ A dictionary consists of a collection of vectors, an encoder, and a decoder. """ - dict_size : int # number of features in the dictionary - activation_dim : int # dimension of the activation vectors + + dict_size: int # number of features in the dictionary + activation_dim: int # dimension of the activation vectors @abstractmethod def encode(self, x): @@ -20,7 +23,7 @@ def encode(self, x): Encode a vector x in the activation space. """ pass - + @abstractmethod def decode(self, f): """ @@ -41,25 +44,29 @@ class AutoEncoder(Dictionary, nn.Module): """ A one-layer autoencoder. """ + def __init__(self, activation_dim, dict_size): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size self.bias = nn.Parameter(t.zeros(activation_dim)) self.encoder = nn.Linear(activation_dim, dict_size, bias=True) - - # rows of decoder weight matrix are unit vectors self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - dec_weight = t.randn_like(self.decoder.weight) - dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True) - self.decoder.weight = nn.Parameter(dec_weight) + + # initialize encoder and decoder weights + w = t.randn(activation_dim, dict_size) + ## normalize columns of w + w = w / w.norm(dim=0, keepdim=True) * 0.1 + ## set encoder and decoder weights + self.encoder.weight = nn.Parameter(w.clone().T) + self.decoder.weight = nn.Parameter(w.clone()) def encode(self, x): return nn.ReLU()(self.encoder(x - self.bias)) - + def decode(self, f): return self.decoder(f) + self.bias - + def forward(self, x, output_features=False, ghost_mask=None): """ Forward pass of an autoencoder. @@ -67,72 +74,127 @@ def forward(self, x, output_features=False, ghost_mask=None): output_features : if True, return the encoded features as well as the decoded x ghost_mask : if not None, run this autoencoder in "ghost mode" where features are masked """ - if ghost_mask is None: # normal mode + if ghost_mask is None: # normal mode f = self.encode(x) x_hat = self.decode(f) if output_features: return x_hat, f else: return x_hat - - else: # ghost mode + + else: # ghost mode f_pre = self.encoder(x - self.bias) f_ghost = t.exp(f_pre) * ghost_mask.to(f_pre) f = nn.ReLU()(f_pre) - x_ghost = self.decoder(f_ghost) # note that this only applies the decoder weight matrix, no bias + x_ghost = self.decoder( + f_ghost + ) # note that this only applies the decoder weight matrix, no bias x_hat = self.decode(f) if output_features: return x_hat, x_ghost, f else: return x_hat, x_ghost - + + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.bias.data *= scale + + def normalize_decoder(self): + norms = t.norm(self.decoder.weight, dim=0).to(dtype=self.decoder.weight.dtype, device=self.decoder.weight.device) + + if t.allclose(norms, t.ones_like(norms)): + return + print("Normalizing decoder weights") + + test_input = t.randn(10, self.activation_dim) + initial_output = self(test_input) + + self.decoder.weight.data /= norms + + new_norms = t.norm(self.decoder.weight, dim=0) + assert t.allclose(new_norms, t.ones_like(new_norms)) + + self.encoder.weight.data *= norms[:, None] + self.encoder.bias.data *= norms + + new_output = self(test_input) + + # Errors can be relatively large in larger SAEs due to floating point precision + assert t.allclose(initial_output, new_output, atol=1e-4) + + @classmethod - def from_pretrained(cls, path, dtype=t.float, device=None): + def from_pretrained(cls, path, dtype=t.float, device=None, normalize_decoder=True): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = cls(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) + + # This is useful for doing analysis where e.g. feature activation magnitudes are important + # If training the SAE using the April update, the decoder weights are not normalized + if normalize_decoder: + autoencoder.normalize_decoder() + if device is not None: autoencoder.to(dtype=dtype, device=device) + return autoencoder - + + class IdentityDict(Dictionary, nn.Module): """ An identity dictionary, i.e. the identity function. """ - def __init__(self, activation_dim=None): + + def __init__(self, activation_dim=None, dtype=None, device=None): super().__init__() self.activation_dim = activation_dim self.dict_size = activation_dim + self.device = device + self.dtype = dtype def encode(self, x): + if self.device is not None: + x = x.to(self.device) + if self.dtype is not None: + x = x.to(self.dtype) return x - + def decode(self, f): + if self.device is not None: + f = f.to(self.device) + if self.dtype is not None: + f = f.to(self.dtype) return f - + def forward(self, x, output_features=False, ghost_mask=None): + if self.device is not None: + x = x.to(self.device) + if self.dtype is not None: + x = x.to(self.dtype) if output_features: return x, x else: return x - + @classmethod - def from_pretrained(cls, path, dtype=t.float, device=None): + def from_pretrained(cls, activation_dim, path, dtype=None, device=None): """ Load a pretrained dictionary from a file. """ - return cls(None) - + return cls(activation_dim, device=device, dtype=dtype) + + class GatedAutoEncoder(Dictionary, nn.Module): """ An autoencoder with separate gating and magnitude networks. """ - def __init__(self, activation_dim, dict_size, initialization='default', device=None): + + def __init__(self, activation_dim, dict_size, initialization="default", device=None): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size @@ -142,7 +204,7 @@ def __init__(self, activation_dim, dict_size, initialization='default', device=N self.gate_bias = nn.Parameter(t.empty(dict_size, device=device)) self.mag_bias = nn.Parameter(t.empty(dict_size, device=device)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False, device=device) - if initialization == 'default': + if initialization == "default": self._reset_parameters() else: initialization(self) @@ -161,8 +223,9 @@ def _reset_parameters(self): dec_weight = t.randn_like(self.decoder.weight) dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True) self.decoder.weight = nn.Parameter(dec_weight) + self.encoder.weight = nn.Parameter(dec_weight.clone().T) - def encode(self, x, return_gate=False): + def encode(self, x: t.Tensor, return_gate:bool=False, normalize_decoder:bool=False): """ Returns features, gate value (pre-Heavyside) """ @@ -178,65 +241,74 @@ def encode(self, x, return_gate=False): f = f_gate * f_mag - # W_dec norm is not kept constant, as per Anthropic's April 2024 Update - # Normalizing after encode, and renormalizing before decode to enable comparability - f = f * self.decoder.weight.norm(dim=0, keepdim=True) + if normalize_decoder: + # If the SAE is trained without ConstrainedAdam, the decoder vectors are not normalized + # Normalizing after encode, and renormalizing before decode to enable comparability + f = f * self.decoder.weight.norm(dim=0, keepdim=True) if return_gate: return f, nn.ReLU()(pi_gate) return f - def decode(self, f): - # W_dec norm is not kept constant, as per Anthropic's April 2024 Update - # Normalizing after encode, and renormalizing before decode to enable comparability - f = f / self.decoder.weight.norm(dim=0, keepdim=True) + def decode(self, f: t.Tensor, normalize_decoder:bool=False): + if normalize_decoder: + # If the SAE is trained without ConstrainedAdam, the decoder vectors are not normalized + # Normalizing after encode, and renormalizing before decode to enable comparability + f = f / self.decoder.weight.norm(dim=0, keepdim=True) return self.decoder(f) + self.decoder_bias - - def forward(self, x, output_features=False): + + def forward(self, x:t.Tensor, output_features:bool=False, normalize_decoder:bool=False): f = self.encode(x) x_hat = self.decode(f) - f = f * self.decoder.weight.norm(dim=0, keepdim=True) + if normalize_decoder: + f = f * self.decoder.weight.norm(dim=0, keepdim=True) if output_features: return x_hat, f else: return x_hat + def scale_biases(self, scale: float): + self.decoder_bias.data *= scale + self.mag_bias.data *= scale + self.gate_bias.data *= scale + def from_pretrained(path, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = GatedAutoEncoder(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: autoencoder.to(device) return autoencoder - + + class JumpReluAutoEncoder(Dictionary, nn.Module): """ An autoencoder with jump ReLUs. """ - def __init__(self, activation_dim, dict_size, device='cpu'): + def __init__(self, activation_dim, dict_size, device="cpu"): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size, device=device)) self.b_enc = nn.Parameter(t.zeros(dict_size, device=device)) - self.W_dec = nn.Parameter(t.empty(dict_size, activation_dim, device=device)) + self.W_dec = nn.Parameter( + t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim, device=device)) + ) self.b_dec = nn.Parameter(t.zeros(activation_dim, device=device)) - self.threshold = nn.Parameter(t.zeros(dict_size, device=device)) + self.threshold = nn.Parameter(t.ones(dict_size, device=device) * 0.001) # Appendix I self.apply_b_dec_to_input = False - # rows of decoder weight matrix are initialized to unit vectors - self.W_enc.data = t.randn_like(self.W_enc) - self.W_enc.data = self.W_enc / self.W_enc.norm(dim=0, keepdim=True) - self.W_dec.data = self.W_enc.data.clone().T + self.W_dec.data = self.W_dec / self.W_dec.norm(dim=1, keepdim=True) + self.W_enc.data = self.W_dec.data.clone().T def encode(self, x, output_pre_jump=False): if self.apply_b_dec_to_input: @@ -244,17 +316,15 @@ def encode(self, x, output_pre_jump=False): pre_jump = x @ self.W_enc + self.b_enc f = nn.ReLU()(pre_jump * (pre_jump > self.threshold)) - f = f * self.W_dec.norm(dim=1) if output_pre_jump: return f, pre_jump else: return f - + def decode(self, f): - f = f / self.W_dec.norm(dim=1) return f @ self.W_dec + self.b_dec - + def forward(self, x, output_features=False): """ Forward pass of an autoencoder. @@ -267,15 +337,20 @@ def forward(self, x, output_features=False): return x_hat, f else: return x_hat - + + def scale_biases(self, scale: float): + self.b_dec.data *= scale + self.b_enc.data *= scale + self.threshold.data *= scale + @classmethod def from_pretrained( - cls, - path: str | None = None, - load_from_sae_lens: bool = False, - dtype: t.dtype = t.float32, - device: t.device | None = None, - **kwargs, + cls, + path: str | None = None, + load_from_sae_lens: bool = False, + dtype: t.dtype = t.float32, + device: t.device | None = None, + **kwargs, ): """ Load a pretrained autoencoder from a file. @@ -284,13 +359,17 @@ def from_pretrained( """ if not load_from_sae_lens: state_dict = t.load(path) - dict_size, activation_dim = state_dict['W_enc'].shape + activation_dim, dict_size = state_dict["W_enc"].shape autoencoder = JumpReluAutoEncoder(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) + autoencoder = autoencoder.to(dtype=dtype, device=device) else: from sae_lens import SAE + sae, cfg_dict, _ = SAE.from_pretrained(**kwargs) - assert cfg_dict["finetuning_scaling_factor"] == False, "Finetuning scaling factor not supported" + assert ( + cfg_dict["finetuning_scaling_factor"] == False + ), "Finetuning scaling factor not supported" dict_size, activation_dim = cfg_dict["d_sae"], cfg_dict["d_in"] autoencoder = JumpReluAutoEncoder(activation_dim, dict_size, device=device) autoencoder.load_state_dict(sae.state_dict()) @@ -300,11 +379,13 @@ def from_pretrained( device = autoencoder.W_enc.device return autoencoder.to(dtype=dtype, device=device) + # TODO merge this with AutoEncoder class AutoEncoderNew(Dictionary, nn.Module): """ The autoencoder architecture and initialization used in https://transformer-circuits.pub/2024/april-update/index.html#training-saes """ + def __init__(self, activation_dim, dict_size): super().__init__() self.activation_dim = activation_dim @@ -326,10 +407,10 @@ def __init__(self, activation_dim, dict_size): def encode(self, x): return nn.ReLU()(self.encoder(x)) - + def decode(self, f): return self.decoder(f) - + def forward(self, x, output_features=False): """ Forward pass of an autoencoder. @@ -337,19 +418,19 @@ def forward(self, x, output_features=False): """ if not output_features: return self.decode(self.encode(x)) - else: # TODO rewrite so that x_hat depends on f + else: # TODO rewrite so that x_hat depends on f f = self.encode(x) x_hat = self.decode(f) # multiply f by decoder column norms f = f * self.decoder.weight.norm(dim=0, keepdim=True) return x_hat, f - + def from_pretrained(path, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = AutoEncoderNew(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: diff --git a/evaluation.py b/dictionary_learning/evaluation.py similarity index 71% rename from evaluation.py rename to dictionary_learning/evaluation.py index 6b3b0e5..ba56437 100644 --- a/evaluation.py +++ b/dictionary_learning/evaluation.py @@ -3,6 +3,8 @@ """ import torch as t +from collections import defaultdict + from .buffer import ActivationBuffer, NNsightActivationBuffer from nnsight import LanguageModel from .config import DEBUG @@ -22,12 +24,21 @@ def loss_recovered( How much of the model's loss is recovered by replacing the component output with the reconstruction by the autoencoder? """ - + if max_len is None: invoker_args = {} else: invoker_args = {"truncation": True, "max_length": max_len } + with model.trace("_"): + temp_output = submodule.output.save() + + output_is_tuple = False + # Note: isinstance() won't work here as torch.Size is a subclass of tuple, + # so isinstance(temp_output.shape, tuple) would return True even for torch.Size. + if type(temp_output.shape) == tuple: + output_is_tuple = True + # unmodified logits with model.trace(text, invoker_args=invoker_args): logits_original = model.output.save() @@ -36,21 +47,18 @@ def loss_recovered( # logits when replacing component activations with reconstruction by autoencoder with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'out': x = submodule.output - if type(submodule.output.shape) == tuple: x = x[0] + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'in_and_out': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] - print(f'x.shape: {x.shape}') + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale @@ -58,35 +66,38 @@ def loss_recovered( raise ValueError(f"Invalid value for io: {io}") x = x.save() - # pull this out so dictionary can be written without FakeTensor (top_k needs this) - x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape).to(model.dtype) + # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here. + assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}" + + x_hat = dictionary(x).to(model.dtype) # intervene with `x_hat` with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = x_hat - else: - submodule.input = x_hat + submodule.input[:] = x_hat elif io == 'out': x = submodule.output + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.output.shape) == tuple: - submodule.output = (x_hat,) + if output_is_tuple: + submodule.output[0][:] = x_hat else: - submodule.output = x_hat + submodule.output[:] = x_hat elif io == 'in_and_out': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output = x_hat + if output_is_tuple: + submodule.output[0][:] = x_hat + else: + submodule.output[:] = x_hat else: raise ValueError(f"Invalid value for io: {io}") @@ -96,22 +107,20 @@ def loss_recovered( # logits when replacing component activations with zeros with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = t.zeros_like(x[0]) - else: - submodule.input = t.zeros_like(x) + x = submodule.input + submodule.input[:] = t.zeros_like(x) elif io in ['out', 'in_and_out']: x = submodule.output - if type(submodule.output.shape) == tuple: + if output_is_tuple: submodule.output[0][:] = t.zeros_like(x[0]) else: - submodule.output = t.zeros_like(x) + submodule.output[:] = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") - input = model.input.save() + input = model.inputs.save() logits_zero = model.output.save() + logits_zero = logits_zero.value # get everything into the right format @@ -144,7 +153,7 @@ def loss_recovered( return tuple(losses) - +@t.no_grad() def evaluate( dictionary, # a dictionary activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered @@ -154,26 +163,31 @@ def evaluate( normalize_batch=False, # normalize batch before passing through dictionary tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. device="cpu", + n_batches: int = 1, ): - with t.no_grad(): - - out = {} # dict of results + assert n_batches > 0 + out = defaultdict(float) + active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device) + for _ in range(n_batches): try: x = next(activations).to(device) if normalize_batch: x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) - except StopIteration: raise StopIteration( "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." ) - x_hat, f = dictionary(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() l0 = (f != 0).float().sum(dim=-1).mean() - frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size + + features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) + assert features_BF.shape[-1] == dictionary.dict_size + assert len(features_BF.shape) == 2 + + active_features += features_BF.sum(dim=0) # cosine similarity between x and x_hat x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) @@ -193,17 +207,16 @@ def evaluate( x_dot_x_hat = (x * x_hat).sum(dim=-1) relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean() - out["l2_loss"] = l2_loss.item() - out["l1_loss"] = l1_loss.item() - out["l0"] = l0.item() - out["frac_alive"] = frac_alive.item() - out["frac_variance_explained"] = frac_variance_explained.item() - out["cossim"] = cossim.item() - out["l2_ratio"] = l2_ratio.item() - out['relative_reconstruction_bias'] = relative_reconstruction_bias.item() + out["l2_loss"] += l2_loss.item() + out["l1_loss"] += l1_loss.item() + out["l0"] += l0.item() + out["frac_variance_explained"] += frac_variance_explained.item() + out["cossim"] += cossim.item() + out["l2_ratio"] += l2_ratio.item() + out['relative_reconstruction_bias'] += relative_reconstruction_bias.item() if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): - return out + continue # compute loss recovered loss_original, loss_reconstructed, loss_zero = loss_recovered( @@ -218,9 +231,13 @@ def evaluate( ) frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) - out["loss_original"] = loss_original.item() - out["loss_reconstructed"] = loss_reconstructed.item() - out["loss_zero"] = loss_zero.item() - out["frac_recovered"] = frac_recovered.item() + out["loss_original"] += loss_original.item() + out["loss_reconstructed"] += loss_reconstructed.item() + out["loss_zero"] += loss_zero.item() + out["frac_recovered"] += frac_recovered.item() + + out = {key: value / n_batches for key, value in out.items()} + frac_alive = (active_features != 0).float().sum() / dictionary.dict_size + out["frac_alive"] = frac_alive.item() - return out + return out \ No newline at end of file diff --git a/grad_pursuit.py b/dictionary_learning/grad_pursuit.py similarity index 100% rename from grad_pursuit.py rename to dictionary_learning/grad_pursuit.py diff --git a/interp.py b/dictionary_learning/interp.py similarity index 98% rename from interp.py rename to dictionary_learning/interp.py index 283965b..18ac308 100644 --- a/interp.py +++ b/dictionary_learning/interp.py @@ -101,7 +101,7 @@ def _list_decode(x): inputs = buffer.tokenized_batch(batch_size=n_inputs) with t.no_grad(), model.trace(inputs, **tracer_kwargs): - tokens = model.input[1][ + tokens = model.inputs[1][ "input_ids" ].save() # if you're getting errors, check here; might only work for pythia models activations = submodule.output @@ -188,4 +188,4 @@ def feature_umap( hover_name=df.index, color=colors, ) - raise ValueError("n_components must be 2 or 3") + raise ValueError("n_components must be 2 or 3") \ No newline at end of file diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py new file mode 100644 index 0000000..9d943e4 --- /dev/null +++ b/dictionary_learning/pytorch_buffer.py @@ -0,0 +1,225 @@ +import torch as t +from transformers import AutoModelForCausalLM, AutoTokenizer +import gc +from tqdm import tqdm +import contextlib + + +class EarlyStopException(Exception): + """Custom exception for stopping model forward pass early.""" + + pass + + +def collect_activations( + model: AutoModelForCausalLM, + submodule: t.nn.Module, + inputs_BL: dict[str, t.Tensor], + use_no_grad: bool = True, +) -> t.Tensor: + """ + Registers a forward hook on the submodule to capture the residual (or hidden) + activations. We then raise an EarlyStopException to skip unneeded computations. + + Args: + model: The model to run. + submodule: The submodule to hook into. + inputs_BL: The inputs to the model. + use_no_grad: Whether to run the forward pass within a `t.no_grad()` context. Defaults to True. + """ + activations_BLD = None + + def gather_target_act_hook(module, inputs, outputs): + nonlocal activations_BLD + # For many models, the submodule outputs are a tuple or a single tensor: + # If "outputs" is a tuple, pick the relevant item: + # e.g. if your layer returns (hidden, something_else), you'd do outputs[0] + # Otherwise just do outputs + if isinstance(outputs, tuple): + activations_BLD = outputs[0] + else: + activations_BLD = outputs + + raise EarlyStopException("Early stopping after capturing activations") + + handle = submodule.register_forward_hook(gather_target_act_hook) + + # Determine the context manager based on the flag + context_manager = t.no_grad() if use_no_grad else contextlib.nullcontext() + + try: + # Use the selected context manager + with context_manager: + _ = model(**inputs_BL) + except EarlyStopException: + pass + except Exception as e: + print(f"Unexpected error during forward pass: {str(e)}") + raise + finally: + handle.remove() + + if activations_BLD is None: + # This should ideally not happen if the hook worked and EarlyStopException was raised, + # but handle it just in case. + raise RuntimeError( + "Failed to collect activations. The hook might not have run correctly." + ) + + return activations_BLD + + +class ActivationBuffer: + """ + Implements a buffer of activations. The buffer stores activations from a model, + yields them in batches, and refreshes them when the buffer is less than half full. + """ + + def __init__( + self, + data, # generator which yields text data + model: AutoModelForCausalLM, # Language Model from which to extract activations + submodule, # submodule of the model from which to extract activations + d_submodule=None, # submodule dimension; if None, try to detect automatically + io="out", # can be 'in' or 'out'; whether to extract input or output activations + n_ctxs=3e4, # approximate number of contexts to store in the buffer + ctx_len=128, # length of each context + refresh_batch_size=512, # size of batches in which to process the data when adding to buffer + out_batch_size=8192, # size of batches in which to yield activations + device="cpu", # device on which to store the activations + remove_bos: bool = False, + add_special_tokens: bool = True, + ): + if io not in ["in", "out"]: + raise ValueError("io must be either 'in' or 'out'") + + if d_submodule is None: + try: + if io == "in": + d_submodule = submodule.in_features + else: + d_submodule = submodule.out_features + except: + raise ValueError( + "d_submodule cannot be inferred and must be specified directly" + ) + self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype) + self.read = t.zeros(0).bool() + + self.data = data + self.model = model + self.submodule = submodule + self.d_submodule = d_submodule + self.io = io + self.n_ctxs = n_ctxs + self.ctx_len = ctx_len + self.activation_buffer_size = n_ctxs * ctx_len + self.refresh_batch_size = refresh_batch_size + self.out_batch_size = out_batch_size + self.device = device + self.add_special_tokens = add_special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) + self.remove_bos = remove_bos and (self.tokenizer.bos_token_id is not None) + + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def __iter__(self): + return self + + def __next__(self): + """ + Return a batch of activations + """ + with t.no_grad(): + # if buffer is less than half full, refresh + if (~self.read).sum() < self.activation_buffer_size // 2: + self.refresh() + + # return a batch + unreads = (~self.read).nonzero().squeeze() + idxs = unreads[ + t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size] + ] + self.read[idxs] = True + return self.activations[idxs] + + def text_batch(self, batch_size=None): + """ + Return a list of text + """ + if batch_size is None: + batch_size = self.refresh_batch_size + try: + return [next(self.data) for _ in range(batch_size)] + except StopIteration: + raise StopIteration("End of data stream reached") + + def tokenized_batch(self, batch_size=None): + """ + Return a batch of tokenized inputs. + """ + texts = self.text_batch(batch_size=batch_size) + return self.tokenizer( + texts, + return_tensors="pt", + max_length=self.ctx_len, + padding=True, + truncation=True, + add_special_tokens=self.add_special_tokens, + ).to(self.device) + + def refresh(self): + gc.collect() + t.cuda.empty_cache() + self.activations = self.activations[~self.read] + + current_idx = len(self.activations) + new_activations = t.empty( + self.activation_buffer_size, + self.d_submodule, + device=self.device, + dtype=self.model.dtype, + ) + + new_activations[: len(self.activations)] = self.activations + self.activations = new_activations + + # Optional progress bar when filling buffer. At larger models / buffer sizes (e.g. gemma-2-2b, 1M tokens on a 4090) this can take a couple minutes. + # pbar = tqdm(total=self.activation_buffer_size, initial=current_idx, desc="Refreshing activations") + + while current_idx < self.activation_buffer_size: + with t.no_grad(): + input = self.tokenized_batch() + hidden_states = collect_activations(self.model, self.submodule, input) + mask = (input["attention_mask"] != 0) + if self.remove_bos: + bos_mask = (input["input_ids"] == self.tokenizer.bos_token_id) + mask = mask & ~bos_mask + hidden_states = hidden_states[mask] + + remaining_space = self.activation_buffer_size - current_idx + assert remaining_space > 0 + hidden_states = hidden_states[:remaining_space] + + self.activations[current_idx : current_idx + len(hidden_states)] = ( + hidden_states.to(self.device) + ) + current_idx += len(hidden_states) + + # pbar.update(len(hidden_states)) + + # pbar.close() + self.read = t.zeros(len(self.activations), dtype=t.bool, device=self.device) + + @property + def config(self): + return { + "d_submodule": self.d_submodule, + "io": self.io, + "n_ctxs": self.n_ctxs, + "ctx_len": self.ctx_len, + "refresh_batch_size": self.refresh_batch_size, + "out_batch_size": self.out_batch_size, + "device": self.device, + } diff --git a/dictionary_learning/trainers/__init__.py b/dictionary_learning/trainers/__init__.py new file mode 100644 index 0000000..4135a82 --- /dev/null +++ b/dictionary_learning/trainers/__init__.py @@ -0,0 +1,19 @@ +from .standard import StandardTrainer +from .gdm import GatedSAETrainer +from .p_anneal import PAnnealTrainer +from .gated_anneal import GatedAnnealTrainer +from .top_k import TopKTrainer +from .jumprelu import JumpReluTrainer +from .batch_top_k import BatchTopKTrainer, BatchTopKSAE + + +__all__ = [ + "StandardTrainer", + "GatedSAETrainer", + "PAnnealTrainer", + "GatedAnnealTrainer", + "TopKTrainer", + "JumpReluTrainer", + "BatchTopKTrainer", + "BatchTopKSAE", +] diff --git a/dictionary_learning/trainers/batch_top_k.py b/dictionary_learning/trainers/batch_top_k.py new file mode 100644 index 0000000..686dc0a --- /dev/null +++ b/dictionary_learning/trainers/batch_top_k.py @@ -0,0 +1,312 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import einops +from collections import namedtuple +from typing import Optional + +from ..dictionary import Dictionary +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) + + +class BatchTopKSAE(Dictionary, nn.Module): + def __init__(self, activation_dim: int, dict_size: int, k: int): + super().__init__() + self.activation_dim = activation_dim + self.dict_size = dict_size + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) + + self.decoder = nn.Linear(dict_size, activation_dim, bias=False) + self.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.decoder.weight, activation_dim, dict_size + ) + + self.encoder = nn.Linear(activation_dim, dict_size) + self.encoder.weight.data = self.decoder.weight.T.clone() + self.encoder.bias.data.zero_() + self.b_dec = nn.Parameter(t.zeros(activation_dim)) + + def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): + post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + else: + # Flatten and perform batch top-k + flattened_acts = post_relu_feat_acts_BF.flatten() + post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) + + encoded_acts_BF = ( + t.zeros_like(post_relu_feat_acts_BF.flatten()) + .scatter_(-1, post_topk.indices, post_topk.values) + .reshape(post_relu_feat_acts_BF.shape) + ) + + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF + else: + return encoded_acts_BF + + def decode(self, x: t.Tensor) -> t.Tensor: + return self.decoder(x) + self.b_dec + + def forward(self, x: t.Tensor, output_features: bool = False): + encoded_acts_BF = self.encode(x) + x_hat_BD = self.decode(encoded_acts_BF) + + if not output_features: + return x_hat_BD + else: + return x_hat_BD, encoded_acts_BF + + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + @classmethod + def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": + state_dict = t.load(path) + dict_size, activation_dim = state_dict["encoder.weight"].shape + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + autoencoder = cls(activation_dim, dict_size, k) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder + + +class BatchTopKTrainer(SAETrainer): + def __init__( + self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + k: int, + layer: int, + lm_name: str, + dict_class: type = BatchTopKSAE, + lr: Optional[float] = None, + auxk_alpha: float = 1 / 32, + warmup_steps: int = 1000, + decay_start: Optional[int] = None, # when does the lr decay start + threshold_beta: float = 0.999, + threshold_start_step: int = 1000, + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: str = "BatchTopKSAE", + submodule_name: Optional[str] = None, + ): + super().__init__(seed) + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + self.wandb_name = wandb_name + self.steps = steps + self.decay_start = decay_start + self.warmup_steps = warmup_steps + self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + self.ae = dict_class(activation_dim, dict_size, k) + + if device is None: + self.device = "cuda" if t.cuda.is_available() else "cpu" + else: + self.device = device + self.ae.to(self.device) + + if lr is not None: + self.lr = lr + else: + # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 + + self.auxk_alpha = auxk_alpha + self.dead_feature_threshold = 10_000_000 + self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper + self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.effective_l0 = -1 + self.dead_features = -1 + self.pre_norm_auxk_loss = -1 + + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) + + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): + dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + self.dead_features = int(dead_features.sum()) + + if dead_features.sum() > 0: + k_aux = min(self.top_k_aux, dead_features.sum()) + + auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) + + # Top-k dead latents + auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) + + auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) + auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + + # Note: decoder(), not decode(), as we don't want to apply the bias + x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) + l2_loss_aux = ( + (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + ) + + self.pre_norm_auxk_loss = l2_loss_aux + + # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) + loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + normalized_auxk_loss = l2_loss_aux / loss_denom + + return normalized_auxk_loss.nan_to_num(0.0) + else: + self.pre_norm_auxk_loss = -1 + return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) + + def update_threshold(self, f: t.Tensor): + device_type = "cuda" if f.is_cuda else "cpu" + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach().to(dtype=t.float32) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + + def loss(self, x, step=None, logging=False): + f, active_indices_F, post_relu_acts_BF = self.ae.encode( + x, return_active=True, use_threshold=False + ) + # l0 = (f != 0).float().sum(dim=-1).mean().item() + + if step > self.threshold_start_step: + self.update_threshold(f) + + x_hat = self.ae.decode(f) + + e = x - x_hat + + self.effective_l0 = self.k + + num_tokens_in_step = x.size(0) + did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) + did_fire[active_indices_F] = True + self.num_tokens_since_fired += num_tokens_in_step + self.num_tokens_since_fired[did_fire] = 0 + + l2_loss = e.pow(2).sum(dim=-1).mean() + auxk_loss = self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) + loss = l2_loss + self.auxk_alpha * auxk_loss + + if not logging: + return loss + else: + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_hat, + f, + {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, + ) + + def update(self, step, x): + if step == 0: + median = self.geometric_median(x) + median = median.to(self.ae.b_dec.dtype) + self.ae.b_dec.data = median + + x = x.to(self.device) + loss = self.loss(x, step=step) + loss.backward() + + self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.decoder.weight, + self.ae.decoder.weight.grad, + self.ae.activation_dim, + self.ae.dict_size, + ) + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) + + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + # Make sure the decoder is still unit-norm + self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size + ) + + return loss.item() + + @property + def config(self): + return { + "trainer_class": "BatchTopKTrainer", + "dict_class": "BatchTopKSAE", + "lr": self.lr, + "steps": self.steps, + "auxk_alpha": self.auxk_alpha, + "warmup_steps": self.warmup_steps, + "decay_start": self.decay_start, + "threshold_beta": self.threshold_beta, + "threshold_start_step": self.threshold_start_step, + "top_k_aux": self.top_k_aux, + "seed": self.seed, + "activation_dim": self.ae.activation_dim, + "dict_size": self.ae.dict_size, + "k": self.ae.k.item(), + "device": self.device, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, + } + + @staticmethod + def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): + guess = points.mean(dim=0) + prev = t.zeros_like(guess) + weights = t.ones(len(points), device=points.device) + + for _ in range(max_iter): + prev = guess + weights = 1 / t.norm(points - guess, dim=1) + weights /= weights.sum() + guess = (weights.unsqueeze(1) * points).sum(dim=0) + if t.norm(guess - prev) < tol: + break + + return guess diff --git a/trainers/gated_anneal.py b/dictionary_learning/trainers/gated_anneal.py similarity index 77% rename from trainers/gated_anneal.py rename to dictionary_learning/trainers/gated_anneal.py index 664904b..09f69e6 100644 --- a/trainers/gated_anneal.py +++ b/dictionary_learning/trainers/gated_anneal.py @@ -3,56 +3,40 @@ """ import torch as t -from ..trainers.trainer import SAETrainer +from typing import Optional + +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG from ..dictionary import GatedAutoEncoder from collections import namedtuple -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr, betas=(0, 0.999)) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class GatedAnnealTrainer(SAETrainer): """ Gated SAE training scheme with p-annealing. """ def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=3e-4, - warmup_steps=1000, # lr warmup period at start of training and after each resample - sparsity_function='Lp^p', # Lp or Lp^p - initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer - anneal_start=15000, # step at which to start annealing p - anneal_end=None, # step at which to stop annealing, defaults to steps-1 - p_start=1, # starting value of p (constant throughout warmup) - p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded - n_sparsity_updates = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times - sparsity_queue_length = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty - resample_steps=None, # number of steps after which to resample dead neurons - steps=None, # total number of steps to train for - device=None, - seed=42, - layer=None, - lm_name=None, - wandb_name='GatedAnnealTrainer', + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class: type = GatedAutoEncoder, + lr: float = 3e-4, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training + decay_start: Optional[int] = None, # decay learning rate after this many steps + sparsity_function: str = 'Lp^p', # Lp or Lp^p + initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer + anneal_start: int = 15000, # step at which to start annealing p + anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1 + p_start: float = 1, # starting value of p (constant throughout warmup) + p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded + n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times + sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty + resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons + device: Optional[str] = None, + seed: Optional[int] = 42, + wandb_name: str = 'GatedAnnealTrainer', ): super().__init__(seed) @@ -98,6 +82,8 @@ def __init__(self, self.sparsity_queue = [] self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.steps = steps self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] self.seed = seed @@ -110,14 +96,11 @@ def __init__(self, else: self.steps_since_active = None - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if resample_steps is None: - def warmup_fn(step): - return min(step / warmup_steps, 1.) - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr, betas=(0.0, 0.999)) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) def resample_neurons(self, deads, activations): with t.no_grad(): @@ -160,7 +143,8 @@ def lp_norm(self, f, p): else: raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") - def loss(self, x, step, logging=False, **kwargs): + def loss(self, x:t.Tensor, step:int, logging=False, **kwargs): + sparsity_scale = self.sparsity_warmup_fn(step) f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() @@ -170,7 +154,7 @@ def loss(self, x, step, logging=False, **kwargs): fs = f_gate # feature activation that we use for sparsity term lp_loss = self.lp_norm(fs, self.p) - scaled_lp_loss = lp_loss * self.sparsity_coeff + scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale self.lp_loss = lp_loss self.scaled_lp_loss = scaled_lp_loss @@ -263,6 +247,8 @@ def config(self): 'n_sparsity_updates' : self.n_sparsity_updates, 'warmup_steps' : self.warmup_steps, 'resample_steps' : self.resample_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'decay_start' : self.decay_start, 'steps' : self.steps, 'seed' : self.seed, 'layer' : self.layer, diff --git a/trainers/gdm.py b/dictionary_learning/trainers/gdm.py similarity index 62% rename from trainers/gdm.py rename to dictionary_learning/trainers/gdm.py index 47ea772..ecab59c 100644 --- a/trainers/gdm.py +++ b/dictionary_learning/trainers/gdm.py @@ -3,49 +3,33 @@ """ import torch as t -from ..trainers.trainer import SAETrainer +from typing import Optional + +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG from ..dictionary import GatedAutoEncoder from collections import namedtuple -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr, betas=(0, 0.999)) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class GatedSAETrainer(SAETrainer): """ Gated SAE training scheme. """ def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=5e-5, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='GatedSAETrainer', - submodule_name=None, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class = GatedAutoEncoder, + lr: float = 5e-5, + l1_penalty: float = 1e-1, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: Optional[int] = 2000, + decay_start:Optional[int]=None, # decay learning rate after this many steps + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: Optional[str] = 'GatedSAETrainer', + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -64,6 +48,8 @@ def __init__(self, self.lr = lr self.l1_penalty=l1_penalty self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.wandb_name = wandb_name if device is None: @@ -75,13 +61,20 @@ def __init__(self, self.optimizer = ConstrainedAdam( self.ae.parameters(), self.ae.decoder.parameters(), - lr=lr + lr=lr, + betas=(0.0, 0.999), ) - def warmup_fn(step): - return min(1, step / warmup_steps) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) - def loss(self, x, logging=False, **kwargs): + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps) + + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) + + + def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs): + + sparsity_scale = self.sparsity_warmup_fn(step) + f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() @@ -90,7 +83,7 @@ def loss(self, x, logging=False, **kwargs): L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean() L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() - loss = L_recon + self.l1_penalty * L_sparse + L_aux + loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux if not logging: return loss @@ -108,7 +101,7 @@ def loss(self, x, logging=False, **kwargs): def update(self, step, x): x = x.to(self.device) self.optimizer.zero_grad() - loss = self.loss(x) + loss = self.loss(x, step) loss.backward() self.optimizer.step() self.scheduler.step() @@ -123,6 +116,9 @@ def config(self): 'lr' : self.lr, 'l1_penalty' : self.l1_penalty, 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'decay_start' : self.decay_start, + 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, 'lm_name' : self.lm_name, diff --git a/trainers/jumprelu.py b/dictionary_learning/trainers/jumprelu.py similarity index 52% rename from trainers/jumprelu.py rename to dictionary_learning/trainers/jumprelu.py index f87785a..2ac717a 100644 --- a/trainers/jumprelu.py +++ b/dictionary_learning/trainers/jumprelu.py @@ -3,9 +3,16 @@ import torch import torch.autograd as autograd from torch import nn +from typing import Optional from ..dictionary import Dictionary, JumpReluAutoEncoder -from .trainer import SAETrainer +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + get_sparsity_warmup_fn, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) class RectangleFunction(autograd.Function): @@ -53,36 +60,38 @@ def backward(ctx, grad_output): bandwidth = bandwidth_tensor.item() x_grad = torch.zeros_like(x) threshold_grad = ( - -(1.0 / bandwidth) - * RectangleFunction.apply((x - threshold) / bandwidth) - * grad_output + -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output ) return x_grad, threshold_grad, None # None for bandwidth -class TrainerJumpRelu(nn.Module, SAETrainer): +class JumpReluTrainer(nn.Module, SAETrainer): """ Trains a JumpReLU autoencoder. Note does not use learning rate or sparsity scheduling as in the paper. """ + def __init__( self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, dict_class=JumpReluAutoEncoder, - activation_dim=512, - dict_size=8192, - steps=30000, - # XXX: Training decay is not implemented - seed=None, + seed: Optional[int] = None, # TODO: What's the default lr use in the paper? - lr=7e-5, - bandwidth=0.001, - sparsity_penalty=0.1, - device="cpu", - layer=None, - lm_name=None, - wandb_name="JumpRelu", - submodule_name=None, + lr: float = 7e-5, + bandwidth: float = 0.001, + sparsity_penalty: float = 1.0, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training + decay_start: Optional[int] = None, # decay learning rate after this many steps + target_l0: float = 20.0, + device: str = "cpu", + wandb_name: str = "JumpRelu", + submodule_name: Optional[str] = None, ): super().__init__() @@ -99,6 +108,10 @@ def __init__( self.bandwidth = bandwidth self.sparsity_coefficient = sparsity_penalty + self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start + self.target_l0 = target_l0 # TODO: Better auto-naming (e.g. in BatchTopK package) self.wandb_name = wandb_name @@ -111,19 +124,54 @@ def __init__( ).to(self.device) # Parameters from the paper - self.optimizer = torch.optim.Adam( - self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8 + self.optimizer = torch.optim.Adam(self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8) + + lr_fn = get_lr_schedule( + steps, + warmup_steps, + decay_start, + resample_steps=None, + sparsity_warmup_steps=sparsity_warmup_steps, ) - self.logging_parameters = [] + self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) + + # Purely for logging purposes + self.dead_feature_threshold = 10_000_000 + self.num_tokens_since_fired = torch.zeros(dict_size, dtype=torch.long, device=device) + self.dead_features = -1 + self.logging_parameters = ["dead_features"] + + def loss(self, x: torch.Tensor, step: int, logging=False, **_): + # Note: We are using threshold, not log_threshold as in this notebook: + # https://colab.research.google.com/drive/1PlFzI_PWGTN9yCQLuBcSuPJUjgHL7GiD#scrollTo=yP828a6uIlSO + # I had poor results when using log_threshold and it would complicate the scale_biases() function + + sparsity_scale = self.sparsity_warmup_fn(step) + x = x.to(self.ae.W_enc.dtype) + + pre_jump = x @ self.ae.W_enc + self.ae.b_enc + f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) + + active_indices = f.sum(0) > 0 + did_fire = torch.zeros_like(self.num_tokens_since_fired, dtype=torch.bool) + did_fire[active_indices] = True + self.num_tokens_since_fired += x.size(0) + self.num_tokens_since_fired[active_indices] = 0 + self.dead_features = ( + (self.num_tokens_since_fired > self.dead_feature_threshold).sum().item() + ) - def loss(self, x, logging=False, **_): - f = self.ae.encode(x) recon = self.ae.decode(f) recon_loss = (x - recon).pow(2).sum(dim=-1).mean() l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() - sparsity_loss = self.sparsity_coefficient * l0 + + sparsity_loss = ( + self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale + ) loss = recon_loss + sparsity_loss if not logging: @@ -144,16 +192,27 @@ def update(self, step, x): loss = self.loss(x, step=step) loss.backward() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size + ).T torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.optimizer.step() + self.scheduler.step() self.optimizer.zero_grad() + + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.data = set_decoder_norm_to_unit_norm( + self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size + ).T + return loss.item() @property def config(self): return { - "trainer_class": "TrainerJumpRelu", + "trainer_class": "JumpReluTrainer", "dict_class": "JumpReluAutoEncoder", "lr": self.lr, "steps": self.steps, @@ -165,4 +224,8 @@ def config(self): "lm_name": self.lm_name, "wandb_name": self.wandb_name, "submodule_name": self.submodule_name, + "bandwidth": self.bandwidth, + "sparsity_penalty": self.sparsity_coefficient, + "sparsity_warmup_steps": self.sparsity_warmup_steps, + "target_l0": self.target_l0, } diff --git a/dictionary_learning/trainers/matryoshka_batch_top_k.py b/dictionary_learning/trainers/matryoshka_batch_top_k.py new file mode 100644 index 0000000..67647fb --- /dev/null +++ b/dictionary_learning/trainers/matryoshka_batch_top_k.py @@ -0,0 +1,388 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import einops +from collections import namedtuple +from typing import Optional +from math import isclose + +from ..dictionary import Dictionary +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) + + +def apply_temperature(probabilities: list[float], temperature: float) -> list[float]: + """ + Apply temperature scaling to a list of probabilities using PyTorch. + + Args: + probabilities (list[float]): Initial probability distribution + temperature (float): Temperature parameter (> 0) + + Returns: + list[float]: Scaled and normalized probabilities + """ + probs_tensor = t.tensor(probabilities, dtype=t.float32) + logits = t.log(probs_tensor) + scaled_logits = logits / temperature + scaled_probs = t.nn.functional.softmax(scaled_logits, dim=0) + + return scaled_probs.tolist() + + +class MatryoshkaBatchTopKSAE(Dictionary, nn.Module): + def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int]): + super().__init__() + self.activation_dim = activation_dim + self.dict_size = dict_size + + assert sum(group_sizes) == dict_size, "group sizes must sum to dict_size" + assert all(s > 0 for s in group_sizes), "all group sizes must be positive" + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) + + self.active_groups = len(group_sizes) + group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0)) + self.group_indices = group_indices + + self.register_buffer("group_sizes", t.tensor(group_sizes)) + + self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) + self.b_enc = nn.Parameter(t.zeros(dict_size)) + self.W_dec = nn.Parameter(t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim))) + self.b_dec = nn.Parameter(t.zeros(activation_dim)) + + # We must transpose because we are using nn.Parameter, not nn.Linear + self.W_dec.data = set_decoder_norm_to_unit_norm( + self.W_dec.data.T, activation_dim, dict_size + ).T + self.W_enc.data = self.W_dec.data.clone().T + + def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): + post_relu_feat_acts_BF = nn.functional.relu((x - self.b_dec) @ self.W_enc + self.b_enc) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + else: + # Flatten and perform batch top-k + flattened_acts = post_relu_feat_acts_BF.flatten() + post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) + + encoded_acts_BF = ( + t.zeros_like(post_relu_feat_acts_BF.flatten()) + .scatter_(-1, post_topk.indices, post_topk.values) + .reshape(post_relu_feat_acts_BF.shape) + ) + + max_act_index = self.group_indices[self.active_groups] + encoded_acts_BF[:, max_act_index:] = 0 + + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF + else: + return encoded_acts_BF + + def decode(self, x: t.Tensor) -> t.Tensor: + return x @ self.W_dec + self.b_dec + + def forward(self, x: t.Tensor, output_features: bool = False): + encoded_acts_BF = self.encode(x) + x_hat_BD = self.decode(encoded_acts_BF) + + if not output_features: + return x_hat_BD + else: + return x_hat_BD, encoded_acts_BF + + @t.no_grad() + def scale_biases(self, scale: float): + self.b_enc.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + @classmethod + def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatryoshkaBatchTopKSAE": + state_dict = t.load(path) + activation_dim, dict_size = state_dict["W_enc"].shape + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + group_sizes = state_dict["group_sizes"].tolist() + + autoencoder = cls(activation_dim, dict_size, k=k, group_sizes=group_sizes) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder + + +class MatryoshkaBatchTopKTrainer(SAETrainer): + def __init__( + self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + k: int, + layer: int, + lm_name: str, + group_fractions: list[float], + group_weights: Optional[list[float]] = None, + dict_class: type = MatryoshkaBatchTopKSAE, + lr: Optional[float] = None, + auxk_alpha: float = 1 / 32, + warmup_steps: int = 1000, + decay_start: Optional[int] = None, # when does the lr decay start + threshold_beta: float = 0.999, + threshold_start_step: int = 1000, + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: str = "BatchTopKSAE", + submodule_name: Optional[str] = None, + ): + super().__init__(seed) + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + self.wandb_name = wandb_name + self.steps = steps + self.decay_start = decay_start + self.warmup_steps = warmup_steps + self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + assert isclose(sum(group_fractions), 1.0), "group_fractions must sum to 1.0" + # Calculate all groups except the last one + group_sizes = [int(f * dict_size) for f in group_fractions[:-1]] + # Put remainder in the last group + group_sizes.append(dict_size - sum(group_sizes)) + + if group_weights is None: + group_weights = [(1.0 / len(group_sizes))] * len(group_sizes) + + assert len(group_sizes) == len(group_weights), ( + "group_sizes and group_weights must have the same length" + ) + + self.group_fractions = group_fractions + self.group_sizes = group_sizes + self.group_weights = group_weights + + self.ae = dict_class(activation_dim, dict_size, k, group_sizes) + + if device is None: + self.device = "cuda" if t.cuda.is_available() else "cpu" + else: + self.device = device + self.ae.to(self.device) + + if lr is not None: + self.lr = lr + else: + # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 + self.auxk_alpha = auxk_alpha + self.dead_feature_threshold = 10_000_000 + self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper + + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.effective_l0 = -1 + self.dead_features = -1 + self.pre_norm_auxk_loss = -1 + + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): + dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + self.dead_features = int(dead_features.sum()) + + if self.dead_features > 0: + k_aux = min(self.top_k_aux, self.dead_features) + + auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) + + # Top-k dead latents + auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) + + auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) + auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + + # We don't want to apply the bias + x_reconstruct_aux = auxk_acts_BF @ self.ae.W_dec + l2_loss_aux = ( + (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + ) + + self.pre_norm_auxk_loss = l2_loss_aux + + # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) + loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + normalized_auxk_loss = l2_loss_aux / loss_denom + + return normalized_auxk_loss.nan_to_num(0.0) + else: + self.pre_norm_auxk_loss = -1 + return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) + + def update_threshold(self, f: t.Tensor): + device_type = "cuda" if f.is_cuda else "cpu" + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach().to(dtype=t.float32) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + + def loss(self, x, step=None, logging=False): + f, active_indices_F, post_relu_acts_BF = self.ae.encode( + x, return_active=True, use_threshold=False + ) + # l0 = (f != 0).float().sum(dim=-1).mean().item() + + if step > self.threshold_start_step: + self.update_threshold(f) + + x_reconstruct = t.zeros_like(x) + self.ae.b_dec + total_l2_loss = 0.0 + l2_losses = t.tensor([]).to(self.device) + + # We could potentially refactor the ae class to use W_dec_chunks instead of W_dec, may be more efficient + W_dec_chunks = t.split(self.ae.W_dec, self.ae.group_sizes.tolist(), dim=0) + f_chunks = t.split(f, self.ae.group_sizes.tolist(), dim=1) + + for i in range(self.ae.active_groups): + W_dec_slice = W_dec_chunks[i] + acts_slice = f_chunks[i] + x_reconstruct = x_reconstruct + acts_slice @ W_dec_slice + + l2_loss = (x - x_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[i] + total_l2_loss += l2_loss + l2_losses = t.cat([l2_losses, l2_loss.unsqueeze(0)]) + + min_l2_loss = l2_losses.min().item() + max_l2_loss = l2_losses.max().item() + mean_l2_loss = l2_losses.mean() + + self.effective_l0 = self.k + + num_tokens_in_step = x.size(0) + did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) + did_fire[active_indices_F] = True + self.num_tokens_since_fired += num_tokens_in_step + self.num_tokens_since_fired[did_fire] = 0 + + auxk_loss = self.get_auxiliary_loss((x - x_reconstruct).detach(), post_relu_acts_BF) + loss = mean_l2_loss + self.auxk_alpha * auxk_loss + + if not logging: + return loss + else: + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_reconstruct, + f, + { + "l2_loss": mean_l2_loss.item(), + "auxk_loss": auxk_loss.item(), + "loss": loss.item(), + "min_l2_loss": min_l2_loss, + "max_l2_loss": max_l2_loss, + }, + ) + + def update(self, step, x): + if step == 0: + median = self.geometric_median(x) + self.ae.b_dec.data = median + + x = x.to(self.device) + loss = self.loss(x, step=step) + loss.backward() + + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size + ).T + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) + + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.data = set_decoder_norm_to_unit_norm( + self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size + ).T + + return loss.item() + + @property + def config(self): + return { + "trainer_class": "MatryoshkaBatchTopKTrainer", + "dict_class": "MatryoshkaBatchTopKSAE", + "lr": self.lr, + "steps": self.steps, + "auxk_alpha": self.auxk_alpha, + "warmup_steps": self.warmup_steps, + "decay_start": self.decay_start, + "threshold_beta": self.threshold_beta, + "threshold_start_step": self.threshold_start_step, + "top_k_aux": self.top_k_aux, + "seed": self.seed, + "activation_dim": self.ae.activation_dim, + "dict_size": self.ae.dict_size, + "group_fractions": self.group_fractions, + "group_weights": self.group_weights, + "group_sizes": self.group_sizes, + "k": self.ae.k.item(), + "device": self.device, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, + } + + @staticmethod + def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): + guess = points.mean(dim=0) + prev = t.zeros_like(guess) + weights = t.ones(len(points), device=points.device) + + for _ in range(max_iter): + prev = guess + weights = 1 / t.norm(points - guess, dim=1) + weights /= weights.sum() + guess = (weights.unsqueeze(1) * points).sum(dim=0) + if t.norm(guess - prev) < tol: + break + + return guess diff --git a/trainers/p_anneal.py b/dictionary_learning/trainers/p_anneal.py similarity index 74% rename from trainers/p_anneal.py rename to dictionary_learning/trainers/p_anneal.py index 4a157b9..de9bdff 100644 --- a/trainers/p_anneal.py +++ b/dictionary_learning/trainers/p_anneal.py @@ -1,60 +1,42 @@ import torch as t - +from typing import Optional """ Implements the standard SAE training scheme. """ from ..dictionary import AutoEncoder -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class PAnnealTrainer(SAETrainer): """ SAE training scheme with the option to anneal the sparsity parameter p. You can further choose to use Lp or Lp^p sparsity. """ def __init__(self, - dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - warmup_steps=1000, # lr warmup period at start of training and after each resample - sparsity_function='Lp', # Lp or Lp^p - initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer - anneal_start=15000, # step at which to start annealing p - anneal_end=None, # step at which to stop annealing, defaults to steps-1 - p_start=1, # starting value of p (constant throughout warmup) - p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded - n_sparsity_updates = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times - sparsity_queue_length = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty - resample_steps=None, # number of steps after which to resample dead neurons - steps=None, # total number of steps to train for - device=None, - seed=42, - layer=None, - lm_name=None, - wandb_name='PAnnealTrainer', - submodule_name: str = None, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class: type = AutoEncoder, + lr: float = 1e-3, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + decay_start: Optional[int] = None, # step at which to start decaying lr + sparsity_warmup_steps: Optional[int] = 2000, # number of steps to warm up sparsity penalty + sparsity_function: str = 'Lp', # Lp or Lp^p + initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer + anneal_start: int = 15000, # step at which to start annealing p + anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1 + p_start: float = 1, # starting value of p (constant throughout warmup) + p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded + n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times + sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty + resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons + device: Optional[str] = None, + seed: int = 42, + wandb_name: str = 'PAnnealTrainer', + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -98,6 +80,8 @@ def __init__(self, self.sparsity_queue = [] self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.steps = steps self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] self.seed = seed @@ -111,13 +95,11 @@ def __init__(self, self.steps_since_active = None self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if resample_steps is None: - def warmup_fn(step): - return min(step / warmup_steps, 1.) - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) if (self.sparsity_update_steps.unique(return_counts=True)[1] >1).any(): print("Warning! Duplicates om self.sparsity_update_steps detected!") @@ -163,12 +145,14 @@ def lp_norm(self, f, p): else: raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") - def loss(self, x, step, logging=False): + def loss(self, x: t.Tensor, step:int, logging=False): + sparsity_scale = self.sparsity_warmup_fn(step) + # Compute loss terms x_hat, f = self.ae(x, output_features=True) - l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() lp_loss = self.lp_norm(f, self.p) - scaled_lp_loss = lp_loss * self.sparsity_coeff + scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale self.lp_loss = lp_loss self.scaled_lp_loss = scaled_lp_loss @@ -201,7 +185,7 @@ def loss(self, x, step, logging=False): self.steps_since_active[~deads] = 0 if logging is False: - return l2_loss + scaled_lp_loss + return recon_loss + scaled_lp_loss else: loss_log = { 'p' : self.p, @@ -241,6 +225,8 @@ def config(self): 'sparsity_queue_length' : self.sparsity_queue_length, 'n_sparsity_updates' : self.n_sparsity_updates, 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps': self.sparsity_warmup_steps, + 'decay_start': self.decay_start, 'resample_steps' : self.resample_steps, 'steps' : self.steps, 'seed' : self.seed, diff --git a/dictionary_learning/trainers/standard.py b/dictionary_learning/trainers/standard.py new file mode 100644 index 0000000..a839737 --- /dev/null +++ b/dictionary_learning/trainers/standard.py @@ -0,0 +1,289 @@ +""" +Implements the standard SAE training scheme. +""" +import torch as t +from typing import Optional + +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam +from ..config import DEBUG +from ..dictionary import AutoEncoder +from collections import namedtuple + +class StandardTrainer(SAETrainer): + """ + Standard SAE training scheme following Towards Monosemanticity. Decoder column norms are constrained to 1. + """ + def __init__(self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class=AutoEncoder, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + decay_start:Optional[int]=None, # decay learning rate after this many steps + resample_steps:Optional[int]=None, # how often to resample neurons + seed:Optional[int]=None, + device=None, + wandb_name:Optional[str]='StandardTrainer', + submodule_name:Optional[str]=None, + ): + super().__init__(seed) + + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + # initialize dictionary + self.ae = dict_class(activation_dim, dict_size) + + self.lr = lr + self.l1_penalty=l1_penalty + self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.steps = steps + self.decay_start = decay_start + self.wandb_name = wandb_name + + if device is None: + self.device = 'cuda' if t.cuda.is_available() else 'cpu' + else: + self.device = device + self.ae.to(self.device) + + self.resample_steps = resample_steps + if self.resample_steps is not None: + # how many steps since each neuron was last activated? + self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) + else: + self.steps_since_active = None + + self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) + + def resample_neurons(self, deads, activations): + with t.no_grad(): + if deads.sum() == 0: return + print(f"resampling {deads.sum().item()} neurons") + + # compute loss for each activation + losses = (activations - self.ae(activations)).norm(dim=-1) + + # sample input to create encoder/decoder weights from + n_resample = min([deads.sum(), losses.shape[0]]) + indices = t.multinomial(losses, num_samples=n_resample, replacement=False) + sampled_vecs = activations[indices] + + # get norm of the living neurons + alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() + + # resample first n_resample dead neurons + deads[deads.nonzero()[n_resample:]] = False + self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 + self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T + self.ae.encoder.bias[deads] = 0. + + + # reset Adam parameters for dead neurons + state_dict = self.optimizer.state_dict()['state'] + ## encoder weight + state_dict[1]['exp_avg'][deads] = 0. + state_dict[1]['exp_avg_sq'][deads] = 0. + ## encoder bias + state_dict[2]['exp_avg'][deads] = 0. + state_dict[2]['exp_avg_sq'][deads] = 0. + ## decoder weight + state_dict[3]['exp_avg'][:,deads] = 0. + state_dict[3]['exp_avg_sq'][:,deads] = 0. + + def loss(self, x, step: int, logging=False, **kwargs): + + sparsity_scale = self.sparsity_warmup_fn(step) + + x_hat, f = self.ae(x, output_features=True) + l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() + l1_loss = f.norm(p=1, dim=-1).mean() + + if self.steps_since_active is not None: + # update steps_since_active + deads = (f == 0).all(dim=0) + self.steps_since_active[deads] += 1 + self.steps_since_active[~deads] = 0 + + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss + + if not logging: + return loss + else: + return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( + x, x_hat, f, + { + 'l2_loss' : l2_loss.item(), + 'mse_loss' : recon_loss.item(), + 'sparsity_loss' : l1_loss.item(), + 'loss' : loss.item() + } + ) + + + def update(self, step, activations): + activations = activations.to(self.device) + + self.optimizer.zero_grad() + loss = self.loss(activations, step=step) + loss.backward() + self.optimizer.step() + self.scheduler.step() + + if self.resample_steps is not None and step % self.resample_steps == 0: + self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + + @property + def config(self): + return { + 'dict_class': 'AutoEncoder', + 'trainer_class' : 'StandardTrainer', + 'activation_dim': self.ae.activation_dim, + 'dict_size': self.ae.dict_size, + 'lr' : self.lr, + 'l1_penalty' : self.l1_penalty, + 'warmup_steps' : self.warmup_steps, + 'resample_steps' : self.resample_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'steps' : self.steps, + 'decay_start' : self.decay_start, + 'seed' : self.seed, + 'device' : self.device, + 'layer' : self.layer, + 'lm_name' : self.lm_name, + 'wandb_name': self.wandb_name, + 'submodule_name': self.submodule_name, + } + + +class StandardTrainerAprilUpdate(SAETrainer): + """ + Standard SAE training scheme following the Anthropic April update. Decoder column norms are NOT constrained to 1. + This trainer does not support resampling or ghost gradients. This trainer will have fewer dead neurons than the standard trainer. + """ + def __init__(self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class=AutoEncoder, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + decay_start:Optional[int]=None, # decay learning rate after this many steps + seed:Optional[int]=None, + device=None, + wandb_name:Optional[str]='StandardTrainerAprilUpdate', + submodule_name:Optional[str]=None, + ): + super().__init__(seed) + + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + # initialize dictionary + self.ae = dict_class(activation_dim, dict_size) + + self.lr = lr + self.l1_penalty=l1_penalty + self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.steps = steps + self.decay_start = decay_start + self.wandb_name = wandb_name + + if device is None: + self.device = 'cuda' if t.cuda.is_available() else 'cpu' + else: + self.device = device + self.ae.to(self.device) + + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=lr) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) + + def loss(self, x, step: int, logging=False, **kwargs): + + sparsity_scale = self.sparsity_warmup_fn(step) + + x_hat, f = self.ae(x, output_features=True) + l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() + l1_loss = (f * self.ae.decoder.weight.norm(p=2, dim=0)).sum(dim=-1).mean() + + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss + + if not logging: + return loss + else: + return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( + x, x_hat, f, + { + 'l2_loss' : l2_loss.item(), + 'mse_loss' : recon_loss.item(), + 'sparsity_loss' : l1_loss.item(), + 'loss' : loss.item() + } + ) + + + def update(self, step, activations): + activations = activations.to(self.device) + + self.optimizer.zero_grad() + loss = self.loss(activations, step=step) + loss.backward() + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + + @property + def config(self): + return { + 'dict_class': 'AutoEncoder', + 'trainer_class' : 'StandardTrainerAprilUpdate', + 'activation_dim': self.ae.activation_dim, + 'dict_size': self.ae.dict_size, + 'lr' : self.lr, + 'l1_penalty' : self.l1_penalty, + 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'steps' : self.steps, + 'decay_start' : self.decay_start, + 'seed' : self.seed, + 'device' : self.device, + 'layer' : self.layer, + 'lm_name' : self.lm_name, + 'wandb_name': self.wandb_name, + 'submodule_name': self.submodule_name, + } + diff --git a/trainers/top_k.py b/dictionary_learning/trainers/top_k.py similarity index 53% rename from trainers/top_k.py rename to dictionary_learning/trainers/top_k.py index 33046f5..f6f5692 100644 --- a/trainers/top_k.py +++ b/dictionary_learning/trainers/top_k.py @@ -7,10 +7,16 @@ import torch as t import torch.nn as nn from collections import namedtuple +from typing import Optional from ..config import DEBUG from ..dictionary import Dictionary -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) @t.no_grad() @@ -58,19 +64,33 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size - self.k = k - self.encoder = nn.Linear(activation_dim, dict_size) - self.encoder.bias.data.zero_() + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.decoder.weight.data = self.encoder.weight.data.clone().T - self.set_decoder_norm_to_unit_norm() + self.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.decoder.weight, activation_dim, dict_size + ) + + self.encoder = nn.Linear(activation_dim, dict_size) + self.encoder.weight.data = self.decoder.weight.T.clone() + self.encoder.bias.data.zero_() self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_topk: bool = False): + def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + if return_topk: + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + return encoded_acts_BF, post_topk.values, post_topk.indices, post_relu_feat_acts_BF + else: + return encoded_acts_BF + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) # We can't split immediately due to nnsight @@ -81,7 +101,7 @@ def encode(self, x: t.Tensor, return_topk: bool = False): encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK) if return_topk: - return encoded_acts_BF, tops_acts_BK, top_indices_BK + return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF else: return encoded_acts_BF @@ -96,33 +116,24 @@ def forward(self, x: t.Tensor, output_features: bool = False): else: return x_hat_BD, encoded_acts_BF - @t.no_grad() - def set_decoder_norm_to_unit_norm(self): - eps = t.finfo(self.decoder.weight.dtype).eps - norm = t.norm(self.decoder.weight.data, dim=0, keepdim=True) - self.decoder.weight.data /= norm + eps - - @t.no_grad() - def remove_gradient_parallel_to_decoder_directions(self): - assert self.decoder.weight.grad is not None # keep pyright happy - - parallel_component = einops.einsum( - self.decoder.weight.grad, - self.decoder.weight.data, - "d_in d_sae, d_in d_sae -> d_sae", - ) - self.decoder.weight.grad -= einops.einsum( - parallel_component, - self.decoder.weight.data, - "d_sae, d_in d_sae -> d_in d_sae", - ) + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale - def from_pretrained(path, k: int, device=None): + def from_pretrained(path, k: Optional[int] = None, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) dict_size, activation_dim = state_dict["encoder.weight"].shape + + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) autoencoder.load_state_dict(state_dict) if device is not None: @@ -130,26 +141,30 @@ def from_pretrained(path, k: int, device=None): return autoencoder -class TrainerTopK(SAETrainer): +class TopKTrainer(SAETrainer): """ Top-K SAE training scheme. """ def __init__( self, - dict_class=AutoEncoderTopK, - activation_dim=512, - dict_size=64 * 512, - k=100, - auxk_alpha=1 / 32, # see Appendix A.2 - decay_start=24000, # when does the lr decay start - steps=30000, # when when does training end - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name="AutoEncoderTopK", - submodule_name=None, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + k: int, + layer: int, + lm_name: str, + dict_class: type = AutoEncoderTopK, + lr: Optional[float] = None, + auxk_alpha: float = 1 / 32, # see Appendix A.2 + warmup_steps: int = 1000, + decay_start: Optional[int] = None, # when does the lr decay start + threshold_beta: float = 0.999, + threshold_start_step: int = 1000, + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: str = "AutoEncoderTopK", + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -160,7 +175,12 @@ def __init__( self.wandb_name = wandb_name self.steps = steps + self.decay_start = decay_start + self.warmup_steps = warmup_steps self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step + if seed is not None: t.manual_seed(seed) t.cuda.manual_seed_all(seed) @@ -173,87 +193,110 @@ def __init__( self.device = device self.ae.to(self.device) - # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper - scale = dict_size / (2**14) - self.lr = 2e-4 / scale**0.5 + if lr is not None: + self.lr = lr + else: + # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 + self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 + self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper + self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.effective_l0 = -1 + self.dead_features = -1 + self.pre_norm_auxk_loss = -1 # Optimizer and scheduler self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) - def lr_fn(step): - if step < decay_start: - return 1.0 - else: - return (steps - step) / (steps - decay_start) + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - # Training parameters - self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): + dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + self.dead_features = int(dead_features.sum()) - # Log the effective L0, i.e. number of features actually used, which should a constant value (K) - # Note: The standard L0 is essentially a measure of dead features for Top-K SAEs) - self.logging_parameters = ["effective_l0", "dead_features"] - self.effective_l0 = -1 - self.dead_features = -1 + if self.dead_features > 0: + k_aux = min(self.top_k_aux, self.dead_features) + + auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) + + # Top-k dead latents + auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) + + auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) + auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + + # Note: decoder(), not decode(), as we don't want to apply the bias + x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) + l2_loss_aux = ( + (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + ) + + self.pre_norm_auxk_loss = l2_loss_aux + + # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) + loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + normalized_auxk_loss = l2_loss_aux / loss_denom + + return normalized_auxk_loss.nan_to_num(0.0) + else: + self.pre_norm_auxk_loss = -1 + return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) + + def update_threshold(self, top_acts_BK: t.Tensor): + device_type = "cuda" if top_acts_BK.is_cuda else "cpu" + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): + active = top_acts_BK.clone().detach() + active[active <= 0] = float("inf") + min_activations = active.min(dim=1).values.to(dtype=t.float32) + min_activation = min_activations.mean() + + B, K = active.shape + assert len(active.shape) == 2 + assert min_activations.shape == (B,) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) def loss(self, x, step=None, logging=False): # Run the SAE - f, top_acts, top_indices = self.ae.encode(x, return_topk=True) + f, top_acts_BK, top_indices_BK, post_relu_acts_BF = self.ae.encode( + x, return_topk=True, use_threshold=False + ) + + if step > self.threshold_start_step: + self.update_threshold(top_acts_BK) + x_hat = self.ae.decode(f) # Measure goodness of reconstruction - e = x_hat - x - total_variance = (x - x.mean(0)).pow(2).sum(0) + e = x - x_hat # Update the effective L0 (again, should just be K) - self.effective_l0 = top_acts.size(1) + self.effective_l0 = top_acts_BK.size(1) # Update "number of tokens since fired" for each features num_tokens_in_step = x.size(0) did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) - did_fire[top_indices.flatten()] = True + did_fire[top_indices_BK.flatten()] = True self.num_tokens_since_fired += num_tokens_in_step self.num_tokens_since_fired[did_fire] = 0 - # Compute dead feature mask based on "number of tokens since fired" - dead_mask = ( - self.num_tokens_since_fired > self.dead_feature_threshold - if self.auxk_alpha > 0 - else None - ).to(f.device) - self.dead_features = int(dead_mask.sum()) - - # If dead features: Second decoder pass for AuxK loss - if dead_mask is not None and (num_dead := int(dead_mask.sum())) > 0: - # Heuristic from Appendix B.1 in the paper - k_aux = x.shape[-1] // 2 - - # Reduce the scale of the loss if there are a small number of dead latents - scale = min(num_dead / k_aux, 1.0) - k_aux = min(k_aux, num_dead) - - # Don't include living latents in this loss - auxk_latents = t.where(dead_mask[None], f, -t.inf) - - # Top-k dead latents - auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) - - auxk_buffer_BF = t.zeros_like(f) - auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) - - # Encourage the top ~50% of dead latents to predict the residual of the - # top k living latents - e_hat = self.ae.decode(auxk_acts_BF) - auxk_loss = (e_hat - e).pow(2) # .sum(0) - auxk_loss = scale * t.mean(auxk_loss / total_variance) - else: - auxk_loss = x_hat.new_tensor(0.0) - l2_loss = e.pow(2).sum(dim=-1).mean() - auxk_loss = auxk_loss.sum(dim=-1).mean() + auxk_loss = ( + self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) if self.auxk_alpha > 0 else 0 + ) + loss = l2_loss + self.auxk_alpha * auxk_loss if not logging: @@ -270,37 +313,51 @@ def update(self, step, x): # Initialise the decoder bias if step == 0: median = geometric_median(x) + median = median.to(self.ae.b_dec.dtype) self.ae.b_dec.data = median - # Make sure the decoder is still unit-norm - self.ae.set_decoder_norm_to_unit_norm() - # compute the loss x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() # clip grad norm and remove grads parallel to decoder directions + self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.decoder.weight, + self.ae.decoder.weight.grad, + self.ae.activation_dim, + self.ae.dict_size, + ) t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) - self.ae.remove_gradient_parallel_to_decoder_directions() # do a training step self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + + # Make sure the decoder is still unit-norm + self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size + ) + return loss.item() @property def config(self): return { - "trainer_class": "TrainerTopK", + "trainer_class": "TopKTrainer", "dict_class": "AutoEncoderTopK", "lr": self.lr, "steps": self.steps, + "auxk_alpha": self.auxk_alpha, + "warmup_steps": self.warmup_steps, + "decay_start": self.decay_start, + "threshold_beta": self.threshold_beta, + "threshold_start_step": self.threshold_start_step, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, - "k": self.ae.k, + "k": self.ae.k.item(), "device": self.device, "layer": self.layer, "lm_name": self.lm_name, diff --git a/dictionary_learning/trainers/trainer.py b/dictionary_learning/trainers/trainer.py new file mode 100644 index 0000000..15eb4ed --- /dev/null +++ b/dictionary_learning/trainers/trainer.py @@ -0,0 +1,195 @@ +from typing import Optional, Callable +import torch +import einops + + +class SAETrainer: + """ + Generic class for implementing SAE training algorithms + """ + + def __init__(self, seed=None): + self.seed = seed + self.logging_parameters = [] + + def update( + self, + step, # index of step in training + activations, # of shape [batch_size, d_submodule] + ): + pass # implemented by subclasses + + def get_logging_parameters(self): + stats = {} + for param in self.logging_parameters: + if hasattr(self, param): + stats[param] = getattr(self, param) + else: + print(f"Warning: {param} not found in {self}") + return stats + + @property + def config(self): + return { + "wandb_name": "trainer", + } + + +class ConstrainedAdam(torch.optim.Adam): + """ + A variant of Adam where some of the parameters are constrained to have unit norm. + Note: This should be used with a decoder that is nn.Linear, not nn.Parameter. + If nn.Parameter, the dim argument to norm should be 1. + """ + + def __init__( + self, params, constrained_params, lr: float, betas: tuple[float, float] = (0.9, 0.999) + ): + super().__init__(params, lr=lr, betas=betas) + self.constrained_params = list(constrained_params) + + def step(self, closure=None): + with torch.no_grad(): + for p in self.constrained_params: + normed_p = p / p.norm(dim=0, keepdim=True) + # project away the parallel component of the gradient + p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p + super().step(closure=closure) + with torch.no_grad(): + for p in self.constrained_params: + # renormalize the constrained parameters + p /= p.norm(dim=0, keepdim=True) + + +# The next two functions could be replaced with the ConstrainedAdam Optimizer +@torch.no_grad() +def set_decoder_norm_to_unit_norm( + W_dec_DF: torch.nn.Parameter, activation_dim: int, d_sae: int +) -> torch.Tensor: + """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. + nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in + to catch this error.""" + + D, F = W_dec_DF.shape + + assert D == activation_dim + assert F == d_sae + + eps = torch.finfo(W_dec_DF.dtype).eps + norm = torch.norm(W_dec_DF.data, dim=0, keepdim=True) + W_dec_DF.data /= norm + eps + return W_dec_DF.data + + +@torch.no_grad() +def remove_gradient_parallel_to_decoder_directions( + W_dec_DF: torch.Tensor, + W_dec_DF_grad: torch.Tensor, + activation_dim: int, + d_sae: int, +) -> torch.Tensor: + """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. + nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in + to catch this error.""" + + D, F = W_dec_DF.shape + assert D == activation_dim + assert F == d_sae + + normed_W_dec_DF = W_dec_DF / (torch.norm(W_dec_DF, dim=0, keepdim=True) + 1e-6) + + parallel_component = einops.einsum( + W_dec_DF_grad, + normed_W_dec_DF, + "d_in d_sae, d_in d_sae -> d_sae", + ) + W_dec_DF_grad -= einops.einsum( + parallel_component, + normed_W_dec_DF, + "d_sae, d_in d_sae -> d_in d_sae", + ) + return W_dec_DF_grad + + +def get_lr_schedule( + total_steps: int, + warmup_steps: int, + decay_start: Optional[int] = None, + resample_steps: Optional[int] = None, + sparsity_warmup_steps: Optional[int] = None, +) -> Callable[[int], float]: + """ + Creates a learning rate schedule function with linear warmup followed by an optional decay phase. + + Note: resample_steps creates a repeating warmup pattern instead of the standard phases, but + is rarely used in practice. + + Args: + total_steps: Total number of training steps + warmup_steps: Steps for linear warmup from 0 to 1 + decay_start: Optional step to begin linear decay to 0 + resample_steps: Optional period for repeating warmup pattern + sparsity_warmup_steps: Used for validation with decay_start + + Returns: + Function that computes LR scale factor for a given step + """ + if decay_start is not None: + assert resample_steps is None, ( + "decay_start and resample_steps are currently mutually exclusive." + ) + assert 0 <= decay_start < total_steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert decay_start > sparsity_warmup_steps, ( + "decay_start must be > sparsity_warmup_steps." + ) + + assert 0 <= warmup_steps < total_steps, "warmup_steps must be >= 0 and < steps." + + if resample_steps is None: + + def lr_schedule(step: int) -> float: + if step < warmup_steps: + # Warm-up phase + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + # Decay phase + return (total_steps - step) / (total_steps - decay_start) + + # Constant phase + return 1.0 + else: + assert 0 < resample_steps < total_steps, "resample_steps must be > 0 and < steps." + + def lr_schedule(step: int) -> float: + return min((step % resample_steps) / warmup_steps, 1.0) + + return lr_schedule + + +def get_sparsity_warmup_fn( + total_steps: int, sparsity_warmup_steps: Optional[int] = None +) -> Callable[[int], float]: + """ + Return a function that computes a scale factor for sparsity penalty at a given step. + + If `sparsity_warmup_steps` is None or 0, returns 1.0 for all steps. + Otherwise, scales from 0.0 up to 1.0 across `sparsity_warmup_steps`. + """ + + if sparsity_warmup_steps is not None: + assert 0 <= sparsity_warmup_steps < total_steps, ( + "sparsity_warmup_steps must be >= 0 and < steps." + ) + + def scale_fn(step: int) -> float: + if not sparsity_warmup_steps: + # If it's None or zero, we just return 1.0 + return 1.0 + else: + # Gradually increase from 0.0 -> 1.0 as step goes from 0 -> sparsity_warmup_steps + return min(step / sparsity_warmup_steps, 1.0) + + return scale_fn diff --git a/dictionary_learning/training.py b/dictionary_learning/training.py new file mode 100644 index 0000000..0671f31 --- /dev/null +++ b/dictionary_learning/training.py @@ -0,0 +1,273 @@ +""" +Training dictionaries +""" + +import json +import torch.multiprocessing as mp +import os +from queue import Empty +from typing import Optional +from contextlib import nullcontext + +import torch as t +from tqdm import tqdm + +import wandb + +from .dictionary import AutoEncoder +from .evaluation import evaluate +from .trainers.standard import StandardTrainer + + +def new_wandb_process(config, log_queue, entity, project): + wandb.init(entity=entity, project=project, config=config, name=config["wandb_name"]) + while True: + try: + log = log_queue.get(timeout=1) + if log == "DONE": + break + wandb.log(log) + except Empty: + continue + wandb.finish() + + +def log_stats( + trainers, + step: int, + act: t.Tensor, + activations_split_by_head: bool, + transcoder: bool, + log_queues: list=[], + verbose: bool=False, +): + with t.no_grad(): + # quick hack to make sure all trainers get the same x + z = act.clone() + for i, trainer in enumerate(trainers): + log = {} + act = z.clone() + if activations_split_by_head: # x.shape: [batch, pos, n_heads, d_head] + act = act[..., i, :] + if not transcoder: + act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) + + # L0 + l0 = (f != 0).float().sum(dim=-1).mean().item() + # fraction of variance explained + total_variance = t.var(act, dim=0).sum() + residual_variance = t.var(act - act_hat, dim=0).sum() + frac_variance_explained = 1 - residual_variance / total_variance + log[f"frac_variance_explained"] = frac_variance_explained.item() + else: # transcoder + x, x_hat, f, losslog = trainer.loss(act, step=step, logging=True) + + # L0 + l0 = (f != 0).float().sum(dim=-1).mean().item() + + if verbose: + print(f"Step {step}: L0 = {l0}, frac_variance_explained = {frac_variance_explained}") + + # log parameters from training + log.update({f"{k}": v.cpu().item() if isinstance(v, t.Tensor) else v for k, v in losslog.items()}) + log[f"l0"] = l0 + trainer_log = trainer.get_logging_parameters() + for name, value in trainer_log.items(): + if isinstance(value, t.Tensor): + value = value.cpu().item() + log[f"{name}"] = value + + if log_queues: + log_queues[i].put(log) + +def get_norm_factor(data, steps: int) -> float: + """Per Section 3.1, find a fixed scalar factor so activation vectors have unit mean squared norm. + This is very helpful for hyperparameter transfer between different layers and models. + Use more steps for more accurate results. + https://arxiv.org/pdf/2408.05147 + + If experiencing troubles with hyperparameter transfer between models, it may be worth instead normalizing to the square root of d_model. + https://transformer-circuits.pub/2024/april-update/index.html#training-saes""" + total_mean_squared_norm = 0 + count = 0 + + for step, act_BD in enumerate(tqdm(data, total=steps, desc="Calculating norm factor")): + if step > steps: + break + + count += 1 + mean_squared_norm = t.mean(t.sum(act_BD ** 2, dim=1)) + total_mean_squared_norm += mean_squared_norm + + average_mean_squared_norm = total_mean_squared_norm / count + norm_factor = t.sqrt(average_mean_squared_norm).item() + + print(f"Average mean squared norm: {average_mean_squared_norm}") + print(f"Norm factor: {norm_factor}") + + return norm_factor + + + +def trainSAE( + data, + trainer_configs: list[dict], + steps: int, + use_wandb:bool=False, + wandb_entity:str="", + wandb_project:str="", + save_steps:Optional[list[int]]=None, + save_dir:Optional[str]=None, + log_steps:Optional[int]=None, + activations_split_by_head:bool=False, + transcoder:bool=False, + run_cfg:dict={}, + normalize_activations:bool=False, + verbose:bool=False, + device:str="cuda", + autocast_dtype: t.dtype = t.float32, + backup_steps:Optional[int]=None, +): + """ + Train SAEs using the given trainers + + If normalize_activations is True, the activations will be normalized to have unit mean squared norm. + The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. + This is very helpful for hyperparameter transfer between different layers and models. + + Setting autocast_dtype to t.bfloat16 provides a significant speedup with minimal change in performance. + """ + + device_type = "cuda" if "cuda" in device else "cpu" + autocast_context = nullcontext() if device_type == "cpu" else t.autocast(device_type=device_type, dtype=autocast_dtype) + + trainers = [] + for i, config in enumerate(trainer_configs): + if "wandb_name" in config: + config["wandb_name"] = f"{config['wandb_name']}_trainer_{i}" + trainer_class = config["trainer"] + del config["trainer"] + trainers.append(trainer_class(**config)) + + wandb_processes = [] + log_queues = [] + + if use_wandb: + # Note: If encountering wandb and CUDA related errors, try setting start method to spawn in the if __name__ == "__main__" block + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method + # Everything should work fine with the default fork method but it may not be as robust + for i, trainer in enumerate(trainers): + log_queue = mp.Queue() + log_queues.append(log_queue) + wandb_config = trainer.config | run_cfg + # Make sure wandb config doesn't contain any CUDA tensors + wandb_config = {k: v.cpu().item() if isinstance(v, t.Tensor) else v + for k, v in wandb_config.items()} + wandb_process = mp.Process( + target=new_wandb_process, + args=(wandb_config, log_queue, wandb_entity, wandb_project), + ) + wandb_process.start() + wandb_processes.append(wandb_process) + + # make save dirs, export config + if save_dir is not None: + save_dirs = [ + os.path.join(save_dir, f"trainer_{i}") for i in range(len(trainer_configs)) + ] + for trainer, dir in zip(trainers, save_dirs): + os.makedirs(dir, exist_ok=True) + # save config + config = {"trainer": trainer.config} + try: + config["buffer"] = data.config + except: + pass + with open(os.path.join(dir, "config.json"), "w") as f: + json.dump(config, f, indent=4) + else: + save_dirs = [None for _ in trainer_configs] + + if normalize_activations: + norm_factor = get_norm_factor(data, steps=100) + + for trainer in trainers: + trainer.config["norm_factor"] = norm_factor + # Verify that all autoencoders have a scale_biases method + trainer.ae.scale_biases(1.0) + + for step, act in enumerate(tqdm(data, total=steps)): + + act = act.to(dtype=autocast_dtype) + + if normalize_activations: + act /= norm_factor + + if step >= steps: + break + + # logging + if (use_wandb or verbose) and step % log_steps == 0: + log_stats( + trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose + ) + + # saving + if save_steps is not None and step in save_steps: + for dir, trainer in zip(save_dirs, trainers): + if dir is None: + continue + + if normalize_activations: + # Temporarily scale up biases for checkpoint saving + trainer.ae.scale_biases(norm_factor) + + if not os.path.exists(os.path.join(dir, "checkpoints")): + os.mkdir(os.path.join(dir, "checkpoints")) + + checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} + t.save( + checkpoint, + os.path.join(dir, "checkpoints", f"ae_{step}.pt"), + ) + + if normalize_activations: + trainer.ae.scale_biases(1 / norm_factor) + + # backup + if backup_steps is not None and step % backup_steps == 0: + for save_dir, trainer in zip(save_dirs, trainers): + if save_dir is None: + continue + # save the current state of the trainer for resume if training is interrupted + # this will be overwritten by the next checkpoint and at the end of training + t.save( + { + "step": step, + "ae": trainer.ae.state_dict(), + "optimizer": trainer.optimizer.state_dict(), + "config": trainer.config, + "norm_factor": norm_factor, + }, + os.path.join(save_dir, "ae.pt"), + ) + + # training + for trainer in trainers: + with autocast_context: + trainer.update(step, act) + + # save final SAEs + for save_dir, trainer in zip(save_dirs, trainers): + if normalize_activations: + trainer.ae.scale_biases(norm_factor) + if save_dir is not None: + final = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} + t.save(final, os.path.join(save_dir, "ae.pt")) + + # Signal wandb processes to finish + if use_wandb: + for queue in log_queues: + queue.put("DONE") + for process in wandb_processes: + process.join() diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py new file mode 100644 index 0000000..6f2d2c0 --- /dev/null +++ b/dictionary_learning/utils.py @@ -0,0 +1,341 @@ +from datasets import load_dataset +import zstandard as zstd +import io +import json +import os +from transformers import AutoModelForCausalLM +from fractions import Fraction +import random +from transformers import AutoTokenizer +import torch as t + +from .trainers.top_k import AutoEncoderTopK +from .trainers.batch_top_k import BatchTopKSAE +from .trainers.matryoshka_batch_top_k import MatryoshkaBatchTopKSAE +from .dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) + + +def hf_dataset_to_generator(dataset_name, split="train", streaming=True): + dataset = load_dataset(dataset_name, split=split, streaming=streaming) + + def gen(): + for x in iter(dataset): + yield x["text"] + + return gen() + + +def zst_to_generator(data_path): + """ + Load a dataset from a .jsonl.zst file. + The jsonl entries is assumed to have a 'text' field + """ + compressed_file = open(data_path, "rb") + dctx = zstd.ZstdDecompressor() + reader = dctx.stream_reader(compressed_file) + text_stream = io.TextIOWrapper(reader, encoding="utf-8") + + def generator(): + for line in text_stream: + yield json.loads(line)["text"] + + return generator() + + +def randomly_remove_system_prompt( + text: str, freq: float, system_prompt: str | None = None +) -> str: + if system_prompt and random.random() < freq: + assert system_prompt in text + text = text.replace(system_prompt, "") + return text + + +def hf_mixed_dataset_to_generator( + tokenizer: AutoTokenizer, + pretrain_dataset: str = "HuggingFaceFW/fineweb", + chat_dataset: str = "lmsys/lmsys-chat-1m", + min_chars: int = 1, + pretrain_frac: float = 0.9, # 0.9 → 90 % pretrain, 10 % chat + split: str = "train", + streaming: bool = True, + pretrain_key: str = "text", + chat_key: str = "conversation", + sequence_pack_pretrain: bool = True, + sequence_pack_chat: bool = False, + system_prompt_to_remove: str | None = None, + system_prompt_removal_freq: float = 0.9, +): + """Get a mix of pretrain and chat data at a specified ratio. By default, 90% of the data will be pretrain and 10% will be chat. + + Default datasets: + pretrain_dataset: "HuggingFaceFW/fineweb" + chat_dataset: "lmsys/lmsys-chat-1m" + + Note that you will have to request permission for lmsys (instant approval on HuggingFace). + + min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens. + Samples will be joined with the eos token. + If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't. + + Why use strings instead of tokens? Because dictionary learning expects an iterator of strings, and this is simple and good enough. + + Implicit assumption: each sample will be truncated to sequence length when tokenized. + + By default, we sequence pack the pretrain data and DO NOT sequence pack the chat data, as it would look kind of weird. The EOS token is used to separate + user / assistant messages, not to separate conversations from different users. + If you want to sequence pack the chat data, set sequence_pack_chat to True. + + Pretrain format will be: texttexttext... + Chat format will be Optionally: ... + + Other parameters: + - system_prompt_to_remove: an optional string that will be removed from the chat data with a given frequency. + You probably want to verify that the system prompt you pass in is correct. + - system_prompt_removal_freq: the frequency with which the system prompt will be removed + + Why? Well, we probably don't want to have 1000's of copies of the system prompt in the training dataset. But we also may not want to remove it entirely. + And we may want to use the LLM with no system prompt when comparing between models. + IDK, this is a complicated and annoying detail. At least this constrains the complexity to the dataset generator. + """ + if not 0 < pretrain_frac < 1: + raise ValueError("main_frac must be between 0 and 1 (exclusive)") + + assert min_chars > 0 + + # Load both datasets as iterable streams + pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming)) + chat_ds = iter(load_dataset(chat_dataset, split=split, streaming=streaming)) + + # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10) + frac = Fraction(pretrain_frac).limit_denominator() + n_pretrain = frac.numerator + n_chat = frac.denominator - n_pretrain + eos_token = tokenizer.eos_token + + bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token + + def gen(): + while True: + for _ in range(n_pretrain): + if sequence_pack_pretrain: + length = 0 + samples = [] + while length < min_chars: + # Add bos token to the beginning of the sample + sample = next(pretrain_ds)[pretrain_key] + samples.append(sample) + length += len(sample) + samples = bos_token + eos_token.join(samples) + yield samples + else: + sample = bos_token + next(pretrain_ds)[pretrain_key] + yield sample + for _ in range(n_chat): + if sequence_pack_chat: + length = 0 + samples = [] + while length < min_chars: + sample = next(chat_ds)[chat_key] + # Apply chat template also includes bos token + sample = tokenizer.apply_chat_template(sample, tokenize=False) + sample = randomly_remove_system_prompt( + sample, system_prompt_removal_freq, system_prompt_to_remove + ) + samples.append(sample) + length += len(sample) + samples = "".join(samples) + yield samples + else: + sample = tokenizer.apply_chat_template( + next(chat_ds)[chat_key], tokenize=False + ) + sample = randomly_remove_system_prompt( + sample, system_prompt_removal_freq, system_prompt_to_remove + ) + yield sample + + return gen() + + +def hf_sequence_packing_dataset_to_generator( + tokenizer: AutoTokenizer, + pretrain_dataset: str = "HuggingFaceFW/fineweb", + min_chars: int = 1, + split: str = "train", + streaming: bool = True, + pretrain_key: str = "text", + sequence_pack_pretrain: bool = True, +): + """min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens. + Samples will be joined with the eos token. + If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't.""" + assert min_chars > 0 + + # Load both datasets as iterable streams + pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming)) + + eos_token = tokenizer.eos_token + + bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token + + def gen(): + while True: + if sequence_pack_pretrain: + length = 0 + samples = [] + while length < min_chars: + # Add bos token to the beginning of the sample + sample = next(pretrain_ds)[pretrain_key] + samples.append(sample) + length += len(sample) + samples = bos_token + eos_token.join(samples) + yield samples + else: + sample = bos_token + next(pretrain_ds)[pretrain_key] + yield sample + + return gen() + + +def simple_hf_mixed_dataset_to_generator( + main_name: str, + aux_name: str, + main_frac: float = 0.9, # 0.9 → 90 % main, 10 % aux + split: str = "train", + streaming: bool = True, + main_key: str = "text", + aux_key: str = "text", +): + if not 0 < main_frac < 1: + raise ValueError("main_frac must be between 0 and 1 (exclusive)") + + # Load both datasets as iterable streams + main_ds = iter(load_dataset(main_name, split=split, streaming=streaming)) + aux_ds = iter(load_dataset(aux_name, split=split, streaming=streaming)) + + # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10) + frac = Fraction(main_frac).limit_denominator() + n_main = frac.numerator + n_aux = frac.denominator - n_main + + def gen(): + while True: + # Yield `n_main` items from the main dataset + for _ in range(n_main): + yield next(main_ds)[main_key] + # Yield `n_aux` items from the auxiliary dataset + for _ in range(n_aux): + yield next(aux_ds)[aux_key] + + return gen() + + +def get_nested_folders(path: str) -> list[str]: + """ + Recursively get a list of folders that contain an ae.pt file, starting the search from the given path + """ + folder_names = [] + + for root, dirs, files in os.walk(path): + if "ae.pt" in files: + folder_names.append(root) + + return folder_names + + +def load_dictionary(base_path: str, device: str) -> tuple: + ae_path = f"{base_path}/ae.pt" + config_path = f"{base_path}/config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + dict_class = config["trainer"]["dict_class"] + + if dict_class == "AutoEncoder": + dictionary = AutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "GatedAutoEncoder": + dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderNew": + dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderTopK": + k = config["trainer"]["k"] + dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "BatchTopKSAE": + k = config["trainer"]["k"] + dictionary = BatchTopKSAE.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "MatryoshkaBatchTopKSAE": + k = config["trainer"]["k"] + dictionary = MatryoshkaBatchTopKSAE.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "JumpReluAutoEncoder": + dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) + else: + raise ValueError(f"Dictionary class {dict_class} not supported") + + return dictionary, config + + +def get_submodule(model: AutoModelForCausalLM, layer: int): + """Gets the residual stream submodule""" + model_name = model.name_or_path + + if model.config.architectures[0] == "GPTNeoXForCausalLM": + return model.gpt_neox.layers[layer] + elif ( + model.config.architectures[0] == "Qwen2ForCausalLM" + or model.config.architectures[0] == "Gemma2ForCausalLM" + ): + return model.model.layers[layer] + else: + raise ValueError(f"Please add submodule for model {model_name}") + + +def truncate_model(model: AutoModelForCausalLM, layer: int): + """From tilde-research/activault + https://github.com/tilde-research/activault/blob/db6d1e4e36c2d3eb4fdce79e72be94f387eccee1/pipeline/setup.py#L74 + This provides significant memory savings by deleting all layers that aren't needed for the given layer. + You should probably test this before using it""" + import gc + + total_params_before = sum(p.numel() for p in model.parameters()) + print(f"Model parameters before truncation: {total_params_before:,}") + + if ( + model.config.architectures[0] == "Qwen2ForCausalLM" + or model.config.architectures[0] == "Gemma2ForCausalLM" + ): + removed_layers = model.model.layers[layer + 1 :] + + model.model.layers = model.model.layers[: layer + 1] + + del removed_layers + del model.lm_head + + model.lm_head = t.nn.Identity() + + elif model.config.architectures[0] == "GPTNeoXForCausalLM": + removed_layers = model.gpt_neox.layers[layer + 1 :] + + model.gpt_neox.layers = model.gpt_neox.layers[: layer + 1] + + del removed_layers + del model.embed_out + + model.embed_out = t.nn.Identity() + + else: + raise ValueError(f"Please add truncation for model {model.name_or_path}") + + total_params_after = sum(p.numel() for p in model.parameters()) + print(f"Model parameters after truncation: {total_params_after:,}") + + gc.collect() + t.cuda.empty_cache() + + return model diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..28b2b9e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,45 @@ +[tool.poetry] +name = "dictionary-learning" +version = "0.1.0" +description = "Dictionary learning via sparse autoencoders on neural network activations" +authors = ["Samuel Marks", "Adam Karvonen", "Aaron Mueller"] +packages = [{ include = "dictionary_learning" }] +license = "MIT" +readme = "README.md" +keywords = [ + "deep-learning", + "sparse-autoencoders", + "mechanistic-interpretability", + "PyTorch", +] +classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"] +repository = "https://github.com/saprmarks/dictionary_learning" +homepage = "https://github.com/saprmarks/dictionary_learning" + + +[tool.poetry.dependencies] +python = "^3.10" +circuitsvis = ">=1.43.2" +datasets = ">=2.18.0" +einops = ">=0.7.0" +nnsight = ">=0.3.0,<0.4.0" +pandas = ">=2.2.1" +plotly = ">=5.18.0" +tqdm = ">=4.66.1" +zstandard = ">=0.22.0" +wandb = ">=0.12.0" +umap-learn = ">=0.5.6" +llvmlite = ">=0.40.0" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.3.4" + +[build-system] +requires = ["poetry-core>=2.0.0,<3.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.semantic_release] +version_variables = ["dictionary_learning/__init__.py:__version__"] +version_toml = ["pyproject.toml:tool.poetry.version"] +branch = "main" +build_command = "pip install poetry && poetry build" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 7366e63..0000000 --- a/requirements.txt +++ /dev/null @@ -1,12 +0,0 @@ -circuitsvis>=1.43.2 -datasets>=2.18.0 -einops>=0.7.0 -matplotlib>=3.8.3 -nnsight>=0.2.11 -pandas>=2.2.1 -plotly>=5.18.0 -torch>=2.1.2 -tqdm>=4.66.1 -umap-learn>=0.5.6 -zstandard>=0.22.0 -wandb diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py new file mode 100644 index 0000000..055db18 --- /dev/null +++ b/tests/test_end_to_end.py @@ -0,0 +1,262 @@ +import torch as t +from nnsight import LanguageModel +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK +from dictionary_learning.utils import ( + hf_dataset_to_generator, + get_nested_folders, + load_dictionary, +) +from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.362327718734742, + "l1_loss": 50.94957427978515, + "l0": 40.0, + "frac_variance_explained": 0.9578053653240204, + "cossim": 0.9478691875934601, + "l2_ratio": 0.9478908002376556, + "relative_reconstruction_bias": 0.999762898683548, + "loss_original": 3.3361297130584715, + "loss_reconstructed": 3.8404462814331053, + "loss_zero": 13.251659297943116, + "frac_recovered": 0.948982036113739, + "frac_alive": 0.99951171875, + }, + "AutoEncoder": { + "l2_loss": 6.822444677352905, + "l1_loss": 19.382131576538086, + "l0": 37.45087890625, + "frac_variance_explained": 0.8993501663208008, + "cossim": 0.8791120409965515, + "l2_ratio": 0.74552041888237, + "relative_reconstruction_bias": 0.9595054805278778, + "loss_original": 3.3361297130584715, + "loss_reconstructed": 5.208198881149292, + "loss_zero": 13.251659297943116, + "frac_recovered": 0.8106247961521149, + "frac_alive": 0.99658203125, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE_PERCENT = 0.005 + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing. + I have observed that results can slightly vary with library versions. For full determinism, + use pytorch 2.5.1 and nnsight 0.3.7. + Unfortunately an RTX 3090 is also required for full determinism. On an H100 the results are off by ~0.3%, meaning this test will + not be within the EVAL_TOLERANCE.""" + + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + k = 40 + sparsity_penalty = 2.0 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = None + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TopKTrainer, + "dict_class": AutoEncoderTopK, + "lr": None, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "warmup_steps": 0, + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": RANDOM_SEED, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": LAYER, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "sparsity_warmup_steps": None, + "decay_start": decay_start, + "steps": steps, + "resample_steps": resample_steps, + "seed": RANDOM_SEED, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": LAYER, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * n_batches: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + n_batches=n_batches, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + max_diff = 0 + max_diff_percent = 0 + for key, value in expected_results.items(): + diff = abs(eval_results[key] - value) + max_diff = max(max_diff, diff) + max_diff_percent = max(max_diff_percent, diff / value) + + print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") + assert max_diff_percent < EVAL_TOLERANCE_PERCENT diff --git a/tests/test_pytorch_end_to_end.py b/tests/test_pytorch_end_to_end.py new file mode 100644 index 0000000..79ef5b3 --- /dev/null +++ b/tests/test_pytorch_end_to_end.py @@ -0,0 +1,261 @@ +import torch as t +from transformers import AutoModelForCausalLM, AutoTokenizer +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK +from dictionary_learning.utils import ( + hf_dataset_to_generator, + get_nested_folders, + load_dictionary, +) + +# from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.pytorch_buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.358876752853393, + "l1_loss": 50.90618553161621, + "l0": 40.0, + "frac_variance_explained": 0.9577824175357819, + "cossim": 0.9476200461387634, + "l2_ratio": 0.9476299166679383, + "relative_reconstruction_bias": 0.9996505916118622, + "frac_alive": 1.0, + }, + "AutoEncoder": { + "l2_loss": 6.8308186531066895, + "l1_loss": 19.398421669006346, + "l0": 37.4469970703125, + "frac_variance_explained": 0.9003101229667664, + "cossim": 0.8782103300094605, + "l2_ratio": 0.7444103538990021, + "relative_reconstruction_bias": 0.960041344165802, + "frac_alive": 0.9970703125, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE_PERCENT = 0.005 + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing. + I have observed that results can slightly vary with library versions. For full determinism, + use pytorch 2.5.1 and nnsight 0.3.7. + Unfortunately an RTX 3090 is also required for full determinism. On an H100 the results are off by ~0.3%, meaning this test will + not be within the EVAL_TOLERANCE.""" + + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + # model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, device_map="auto", torch_dtype=t.float32 + ).to(DEVICE) + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + k = 40 + sparsity_penalty = 2.0 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = None + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TopKTrainer, + "dict_class": AutoEncoderTopK, + "lr": None, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "warmup_steps": 0, + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": RANDOM_SEED, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": LAYER, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "sparsity_warmup_steps": None, + "decay_start": decay_start, + "steps": steps, + "resample_steps": resample_steps, + "seed": RANDOM_SEED, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": LAYER, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, device_map="auto", torch_dtype=t.float32 + ).to(DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * n_batches: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + n_batches=n_batches, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + max_diff = 0 + max_diff_percent = 0 + for key, value in expected_results.items(): + diff = abs(eval_results[key] - value) + max_diff = max(max_diff, diff) + max_diff_percent = max(max_diff_percent, diff / value) + + print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") + assert max_diff_percent < EVAL_TOLERANCE_PERCENT diff --git a/tests/unit/test_dictionary.py b/tests/unit/test_dictionary.py new file mode 100644 index 0000000..232eb0e --- /dev/null +++ b/tests/unit/test_dictionary.py @@ -0,0 +1,136 @@ +import torch as t +import pytest +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) + + +@pytest.mark.parametrize( + "sae_cls", [AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder] +) +def test_forward_equals_decode_encode(sae_cls: type) -> None: + """Test that forward pass equals decode(encode(x)) for all SAE types""" + batch_size = 4 + act_dim = 8 + dict_size = 6 + x = t.randn(batch_size, act_dim) + + sae = sae_cls(activation_dim=act_dim, dict_size=dict_size) + + # Test without output_features + forward_out = sae(x) + encode_decode = sae.decode(sae.encode(x)) + assert t.allclose(forward_out, encode_decode) + + # Test with output_features + forward_out, features = sae(x, output_features=True) + encode_features = sae.encode(x) + assert t.allclose(features, encode_features) + + +def test_simple_autoencoder() -> None: + """Test AutoEncoder with simple weight matrices""" + sae = AutoEncoder(activation_dim=2, dict_size=2) + + # Set simple weights + with t.no_grad(): + sae.encoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.decoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.encoder.bias.data = t.zeros(2) + sae.bias.data = t.zeros(2) + + # Test encoding + x = t.tensor([[2.0, -1.0]]) + encoded = sae.encode(x) + assert t.allclose(encoded, t.tensor([[2.0, 0.0]])) # ReLU clips negative value + + # Test decoding + decoded = sae.decode(encoded) + assert t.allclose(decoded, t.tensor([[2.0, 0.0]])) + + +def test_simple_gated_autoencoder() -> None: + """Test GatedAutoEncoder with simple weight matrices""" + sae = GatedAutoEncoder(activation_dim=2, dict_size=2) + + # Set simple weights and biases + with t.no_grad(): + sae.encoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.decoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.gate_bias.data = t.zeros(2) + sae.mag_bias.data = t.zeros(2) + sae.r_mag.data = t.zeros(2) + sae.decoder_bias.data = t.zeros(2) + + x = t.tensor([[2.0, -1.0]]) + encoded = sae.encode(x) + assert t.allclose( + encoded, t.tensor([[2.0, 0.0]]) + ) # Only positive values pass through + + +def test_normalize_decoder() -> None: + """Test that normalize_decoder maintains output while normalizing weights""" + sae = AutoEncoder(activation_dim=4, dict_size=3) + x = t.randn(2, 4) + + # Get initial output + initial_output = sae(x) + + # Normalize decoder + sae.normalize_decoder() + + # Check decoder weights are normalized + norms = t.norm(sae.decoder.weight, dim=0) + assert t.allclose(norms, t.ones_like(norms)) + + # Check output is maintained + new_output = sae(x) + assert t.allclose(initial_output, new_output, atol=1e-4) + + +def test_scale_biases() -> None: + """Test that scale_biases correctly scales all bias terms""" + sae = AutoEncoder(activation_dim=4, dict_size=3) + + # Record initial biases + initial_encoder_bias = sae.encoder.bias.data.clone() + initial_bias = sae.bias.data.clone() + + scale = 2.0 + sae.scale_biases(scale) + + assert t.allclose(sae.encoder.bias.data, initial_encoder_bias * scale) + assert t.allclose(sae.bias.data, initial_bias * scale) + + +@pytest.mark.parametrize( + "sae_cls", [AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder] +) +def test_output_shapes(sae_cls: type) -> None: + """Test that output shapes are correct for all operations""" + batch_size = 3 + act_dim = 4 + dict_size = 5 + x = t.randn(batch_size, act_dim) + + sae = sae_cls(activation_dim=act_dim, dict_size=dict_size) + + # Test encode shape + encoded = sae.encode(x) + assert encoded.shape == (batch_size, dict_size) + + # Test decode shape + decoded = sae.decode(encoded) + assert decoded.shape == (batch_size, act_dim) + + # Test forward shape with and without features + output = sae(x) + assert output.shape == (batch_size, act_dim) + + output, features = sae(x, output_features=True) + assert output.shape == (batch_size, act_dim) + assert features.shape == (batch_size, dict_size) diff --git a/trainers/__init__.py b/trainers/__init__.py deleted file mode 100644 index 461af62..0000000 --- a/trainers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .standard import StandardTrainer -from .gdm import GatedSAETrainer -from .p_anneal import PAnnealTrainer -from .gated_anneal import GatedAnnealTrainer -from .top_k import TrainerTopK -from .jumprelu import TrainerJumpRelu -from .batch_top_k import TrainerBatchTopK, BatchTopKSAE diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py deleted file mode 100644 index e684d9a..0000000 --- a/trainers/batch_top_k.py +++ /dev/null @@ -1,243 +0,0 @@ -import torch as t -import torch.nn as nn -import einops -from collections import namedtuple - -from ..config import DEBUG -from ..dictionary import Dictionary -from ..trainers.trainer import SAETrainer - - -class BatchTopKSAE(Dictionary, nn.Module): - def __init__(self, activation_dim: int, dict_size: int, k: int): - super().__init__() - self.activation_dim = activation_dim - self.dict_size = dict_size - self.k = k - - self.encoder = nn.Linear(activation_dim, dict_size) - self.encoder.bias.data.zero_() - self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.decoder.weight.data = self.encoder.weight.data.clone().T - self.set_decoder_norm_to_unit_norm() - self.b_dec = nn.Parameter(t.zeros(activation_dim)) - - def encode(self, x: t.Tensor, return_active: bool = False): - post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) - - # Flatten and perform batch top-k - flattened_acts = post_relu_feat_acts_BF.flatten() - post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) - - buffer_BF = t.zeros_like(post_relu_feat_acts_BF) - encoded_acts_BF = ( - buffer_BF.flatten() - .scatter(-1, post_topk.indices, post_topk.values) - .reshape(buffer_BF.shape) - ) - - if return_active: - return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 - else: - return encoded_acts_BF - - def decode(self, x: t.Tensor) -> t.Tensor: - return self.decoder(x) + self.b_dec - - def forward(self, x: t.Tensor, output_features: bool = False): - encoded_acts_BF = self.encode(x) - x_hat_BD = self.decode(encoded_acts_BF) - - if not output_features: - return x_hat_BD - else: - return x_hat_BD, encoded_acts_BF - - @t.no_grad() - def set_decoder_norm_to_unit_norm(self): - eps = t.finfo(self.decoder.weight.dtype).eps - norm = t.norm(self.decoder.weight.data, dim=0, keepdim=True) - self.decoder.weight.data /= norm + eps - - @t.no_grad() - def remove_gradient_parallel_to_decoder_directions(self): - assert self.decoder.weight.grad is not None - parallel_component = einops.einsum( - self.decoder.weight.grad, - self.decoder.weight.data, - "d_in d_sae, d_in d_sae -> d_sae", - ) - self.decoder.weight.grad -= einops.einsum( - parallel_component, - self.decoder.weight.data, - "d_sae, d_in d_sae -> d_in d_sae", - ) - - -class TrainerBatchTopK(SAETrainer): - def __init__( - self, - dict_class=BatchTopKSAE, - activation_dim=512, - dict_size=64 * 512, - k=8, - auxk_alpha=1 / 32, - decay_start=24000, - steps=30000, - top_k_aux=512, - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name="BatchTopKSAE", - submodule_name=None, - ): - super().__init__(seed) - assert layer is not None and lm_name is not None - self.layer = layer - self.lm_name = lm_name - self.submodule_name = submodule_name - self.wandb_name = wandb_name - self.steps = steps - self.k = k - - if seed is not None: - t.manual_seed(seed) - t.cuda.manual_seed_all(seed) - - self.ae = dict_class(activation_dim, dict_size, k) - - if device is None: - self.device = "cuda" if t.cuda.is_available() else "cpu" - else: - self.device = device - self.ae.to(self.device) - - scale = dict_size / (2**14) - self.lr = 2e-4 / scale**0.5 - self.auxk_alpha = auxk_alpha - self.dead_feature_threshold = 10_000_000 - self.top_k_aux = top_k_aux - - self.optimizer = t.optim.Adam( - self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) - ) - - def lr_fn(step): - if step < decay_start: - return 1.0 - else: - return (steps - step) / (steps - decay_start) - - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - - self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features"] - self.effective_l0 = -1 - self.dead_features = -1 - - def get_auxiliary_loss(self, x, x_reconstruct, acts): - dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold - if dead_features.sum() > 0: - residual = x.float() - x_reconstruct.float() - acts_topk_aux = t.topk( - acts[:, dead_features], - min(self.top_k_aux, dead_features.sum()), - dim=-1, - ) - acts_aux = t.zeros_like(acts[:, dead_features]).scatter( - -1, acts_topk_aux.indices, acts_topk_aux.values - ) - x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] - l2_loss_aux = ( - self.auxk_alpha - * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() - ) - return l2_loss_aux - else: - return t.tensor(0, dtype=x.dtype, device=x.device) - - def loss(self, x, step=None, logging=False): - f, active_indices = self.ae.encode(x, return_active=True) - l0 = (f != 0).float().sum(dim=-1).mean().item() - x_hat = self.ae.decode(f) - - e = x_hat - x - - self.effective_l0 = self.k - - num_tokens_in_step = x.size(0) - did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) - did_fire[active_indices] = True - self.num_tokens_since_fired += num_tokens_in_step - self.num_tokens_since_fired[did_fire] = 0 - - auxk_loss = self.get_auxiliary_loss(x, x_hat, f) - - l2_loss = e.pow(2).sum(dim=-1).mean() - auxk_loss = auxk_loss.sum(dim=-1).mean() - loss = l2_loss + self.auxk_alpha * auxk_loss - - if not logging: - return loss - else: - return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( - x, - x_hat, - f, - {"l2_loss": l2_loss.item(), "auxk_loss": auxk_loss.item(), "loss": loss.item()}, - ) - - def update(self, step, x): - if step == 0: - median = self.geometric_median(x) - self.ae.b_dec.data = median - - self.ae.set_decoder_norm_to_unit_norm() - - x = x.to(self.device) - loss = self.loss(x, step=step) - loss.backward() - - t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) - self.ae.remove_gradient_parallel_to_decoder_directions() - - self.optimizer.step() - self.optimizer.zero_grad() - self.scheduler.step() - - return loss.item() - - @property - def config(self): - return { - "trainer_class": "TrainerBatchTopK", - "dict_class": "BatchTopKSAE", - "lr": self.lr, - "steps": self.steps, - "seed": self.seed, - "activation_dim": self.ae.activation_dim, - "dict_size": self.ae.dict_size, - "k": self.ae.k, - "device": self.device, - "layer": self.layer, - "lm_name": self.lm_name, - "wandb_name": self.wandb_name, - "submodule_name": self.submodule_name, - } - - @staticmethod - def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): - guess = points.mean(dim=0) - prev = t.zeros_like(guess) - weights = t.ones(len(points), device=points.device) - - for _ in range(max_iter): - prev = guess - weights = 1 / t.norm(points - guess, dim=1) - weights /= weights.sum() - guess = (weights.unsqueeze(1) * points).sum(dim=0) - if t.norm(guess - prev) < tol: - break - - return guess diff --git a/trainers/standard.py b/trainers/standard.py deleted file mode 100644 index 2cfbb6a..0000000 --- a/trainers/standard.py +++ /dev/null @@ -1,183 +0,0 @@ -""" -Implements the standard SAE training scheme. -""" -import torch as t -from ..trainers.trainer import SAETrainer -from ..config import DEBUG -from ..dictionary import AutoEncoder -from collections import namedtuple - -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - -class StandardTrainer(SAETrainer): - """ - Standard SAE training scheme. - """ - def __init__(self, - dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='StandardTrainer', - submodule_name=None, - ): - super().__init__(seed) - - assert layer is not None and lm_name is not None - self.layer = layer - self.lm_name = lm_name - self.submodule_name = submodule_name - - if seed is not None: - t.manual_seed(seed) - t.cuda.manual_seed_all(seed) - - # initialize dictionary - self.ae = dict_class(activation_dim, dict_size) - - self.lr = lr - self.l1_penalty=l1_penalty - self.warmup_steps = warmup_steps - self.wandb_name = wandb_name - - if device is None: - self.device = 'cuda' if t.cuda.is_available() else 'cpu' - else: - self.device = device - self.ae.to(self.device) - - self.resample_steps = resample_steps - - - if self.resample_steps is not None: - # how many steps since each neuron was last activated? - self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) - else: - self.steps_since_active = None - - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if resample_steps is None: - def warmup_fn(step): - return min(step / warmup_steps, 1.) - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) - - def resample_neurons(self, deads, activations): - with t.no_grad(): - if deads.sum() == 0: return - print(f"resampling {deads.sum().item()} neurons") - - # compute loss for each activation - losses = (activations - self.ae(activations)).norm(dim=-1) - - # sample input to create encoder/decoder weights from - n_resample = min([deads.sum(), losses.shape[0]]) - indices = t.multinomial(losses, num_samples=n_resample, replacement=False) - sampled_vecs = activations[indices] - - # get norm of the living neurons - alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() - - # resample first n_resample dead neurons - deads[deads.nonzero()[n_resample:]] = False - self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 - self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T - self.ae.encoder.bias[deads] = 0. - - - # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] - ## encoder weight - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. - ## encoder bias - state_dict[2]['exp_avg'][deads] = 0. - state_dict[2]['exp_avg_sq'][deads] = 0. - ## decoder weight - state_dict[3]['exp_avg'][:,deads] = 0. - state_dict[3]['exp_avg_sq'][:,deads] = 0. - - def loss(self, x, logging=False, **kwargs): - x_hat, f = self.ae(x, output_features=True) - l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() - l1_loss = f.norm(p=1, dim=-1).mean() - - if self.steps_since_active is not None: - # update steps_since_active - deads = (f == 0).all(dim=0) - self.steps_since_active[deads] += 1 - self.steps_since_active[~deads] = 0 - - loss = l2_loss + self.l1_penalty * l1_loss - - if not logging: - return loss - else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, - { - 'l2_loss' : l2_loss.item(), - 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(), - 'sparsity_loss' : l1_loss.item(), - 'loss' : loss.item() - } - ) - - - def update(self, step, activations): - activations = activations.to(self.device) - - self.optimizer.zero_grad() - loss = self.loss(activations) - loss.backward() - self.optimizer.step() - self.scheduler.step() - - if self.resample_steps is not None and step % self.resample_steps == 0: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) - - @property - def config(self): - return { - 'dict_class': 'AutoEncoder', - 'trainer_class' : 'StandardTrainer', - 'activation_dim': self.ae.activation_dim, - 'dict_size': self.ae.dict_size, - 'lr' : self.lr, - 'l1_penalty' : self.l1_penalty, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'device' : self.device, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name': self.wandb_name, - 'submodule_name': self.submodule_name, - } - diff --git a/trainers/trainer.py b/trainers/trainer.py deleted file mode 100644 index 04170b9..0000000 --- a/trainers/trainer.py +++ /dev/null @@ -1,28 +0,0 @@ -class SAETrainer: - """ - Generic class for implementing SAE training algorithms - """ - def __init__(self, seed=None): - self.seed = seed - self.logging_parameters = [] - - def update(self, - step, # index of step in training - activations, # of shape [batch_size, d_submodule] - ): - pass # implemented by subclasses - - def get_logging_parameters(self): - stats = {} - for param in self.logging_parameters: - if hasattr(self, param): - stats[param] = getattr(self, param) - else: - print(f"Warning: {param} not found in {self}") - return stats - - @property - def config(self): - return { - 'wandb_name': 'trainer', - } diff --git a/training.py b/training.py deleted file mode 100644 index f100fee..0000000 --- a/training.py +++ /dev/null @@ -1,167 +0,0 @@ -""" -Training dictionaries -""" - -import json -import multiprocessing as mp -import os -from queue import Empty - -import torch as t -from tqdm import tqdm - -import wandb - -from .dictionary import AutoEncoder -from .evaluation import evaluate -from .trainers.standard import StandardTrainer - - -def new_wandb_process(config, log_queue, entity, project): - wandb.init(entity=entity, project=project, config=config, name=config["wandb_name"]) - while True: - try: - log = log_queue.get(timeout=1) - if log == "DONE": - break - wandb.log(log) - except Empty: - continue - wandb.finish() - - -def log_stats( - trainers, - step: int, - act: t.Tensor, - activations_split_by_head: bool, - transcoder: bool, - log_queues: list=[], -): - with t.no_grad(): - # quick hack to make sure all trainers get the same x - z = act.clone() - for i, trainer in enumerate(trainers): - log = {} - act = z.clone() - if activations_split_by_head: # x.shape: [batch, pos, n_heads, d_head] - act = act[..., i, :] - if not transcoder: - act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) - - # L0 - l0 = (f != 0).float().sum(dim=-1).mean().item() - # fraction of variance explained - total_variance = t.var(act, dim=0).sum() - residual_variance = t.var(act - act_hat, dim=0).sum() - frac_variance_explained = 1 - residual_variance / total_variance - log[f"frac_variance_explained"] = frac_variance_explained.item() - else: # transcoder - x, x_hat, f, losslog = trainer.loss(act, step=step, logging=True) - - # L0 - l0 = (f != 0).float().sum(dim=-1).mean().item() - - # log parameters from training - log.update({f"{k}": v for k, v in losslog.items()}) - log[f"l0"] = l0 - trainer_log = trainer.get_logging_parameters() - for name, value in trainer_log.items(): - log[f"{name}"] = value - - if log_queues: - log_queues[i].put(log) - - -def trainSAE( - data, - trainer_configs, - use_wandb=False, - wandb_entity="", - wandb_project="", - steps=None, - save_steps=None, - save_dir=None, - log_steps=None, - activations_split_by_head=False, - transcoder=False, - run_cfg={}, -): - """ - Train SAEs using the given trainers - """ - trainers = [] - for config in trainer_configs: - trainer_class = config["trainer"] - del config["trainer"] - trainers.append(trainer_class(**config)) - - wandb_processes = [] - log_queues = [] - - if use_wandb: - for i, trainer in enumerate(trainers): - log_queue = mp.Queue() - log_queues.append(log_queue) - wandb_config = trainer.config | run_cfg - wandb_process = mp.Process( - target=new_wandb_process, - args=(wandb_config, log_queue, wandb_entity, wandb_project), - ) - wandb_process.start() - wandb_processes.append(wandb_process) - - # make save dirs, export config - if save_dir is not None: - save_dirs = [ - os.path.join(save_dir, f"trainer_{i}") for i in range(len(trainer_configs)) - ] - for trainer, dir in zip(trainers, save_dirs): - os.makedirs(dir, exist_ok=True) - # save config - config = {"trainer": trainer.config} - try: - config["buffer"] = data.config - except: - pass - with open(os.path.join(dir, "config.json"), "w") as f: - json.dump(config, f, indent=4) - else: - save_dirs = [None for _ in trainer_configs] - - for step, act in enumerate(tqdm(data, total=steps)): - if steps is not None and step >= steps: - break - - # logging - if log_steps is not None and step % log_steps == 0: - log_stats( - trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues - ) - - # saving - if save_steps is not None and step % save_steps == 0: - for dir, trainer in zip(save_dirs, trainers): - if dir is not None: - if not os.path.exists(os.path.join(dir, "checkpoints")): - os.mkdir(os.path.join(dir, "checkpoints")) - t.save( - trainer.ae.state_dict(), - os.path.join(dir, "checkpoints", f"ae_{step}.pt"), - ) - - # training - for trainer in trainers: - trainer.update(step, act) - - # save final SAEs - for save_dir, trainer in zip(save_dirs, trainers): - if save_dir is not None: - t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt")) - - # Signal wandb processes to finish - if use_wandb: - for queue in log_queues: - queue.put("DONE") - for process in wandb_processes: - process.join() diff --git a/utils.py b/utils.py deleted file mode 100644 index 8641f05..0000000 --- a/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -from datasets import load_dataset -import zstandard as zstd -import io -import json - -def hf_dataset_to_generator(dataset_name, split='train', streaming=True): - dataset = load_dataset(dataset_name, split=split, streaming=streaming) - - def gen(): - for x in iter(dataset): - yield x['text'] - - return gen() - -def zst_to_generator(data_path): - """ - Load a dataset from a .jsonl.zst file. - The jsonl entries is assumed to have a 'text' field - """ - compressed_file = open(data_path, 'rb') - dctx = zstd.ZstdDecompressor() - reader = dctx.stream_reader(compressed_file) - text_stream = io.TextIOWrapper(reader, encoding='utf-8') - def generator(): - for line in text_stream: - yield json.loads(line)['text'] - return generator() \ No newline at end of file