From 4aea5388811284f4fd3daa8fb97916073bfe8841 Mon Sep 17 00:00:00 2001 From: Mateusz Date: Mon, 14 Oct 2024 20:53:52 +0000 Subject: [PATCH 01/70] Fix BatchTopKSAE training --- dictionary.py | 2 +- trainers/batch_top_k.py | 25 ++++++++++++++++++++++--- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/dictionary.py b/dictionary.py index f0eb176..7950cf7 100644 --- a/dictionary.py +++ b/dictionary.py @@ -2,7 +2,7 @@ Defines the dictionary classes """ -from abc import ABC, abstractclassmethod, abstractmethod +from abc import ABC, abstractmethod import torch as t import torch.nn as nn import torch.nn.init as init diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index e684d9a..c65195f 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -1,9 +1,9 @@ import torch as t import torch.nn as nn +import torch.nn.functional as F import einops from collections import namedtuple -from ..config import DEBUG from ..dictionary import Dictionary from ..trainers.trainer import SAETrainer @@ -13,7 +13,9 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size - self.k = k + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k)) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.bias.data.zero_() @@ -73,6 +75,21 @@ def remove_gradient_parallel_to_decoder_directions(self): "d_sae, d_in d_sae -> d_in d_sae", ) + @classmethod + def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": + state_dict = t.load(path) + dict_size, activation_dim = state_dict['encoder.weight'].shape + if k is None: + k = state_dict['k'].item() + elif 'k' in state_dict and k != state_dict['k'].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + autoencoder = cls(activation_dim, dict_size, k) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder + class TrainerBatchTopK(SAETrainer): def __init__( @@ -148,7 +165,9 @@ def get_auxiliary_loss(self, x, x_reconstruct, acts): acts_aux = t.zeros_like(acts[:, dead_features]).scatter( -1, acts_topk_aux.indices, acts_topk_aux.values ) - x_reconstruct_aux = acts_aux @ self.W_dec[dead_features] + x_reconstruct_aux = F.linear( + acts_aux, self.ae.decoder.weight[:, dead_features] + ) l2_loss_aux = ( self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() From fe54b001cba976ca96d46add8539580268dc5cb6 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 17 Dec 2024 23:33:01 +0000 Subject: [PATCH 02/70] Add a simple end to end test --- requirements.txt | 3 +- tests/test_end_to_end.py | 282 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 284 insertions(+), 1 deletion(-) create mode 100644 tests/test_end_to_end.py diff --git a/requirements.txt b/requirements.txt index 7366e63..5b9f3c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ torch>=2.1.2 tqdm>=4.66.1 umap-learn>=0.5.6 zstandard>=0.22.0 -wandb +wandb>=0.12.0 +pytest>=6.2.4 \ No newline at end of file diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py new file mode 100644 index 0000000..8cb4b95 --- /dev/null +++ b/tests/test_end_to_end.py @@ -0,0 +1,282 @@ +import torch as t +from nnsight import LanguageModel +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK +from dictionary_learning.utils import hf_dataset_to_generator +from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.470372676849365, + "l1_loss": 44.47749710083008, + "l0": 40.0, + "frac_alive": 0.000244140625, + "frac_variance_explained": 0.9372208118438721, + "cossim": 0.9471381902694702, + "l2_ratio": 0.9523985981941223, + "relative_reconstruction_bias": 0.9996458888053894, + "loss_original": 3.186223268508911, + "loss_reconstructed": 3.690929412841797, + "loss_zero": 12.936649322509766, + "frac_recovered": 0.9482374787330627, + }, + "AutoEncoder": { + "l2_loss": 6.72230863571167, + "l1_loss": 28.893749237060547, + "l0": 61.12999725341797, + "frac_alive": 0.000244140625, + "frac_variance_explained": 0.6076533794403076, + "cossim": 0.869738757610321, + "l2_ratio": 0.8005934953689575, + "relative_reconstruction_bias": 0.9304398894309998, + "loss_original": 3.186223268508911, + "loss_reconstructed": 5.501500129699707, + "loss_zero": 12.936649322509766, + "frac_recovered": 0.7625460624694824, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE = 0.01 + + +def get_nested_folders(path: str) -> list[str]: + """ + Recursively get a list of folders that contain an ae.pt file, starting the search from the given path + """ + folder_names = [] + + for root, dirs, files in os.walk(path): + if "ae.pt" in files: + folder_names.append(root) + + return folder_names + + +def load_dictionary(base_path: str, device: str) -> tuple: + ae_path = f"{base_path}/ae.pt" + config_path = f"{base_path}/config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + # TODO: Save the submodule name in the config? + # submodule_str = config["trainer"]["submodule_name"] + dict_class = config["trainer"]["dict_class"] + + if dict_class == "AutoEncoder": + dictionary = AutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "GatedAutoEncoder": + dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderNew": + dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderTopK": + k = config["trainer"]["k"] + dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "JumpReluAutoEncoder": + dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) + else: + raise ValueError(f"Dictionary class {dict_class} not supported") + + return dictionary, config + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~3 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing.""" + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + MODEL_NAME = "EleutherAI/pythia-70m-deduped" + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + layer = 3 + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + random_seed = 42 + k = 40 + sparsity_penalty = 0.05 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = 24000 + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TrainerTopK, + "dict_class": AutoEncoderTopK, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": random_seed, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": layer, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "seed": random_seed, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": layer, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * 2: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=llm_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + for key, value in expected_results.items(): + assert abs(eval_results[key] - value) < EVAL_TOLERANCE From 9ed4af245a22e095e932d6065d368c58947d9a3d Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 17 Dec 2024 23:41:15 +0000 Subject: [PATCH 03/70] Rename input to inputs per nnsight 0.3.0 --- buffer.py | 4 ++-- evaluation.py | 2 +- requirements.txt | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/buffer.py b/buffer.py index 86f24f9..178ea3e 100644 --- a/buffer.py +++ b/buffer.py @@ -124,7 +124,7 @@ def refresh(self): hidden_states = self.submodule.input[0].save() else: hidden_states = self.submodule.output.save() - input = self.model.input.save() + input = self.model.inputs.save() attn_mask = input.value[1]["attention_mask"] hidden_states = hidden_states.value if isinstance(hidden_states, tuple): @@ -251,7 +251,7 @@ def refresh(self): while len(self.activations) < self.n_ctxs * self.ctx_len: with t.no_grad(): with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): - input = self.model.input.save() + input = self.model.inputs.save() hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] diff --git a/evaluation.py b/evaluation.py index 6b3b0e5..13bf4fa 100644 --- a/evaluation.py +++ b/evaluation.py @@ -110,7 +110,7 @@ def loss_recovered( else: raise ValueError(f"Invalid value for io: {io}") - input = model.input.save() + input = model.inputs.save() logits_zero = model.output.save() logits_zero = logits_zero.value diff --git a/requirements.txt b/requirements.txt index 5b9f3c6..bda16d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ circuitsvis>=1.43.2 datasets>=2.18.0 einops>=0.7.0 matplotlib>=3.8.3 -nnsight>=0.2.11 +nnsight>=0.3.0 pandas>=2.2.1 plotly>=5.18.0 torch>=2.1.2 From 807f6ef735872a5cab68773a315f15bc920c3d72 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 02:18:19 +0000 Subject: [PATCH 04/70] Complete nnsight 0.2 to 0.3 changes --- buffer.py | 6 +++--- evaluation.py | 46 +++++++++++++--------------------------- interp.py | 2 +- tests/test_end_to_end.py | 42 ++++++++++++++++++------------------ 4 files changed, 40 insertions(+), 56 deletions(-) diff --git a/buffer.py b/buffer.py index 178ea3e..be3a745 100644 --- a/buffer.py +++ b/buffer.py @@ -121,7 +121,7 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io == "in": - hidden_states = self.submodule.input[0].save() + hidden_states = self.submodule.inputs[0].save() else: hidden_states = self.submodule.output.save() input = self.model.inputs.save() @@ -252,7 +252,7 @@ def refresh(self): with t.no_grad(): with self.model.trace(self.text_batch(), **tracer_kwargs, invoker_args={'truncation': True, 'max_length': self.ctx_len}, remote=self.remote): input = self.model.inputs.save() - hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.input[0][0]#.save() + hidden_states = self.model.model.layers[self.layer].self_attn.o_proj.inputs[0][0]#.save() if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] @@ -416,7 +416,7 @@ def refresh(self): invoker_args={"truncation": True, "max_length": self.ctx_len}, ): if self.io in ["in", "in_and_out"]: - hidden_states_in = self.submodule.input[0].save() + hidden_states_in = self.submodule.inputs[0].save() if self.io in ["out", "in_and_out"]: hidden_states_out = self.submodule.output.save() diff --git a/evaluation.py b/evaluation.py index 13bf4fa..558d6c8 100644 --- a/evaluation.py +++ b/evaluation.py @@ -36,21 +36,17 @@ def loss_recovered( # logits when replacing component activations with reconstruction by autoencoder with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'out': - x = submodule.output - if type(submodule.output.shape) == tuple: x = x[0] + x = submodule.output[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'in_and_out': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: x = x[0] - print(f'x.shape: {x.shape}') + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale @@ -58,35 +54,28 @@ def loss_recovered( raise ValueError(f"Invalid value for io: {io}") x = x.save() - # pull this out so dictionary can be written without FakeTensor (top_k needs this) - x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape).to(model.dtype) + x_hat = dictionary(x).to(model.dtype) # intervene with `x_hat` with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = x_hat - else: - submodule.input = x_hat + submodule.input[:] = x_hat elif io == 'out': - x = submodule.output + x = submodule.output[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - if type(submodule.output.shape) == tuple: - submodule.output = (x_hat,) - else: - submodule.output = x_hat + submodule.output[0][:] = x_hat elif io == 'in_and_out': - x = submodule.input[0] + x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output = x_hat + submodule.output[0][:] = x_hat else: raise ValueError(f"Invalid value for io: {io}") @@ -96,22 +85,17 @@ def loss_recovered( # logits when replacing component activations with zeros with model.trace(text, **tracer_args, invoker_args=invoker_args): if io == 'in': - x = submodule.input[0] - if type(submodule.input.shape) == tuple: - submodule.input[0][:] = t.zeros_like(x[0]) - else: - submodule.input = t.zeros_like(x) + x = submodule.input + submodule.input[:] = t.zeros_like(x) elif io in ['out', 'in_and_out']: - x = submodule.output - if type(submodule.output.shape) == tuple: - submodule.output[0][:] = t.zeros_like(x[0]) - else: - submodule.output = t.zeros_like(x) + x = submodule.output[0] + submodule.output[0][:] = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") input = model.inputs.save() logits_zero = model.output.save() + logits_zero = logits_zero.value # get everything into the right format diff --git a/interp.py b/interp.py index 283965b..e721eb9 100644 --- a/interp.py +++ b/interp.py @@ -101,7 +101,7 @@ def _list_decode(x): inputs = buffer.tokenized_batch(batch_size=n_inputs) with t.no_grad(), model.trace(inputs, **tracer_kwargs): - tokens = model.input[1][ + tokens = model.inputs[1][ "input_ids" ].save() # if you're getting errors, check here; might only work for pythia models activations = submodule.output diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 8cb4b95..c6eb0a3 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -19,32 +19,32 @@ EXPECTED_RESULTS = { "AutoEncoderTopK": { - "l2_loss": 4.470372676849365, - "l1_loss": 44.47749710083008, + "l2_loss": 4.462644577026367, + "l1_loss": 44.446834564208984, "l0": 40.0, "frac_alive": 0.000244140625, - "frac_variance_explained": 0.9372208118438721, - "cossim": 0.9471381902694702, - "l2_ratio": 0.9523985981941223, - "relative_reconstruction_bias": 0.9996458888053894, - "loss_original": 3.186223268508911, - "loss_reconstructed": 3.690929412841797, - "loss_zero": 12.936649322509766, - "frac_recovered": 0.9482374787330627, + "frac_variance_explained": 0.9372867941856384, + "cossim": 0.9471449851989746, + "l2_ratio": 0.9524278044700623, + "relative_reconstruction_bias": 0.9986423254013062, + "loss_original": 3.1832079887390137, + "loss_reconstructed": 3.713366985321045, + "loss_zero": 12.936450958251953, + "frac_recovered": 0.9456427693367004, }, "AutoEncoder": { - "l2_loss": 6.72230863571167, - "l1_loss": 28.893749237060547, - "l0": 61.12999725341797, + "l2_loss": 6.721538066864014, + "l1_loss": 28.914989471435547, + "l0": 61.29999923706055, "frac_alive": 0.000244140625, - "frac_variance_explained": 0.6076533794403076, - "cossim": 0.869738757610321, - "l2_ratio": 0.8005934953689575, - "relative_reconstruction_bias": 0.9304398894309998, - "loss_original": 3.186223268508911, - "loss_reconstructed": 5.501500129699707, - "loss_zero": 12.936649322509766, - "frac_recovered": 0.7625460624694824, + "frac_variance_explained": 0.6077123880386353, + "cossim": 0.869745135307312, + "l2_ratio": 0.801030695438385, + "relative_reconstruction_bias": 0.9309902191162109, + "loss_original": 3.1832079887390137, + "loss_reconstructed": 5.499264717102051, + "loss_zero": 12.936450958251953, + "frac_recovered": 0.7625347375869751, }, } From dc3072089c24ce1eb8bc40e9f5248c69a92f5174 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 02:53:40 +0000 Subject: [PATCH 05/70] Fix frac_alive calculation, perform evaluation over multiple batches --- evaluation.py | 51 +++++++++++++++++++++++----------------- tests/test_end_to_end.py | 4 ++-- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/evaluation.py b/evaluation.py index 558d6c8..9097713 100644 --- a/evaluation.py +++ b/evaluation.py @@ -3,6 +3,8 @@ """ import torch as t +from collections import defaultdict + from .buffer import ActivationBuffer, NNsightActivationBuffer from nnsight import LanguageModel from .config import DEBUG @@ -128,7 +130,7 @@ def loss_recovered( return tuple(losses) - +@t.no_grad() def evaluate( dictionary, # a dictionary activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered @@ -138,26 +140,28 @@ def evaluate( normalize_batch=False, # normalize batch before passing through dictionary tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. device="cpu", + n_batches: int = 1, ): - with t.no_grad(): - - out = {} # dict of results + assert n_batches > 0 + out = defaultdict(float) + active_features = t.zeros(dictionary.dict_size, dtype=t.float32, device=device) + for _ in range(n_batches): try: x = next(activations).to(device) if normalize_batch: x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) - except StopIteration: raise StopIteration( "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." ) - x_hat, f = dictionary(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() l0 = (f != 0).float().sum(dim=-1).mean() - frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size + + features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) + active_features += features_BF.sum(dim=0) # cosine similarity between x and x_hat x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) @@ -177,18 +181,17 @@ def evaluate( x_dot_x_hat = (x * x_hat).sum(dim=-1) relative_reconstruction_bias = x_hat_norm_squared.mean() / x_dot_x_hat.mean() - out["l2_loss"] = l2_loss.item() - out["l1_loss"] = l1_loss.item() - out["l0"] = l0.item() - out["frac_alive"] = frac_alive.item() - out["frac_variance_explained"] = frac_variance_explained.item() - out["cossim"] = cossim.item() - out["l2_ratio"] = l2_ratio.item() - out['relative_reconstruction_bias'] = relative_reconstruction_bias.item() + out["l2_loss"] += l2_loss.item() + out["l1_loss"] += l1_loss.item() + out["l0"] += l0.item() + out["frac_variance_explained"] += frac_variance_explained.item() + out["cossim"] += cossim.item() + out["l2_ratio"] += l2_ratio.item() + out['relative_reconstruction_bias'] += relative_reconstruction_bias.item() if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): - return out - + continue + # compute loss recovered loss_original, loss_reconstructed, loss_zero = loss_recovered( activations.text_batch(batch_size=batch_size), @@ -202,9 +205,13 @@ def evaluate( ) frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) - out["loss_original"] = loss_original.item() - out["loss_reconstructed"] = loss_reconstructed.item() - out["loss_zero"] = loss_zero.item() - out["frac_recovered"] = frac_recovered.item() + out["loss_original"] += loss_original.item() + out["loss_reconstructed"] += loss_reconstructed.item() + out["loss_zero"] += loss_zero.item() + out["frac_recovered"] += frac_recovered.item() + + out = {key: value / n_batches for key, value in out.items()} + frac_alive = (active_features != 0).float().sum() / dictionary.dict_size + out["frac_alive"] = frac_alive.item() - return out + return out \ No newline at end of file diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index c6eb0a3..c41a433 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -22,7 +22,7 @@ "l2_loss": 4.462644577026367, "l1_loss": 44.446834564208984, "l0": 40.0, - "frac_alive": 0.000244140625, + "frac_alive": 0.45458984375, "frac_variance_explained": 0.9372867941856384, "cossim": 0.9471449851989746, "l2_ratio": 0.9524278044700623, @@ -36,7 +36,7 @@ "l2_loss": 6.721538066864014, "l1_loss": 28.914989471435547, "l0": 61.29999923706055, - "frac_alive": 0.000244140625, + "frac_alive": 0.14404296875, "frac_variance_explained": 0.6077123880386353, "cossim": 0.869745135307312, "l2_ratio": 0.801030695438385, From 067bf7b05470f61b9ed4f38b95be55c5ac45fb8f Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 03:55:05 +0000 Subject: [PATCH 06/70] Obtain better test results using multiple batches --- evaluation.py | 5 +++- tests/test_end_to_end.py | 59 ++++++++++++++++++++-------------------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/evaluation.py b/evaluation.py index 9097713..99fddef 100644 --- a/evaluation.py +++ b/evaluation.py @@ -161,6 +161,9 @@ def evaluate( l0 = (f != 0).float().sum(dim=-1).mean() features_BF = t.flatten(f, start_dim=0, end_dim=-2).to(dtype=t.float32) # If f is shape (B, L, D), flatten to (B*L, D) + assert features_BF.shape[-1] == dictionary.dict_size + assert len(features_BF.shape) == 2 + active_features += features_BF.sum(dim=0) # cosine similarity between x and x_hat @@ -191,7 +194,7 @@ def evaluate( if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): continue - + # compute loss recovered loss_original, loss_reconstructed, loss_zero = loss_recovered( activations.text_batch(batch_size=batch_size), diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index c41a433..ce5a1cf 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -19,32 +19,32 @@ EXPECTED_RESULTS = { "AutoEncoderTopK": { - "l2_loss": 4.462644577026367, - "l1_loss": 44.446834564208984, + "l2_loss": 4.325331306457519, + "l1_loss": 47.92763671875, "l0": 40.0, - "frac_alive": 0.45458984375, - "frac_variance_explained": 0.9372867941856384, - "cossim": 0.9471449851989746, - "l2_ratio": 0.9524278044700623, - "relative_reconstruction_bias": 0.9986423254013062, - "loss_original": 3.1832079887390137, - "loss_reconstructed": 3.713366985321045, - "loss_zero": 12.936450958251953, - "frac_recovered": 0.9456427693367004, + "frac_variance_explained": 0.9584966480731965, + "cossim": 0.948570293188095, + "l2_ratio": 0.94872345328331, + "relative_reconstruction_bias": 0.9998040139675141, + "loss_original": 3.328495955467224, + "loss_reconstructed": 3.819682216644287, + "loss_zero": 13.250199031829833, + "frac_recovered": 0.9503251194953919, + "frac_alive": 0.99951171875, }, "AutoEncoder": { - "l2_loss": 6.721538066864014, - "l1_loss": 28.914989471435547, - "l0": 61.29999923706055, - "frac_alive": 0.14404296875, - "frac_variance_explained": 0.6077123880386353, - "cossim": 0.869745135307312, - "l2_ratio": 0.801030695438385, - "relative_reconstruction_bias": 0.9309902191162109, - "loss_original": 3.1832079887390137, - "loss_reconstructed": 5.499264717102051, - "loss_zero": 12.936450958251953, - "frac_recovered": 0.7625347375869751, + "l2_loss": 6.5741173267364506, + "l1_loss": 32.06615734100342, + "l0": 60.9147216796875, + "frac_variance_explained": 0.9042629599571228, + "cossim": 0.8782194256782532, + "l2_ratio": 0.814234834909439, + "relative_reconstruction_bias": 0.9813631415367127, + "loss_original": 3.328495955467224, + "loss_reconstructed": 5.7899915218353275, + "loss_zero": 13.250199031829833, + "frac_recovered": 0.754741370677948, + "frac_alive": 0.9921875, }, } @@ -105,9 +105,7 @@ def test_sae_training(): random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) - MODEL_NAME = "EleutherAI/pythia-70m-deduped" model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) - layer = 3 context_length = 128 llm_batch_size = 512 # Fits on a 24GB GPU @@ -172,7 +170,7 @@ def test_sae_training(): "seed": random_seed, "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", "device": DEVICE, - "layer": layer, + "layer": LAYER, "lm_name": MODEL_NAME, "submodule_name": submodule_name, }, @@ -191,7 +189,7 @@ def test_sae_training(): "resample_steps": resample_steps, "seed": random_seed, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", - "layer": layer, + "layer": LAYER, "lm_name": MODEL_NAME, "device": DEVICE, "submodule_name": submodule_name, @@ -230,6 +228,8 @@ def test_evaluation(): context_length = 128 llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 buffer_size = 256 io = "out" @@ -239,7 +239,7 @@ def test_evaluation(): input_strings = [] for i, example in enumerate(generator): input_strings.append(example) - if i > buffer_size * 2: + if i > buffer_size * n_batches: break for ae_path in ae_paths: @@ -258,7 +258,7 @@ def test_evaluation(): n_ctxs=buffer_size, ctx_len=context_length, refresh_batch_size=llm_batch_size, - out_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, io=io, d_submodule=activation_dim, device=DEVICE, @@ -271,6 +271,7 @@ def test_evaluation(): llm_batch_size, io=io, device=DEVICE, + n_batches=n_batches, ) print(eval_results) From 05fe179f5b0616310253deaf758c370071f534fa Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 03:55:27 +0000 Subject: [PATCH 07/70] Add early stopping in forward pass --- buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/buffer.py b/buffer.py index be3a745..d997596 100644 --- a/buffer.py +++ b/buffer.py @@ -125,6 +125,8 @@ def refresh(self): else: hidden_states = self.submodule.output.save() input = self.model.inputs.save() + + self.submodule.output.stop() attn_mask = input.value[1]["attention_mask"] hidden_states = hidden_states.value if isinstance(hidden_states, tuple): From f1b9b800bc8e2cc308d4d14690df71f854b30fce Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 04:00:52 +0000 Subject: [PATCH 08/70] Change save_steps to a list of ints --- training.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/training.py b/training.py index f100fee..13fd4b3 100644 --- a/training.py +++ b/training.py @@ -6,6 +6,7 @@ import multiprocessing as mp import os from queue import Empty +from typing import Optional import torch as t from tqdm import tqdm @@ -75,17 +76,17 @@ def log_stats( def trainSAE( data, - trainer_configs, - use_wandb=False, - wandb_entity="", - wandb_project="", - steps=None, - save_steps=None, - save_dir=None, - log_steps=None, - activations_split_by_head=False, - transcoder=False, - run_cfg={}, + trainer_configs: list[dict], + use_wandb:bool=False, + wandb_entity:str="", + wandb_project:str="", + steps:Optional[int]=None, + save_steps:Optional[list[int]]=None, + save_dir:Optional[str]=None, + log_steps:Optional[int]=None, + activations_split_by_head:bool=False, + transcoder:bool=False, + run_cfg:dict={}, ): """ Train SAEs using the given trainers @@ -140,7 +141,7 @@ def trainSAE( ) # saving - if save_steps is not None and step % save_steps == 0: + if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): if dir is not None: if not os.path.exists(os.path.join(dir, "checkpoints")): From d350415e119cacb6547703eb9733daf8ef57075b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 18 Dec 2024 16:30:15 +0000 Subject: [PATCH 09/70] Check for is_tuple to support mlp / attn submodules --- evaluation.py | 37 ++++++++++++++++++++++++++++++------- tests/test_end_to_end.py | 7 +++---- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/evaluation.py b/evaluation.py index 99fddef..ba56437 100644 --- a/evaluation.py +++ b/evaluation.py @@ -24,12 +24,21 @@ def loss_recovered( How much of the model's loss is recovered by replacing the component output with the reconstruction by the autoencoder? """ - + if max_len is None: invoker_args = {} else: invoker_args = {"truncation": True, "max_length": max_len } + with model.trace("_"): + temp_output = submodule.output.save() + + output_is_tuple = False + # Note: isinstance() won't work here as torch.Size is a subclass of tuple, + # so isinstance(temp_output.shape, tuple) would return True even for torch.Size. + if type(temp_output.shape) == tuple: + output_is_tuple = True + # unmodified logits with model.trace(text, invoker_args=invoker_args): logits_original = model.output.save() @@ -43,7 +52,8 @@ def loss_recovered( scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale elif io == 'out': - x = submodule.output[0] + x = submodule.output + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x = x * scale @@ -56,6 +66,9 @@ def loss_recovered( raise ValueError(f"Invalid value for io: {io}") x = x.save() + # If we incorrectly handle output_is_tuple, such as with some mlp submodules, we will get an error here. + assert len(x.shape) == 3, f"Expected x to have shape (B, L, D), got {x.shape}, output_is_tuple: {output_is_tuple}" + x_hat = dictionary(x).to(model.dtype) # intervene with `x_hat` @@ -67,17 +80,24 @@ def loss_recovered( x_hat = x_hat / scale submodule.input[:] = x_hat elif io == 'out': - x = submodule.output[0] + x = submodule.output + if output_is_tuple: x = x[0] if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output[0][:] = x_hat + if output_is_tuple: + submodule.output[0][:] = x_hat + else: + submodule.output[:] = x_hat elif io == 'in_and_out': x = submodule.input if normalize_batch: scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() x_hat = x_hat / scale - submodule.output[0][:] = x_hat + if output_is_tuple: + submodule.output[0][:] = x_hat + else: + submodule.output[:] = x_hat else: raise ValueError(f"Invalid value for io: {io}") @@ -90,8 +110,11 @@ def loss_recovered( x = submodule.input submodule.input[:] = t.zeros_like(x) elif io in ['out', 'in_and_out']: - x = submodule.output[0] - submodule.output[0][:] = t.zeros_like(x) + x = submodule.output + if output_is_tuple: + submodule.output[0][:] = t.zeros_like(x[0]) + else: + submodule.output[:] = t.zeros_like(x) else: raise ValueError(f"Invalid value for io: {io}") diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index ce5a1cf..8e93cab 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -100,7 +100,7 @@ def load_dictionary(base_path: str, device: str) -> tuple: def test_sae_training(): - """End to end test for training an SAE. Takes ~3 minutes on an RTX 3090. + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing.""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -117,7 +117,6 @@ def test_sae_training(): num_tokens = 10_000_000 # sae training parameters - random_seed = 42 k = 40 sparsity_penalty = 0.05 expansion_factor = 8 @@ -167,7 +166,7 @@ def test_sae_training(): "auxk_alpha": auxk_alpha, # see Appendix A.2 "decay_start": decay_start, # when does the lr decay start "steps": steps, # when when does training end - "seed": random_seed, + "seed": RANDOM_SEED, "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", "device": DEVICE, "layer": LAYER, @@ -187,7 +186,7 @@ def test_sae_training(): "l1_penalty": sparsity_penalty, "warmup_steps": warmup_steps, "resample_steps": resample_steps, - "seed": random_seed, + "seed": RANDOM_SEED, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", "layer": LAYER, "lm_name": MODEL_NAME, From d416eab5de1edfe8ea75c972cdf78d9de68642c2 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 20 Dec 2024 03:46:51 +0000 Subject: [PATCH 10/70] Ensure activation buffer has the correct dtype --- buffer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/buffer.py b/buffer.py index d997596..6cbf8e3 100644 --- a/buffer.py +++ b/buffer.py @@ -40,7 +40,7 @@ def __init__(self, d_submodule = submodule.out_features except: raise ValueError("d_submodule cannot be inferred and must be specified directly") - self.activations = t.empty(0, d_submodule, device=device) + self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype) self.read = t.zeros(0).bool() self.data = data @@ -105,7 +105,7 @@ def refresh(self): self.activations = self.activations[~self.read] current_idx = len(self.activations) - new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device) + new_activations = t.empty(self.activation_buffer_size, self.d_submodule, device=self.device, dtype=self.model.dtype) new_activations[: len(self.activations)] = self.activations self.activations = new_activations From 552a8c2c12d41b5d520c99bf3534dff5329f0fde Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 20 Dec 2024 03:48:12 +0000 Subject: [PATCH 11/70] Fix JumpReLU training and loading --- dictionary.py | 3 ++- trainers/__init__.py | 2 +- trainers/jumprelu.py | 11 +++++++---- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/dictionary.py b/dictionary.py index 7950cf7..ababd85 100644 --- a/dictionary.py +++ b/dictionary.py @@ -284,9 +284,10 @@ def from_pretrained( """ if not load_from_sae_lens: state_dict = t.load(path) - dict_size, activation_dim = state_dict['W_enc'].shape + activation_dim, dict_size = state_dict['W_enc'].shape autoencoder = JumpReluAutoEncoder(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) + autoencoder = autoencoder.to(dtype=dtype, device=device) else: from sae_lens import SAE sae, cfg_dict, _ = SAE.from_pretrained(**kwargs) diff --git a/trainers/__init__.py b/trainers/__init__.py index 461af62..99b015b 100644 --- a/trainers/__init__.py +++ b/trainers/__init__.py @@ -3,5 +3,5 @@ from .p_anneal import PAnnealTrainer from .gated_anneal import GatedAnnealTrainer from .top_k import TrainerTopK -from .jumprelu import TrainerJumpRelu +from .jumprelu import JumpReluTrainer from .batch_top_k import TrainerBatchTopK, BatchTopKSAE diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index f87785a..a3a6371 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -60,7 +60,7 @@ def backward(ctx, grad_output): return x_grad, threshold_grad, None # None for bandwidth -class TrainerJumpRelu(nn.Module, SAETrainer): +class JumpReluTrainer(nn.Module, SAETrainer): """ Trains a JumpReLU autoencoder. @@ -77,7 +77,8 @@ def __init__( # TODO: What's the default lr use in the paper? lr=7e-5, bandwidth=0.001, - sparsity_penalty=0.1, + sparsity_penalty=1.0, + target_l0=20.0, device="cpu", layer=None, lm_name=None, @@ -99,6 +100,7 @@ def __init__( self.bandwidth = bandwidth self.sparsity_coefficient = sparsity_penalty + self.target_l0 = target_l0 # TODO: Better auto-naming (e.g. in BatchTopK package) self.wandb_name = wandb_name @@ -123,7 +125,8 @@ def loss(self, x, logging=False, **_): recon_loss = (x - recon).pow(2).sum(dim=-1).mean() l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() - sparsity_loss = self.sparsity_coefficient * l0 + + sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) loss = recon_loss + sparsity_loss if not logging: @@ -153,7 +156,7 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "TrainerJumpRelu", + "trainer_class": "JumpReluTrainer", "dict_class": "JumpReluAutoEncoder", "lr": self.lr, "steps": self.steps, From 712eb98f78d9537aa3ff01a1d9e007361e67c267 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 20 Dec 2024 03:48:32 +0000 Subject: [PATCH 12/70] Begin creation of demo script --- demo.py | 432 +++++++++++++++++++++++++++++++++++++++ graphing.ipynb | 205 +++++++++++++++++++ tests/test_end_to_end.py | 56 ++--- utils.py | 61 +++++- 4 files changed, 709 insertions(+), 45 deletions(-) create mode 100644 demo.py create mode 100644 graphing.ipynb diff --git a/demo.py b/demo.py new file mode 100644 index 0000000..fbac8de --- /dev/null +++ b/demo.py @@ -0,0 +1,432 @@ +import torch as t +from nnsight import LanguageModel +import argparse +import itertools +import os +import json + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK +from dictionary_learning.trainers.gdm import GatedSAETrainer +from dictionary_learning.trainers.p_anneal import PAnnealTrainer +from dictionary_learning.trainers.jumprelu import JumpReluTrainer +from dictionary_learning.utils import hf_dataset_to_generator +from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.dictionary import AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder +from dictionary_learning.evaluation import evaluate +import dictionary_learning.utils as utils + + +DEVICE = "cuda:0" + +LLM_CONFIG = { + "EleutherAI/pythia-70m-deduped": { + "llm_batch_size": 512, + "context_length": 128, + "sae_batch_size": 4096, + "dtype": t.float32, + }, + "google/gemma-2-2b": { + "llm_batch_size": 32, + "context_length": 128, + "sae_batch_size": 2048, + "dtype": t.bfloat16, + }, +} + +SPARSITY_PENALTIES = { + "EleutherAI/pythia-70m-deduped": { + "standard": [0.01, 0.05, 0.075, 0.1, 0.125, 0.15], + "p_anneal": [0.02, 0.03, 0.035, 0.04, 0.05, 0.075], + "gated": [0.1, 0.3, 0.5, 0.7, 0.9, 1.1], + }, + "google/gemma-2-2b": { + "standard": [0.025, 0.035, 0.04, 0.05, 0.06, 0.07], + "p_anneal": [-1, -1, -1, -1, -1, -1], + "gated": [-1, -1, -1, -1, -1, -1], + }, +} + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--save_dir", type=str, required=True, help="where to store sweep") + parser.add_argument("--use_wandb", action="store_true", help="use wandb logging") + parser.add_argument("--dry_run", action="store_true", help="dry run sweep") + parser.add_argument( + "--layers", type=int, nargs="+", required=True, help="layers to train SAE on" + ) + parser.add_argument( + "--model_name", + type=str, + required=True, + help="which language model to use", + ) + parser.add_argument( + "--architectures", + type=str, + nargs="+", + choices=["standard", "standard_new", "top_k", "gated", "p_anneal", "jump_relu"], + required=True, + help="which SAE architectures to train", + ) + args = parser.parse_args() + return args + + +def run_sae_training( + model_name: str, + layer: int, + save_dir: str, + device: str, + architectures: list, + dry_run: bool = False, + use_wandb: bool = False, + save_checkpoints: bool = False, +): + # model and data parameters + context_length = LLM_CONFIG[model_name]["context_length"] + + llm_batch_size = LLM_CONFIG[model_name]["llm_batch_size"] + sae_batch_size = LLM_CONFIG[model_name]["sae_batch_size"] + dtype = LLM_CONFIG[model_name]["dtype"] + num_tokens = 50_000_000 + + num_contexts_per_sae_batch = sae_batch_size // context_length + buffer_size = num_contexts_per_sae_batch * 20 + + # sae training parameters + # random_seeds = t.arange(10).tolist() + random_seeds = [0] + expansion_factors = [8] + + num_sparsities = 6 + sparsity_indices = t.arange(num_sparsities).tolist() + standard_sparsity_penalties = SPARSITY_PENALTIES[model_name]["standard"] + p_anneal_sparsity_penalties = SPARSITY_PENALTIES[model_name]["p_anneal"] + gated_sparsity_penalties = SPARSITY_PENALTIES[model_name]["gated"] + ks = [20, 40, 80, 160, 320, 640] + + assert len(standard_sparsity_penalties) == num_sparsities + assert len(p_anneal_sparsity_penalties) == num_sparsities + assert len(gated_sparsity_penalties) == num_sparsities + assert len(ks) == num_sparsities + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # note: learning rate is not used for topk + learning_rates = [3e-4] + + # topk sae training parameters + decay_start = 24000 + auxk_alpha = 1 / 32 + + # p_anneal sae training parameters + p_start = 1 + p_end = 0.2 + anneal_end = None # steps - int(steps/10) + sparsity_queue_length = 10 + anneal_start = 10000 + n_sparsity_updates = 10 + + # jumprelu sae training parameters + jumprelu_bandwidth = 0.001 + jumprelu_sparsity_penalty = 1.0 # per figure 9 in the paper + + if save_checkpoints: + # Creates checkpoints at 0.1%, 0.316%, 1%, 3.16%, 10%, 31.6%, 100% of training + desired_checkpoints = t.logspace(-3, 0, 7).tolist() + desired_checkpoints = [0.0] + desired_checkpoints[:-1] + desired_checkpoints.sort() + print(f"desired_checkpoints: {desired_checkpoints}") + + save_steps = [int(steps * step) for step in desired_checkpoints] + save_steps.sort() + print(f"save_steps: {save_steps}") + else: + save_steps = None + + log_steps = 100 # Log the training on wandb + if not use_wandb: + log_steps = None + + model = LanguageModel(model_name, dispatch=True, device_map=DEVICE) + model = model.to(dtype=dtype) + submodule = utils.get_submodule(model, layer) + submodule_name = f"resid_post_layer_{layer}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator("monology/pile-uncopyrighted") + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=device, + ) + + # create the list of configs + trainer_configs = [] + + for seed, sparsity_index, expansion_factor, learning_rate in itertools.product( + random_seeds, sparsity_indices, expansion_factors, learning_rates + ): + if "p_anneal" in architectures: + trainer_configs.append( + { + "trainer": PAnnealTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "sparsity_function": "Lp^p", + "initial_sparsity_penalty": p_anneal_sparsity_penalties[sparsity_index], + "p_start": p_start, + "p_end": p_end, + "anneal_start": int(anneal_start), + "anneal_end": anneal_end, + "sparsity_queue_length": sparsity_queue_length, + "n_sparsity_updates": n_sparsity_updates, + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "steps": steps, + "seed": seed, + "wandb_name": f"PAnnealTrainer-pythia70m-{layer}", + "layer": layer, + "lm_name": model_name, + "device": device, + "submodule_name": submodule_name, + }, + ) + if "standard" in architectures: + trainer_configs.append( + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": standard_sparsity_penalties[sparsity_index], + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "seed": seed, + "wandb_name": f"StandardTrainer-{model_name}-{submodule_name}", + "layer": layer, + "lm_name": model_name, + "device": device, + "submodule_name": submodule_name, + } + ) + if "standard_new" in architectures: + trainer_configs.append( + { + "trainer": StandardTrainer, + "dict_class": AutoEncoderNew, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": standard_sparsity_penalties[sparsity_index], + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "seed": seed, + "wandb_name": f"StandardTrainerNew-{model_name}-{submodule_name}", + "layer": layer, + "lm_name": model_name, + "device": device, + "submodule_name": submodule_name, + } + ) + if "top_k" in architectures: + trainer_configs.append( + { + "trainer": TrainerTopK, + "dict_class": AutoEncoderTopK, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": ks[sparsity_index], + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": seed, + "wandb_name": f"TopKTrainer-{model_name}-{submodule_name}", + "device": device, + "layer": layer, + "lm_name": model_name, + "submodule_name": submodule_name, + } + ) + if "gated" in architectures: + trainer_configs.append( + { + "trainer": GatedSAETrainer, + "dict_class": GatedAutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": gated_sparsity_penalties[sparsity_index], + "warmup_steps": warmup_steps, + "resample_steps": resample_steps, + "seed": seed, + "wandb_name": f"GatedSAETrainer-{model_name}-{submodule_name}", + "device": device, + "layer": layer, + "lm_name": model_name, + "submodule_name": submodule_name, + } + ) + if "jump_relu" in architectures: + trainer_configs.append( + { + "trainer": JumpReluTrainer, + "dict_class": JumpReluAutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "target_l0": ks[sparsity_index], + "sparsity_penalty": jumprelu_sparsity_penalty, + "bandwidth": jumprelu_bandwidth, + "seed": seed, + "wandb_name": f"JumpReLUSAETrainer-{model_name}-{submodule_name}", + "device": device, + "layer": layer, + "lm_name": model_name, + "submodule_name": submodule_name, + } + ) + + print(f"len trainer configs: {len(trainer_configs)}") + save_dir = f"{save_dir}/{submodule_name}" + + if not dry_run: + # actually run the sweep + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + use_wandb=use_wandb, + steps=steps, + save_steps=save_steps, + save_dir=save_dir, + log_steps=log_steps, + ) + +@t.no_grad() +def eval_saes( + model_name: str, + ae_paths: list[str], + n_inputs: int, + device: str, + overwrite_prev_results: bool = False, + transcoder: bool = False, +) -> dict: + + if transcoder: + io = "in_and_out" + else: + io = "out" + + context_length = LLM_CONFIG[model_name]["context_length"] + llm_batch_size = LLM_CONFIG[model_name]["llm_batch_size"] + loss_recovered_batch_size = llm_batch_size // 5 + sae_batch_size = loss_recovered_batch_size * context_length + dtype = LLM_CONFIG[model_name]["dtype"] + + model = LanguageModel(model_name, dispatch=True, device_map=DEVICE) + model = model.to(dtype=dtype) + + buffer_size = n_inputs + io = "out" + n_batches = n_inputs // loss_recovered_batch_size + + generator = hf_dataset_to_generator("monology/pile-uncopyrighted") + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > n_inputs * 5: + break + + eval_results = {} + + for ae_path in ae_paths: + output_filename = f"{ae_path}/eval_results.json" + if not overwrite_prev_results: + if os.path.exists(output_filename): + print(f"Skipping {ae_path} as eval results already exist") + continue + + dictionary, config = utils.load_dictionary(ae_path, device) + dictionary = dictionary.to(dtype=model.dtype) + + layer = config["trainer"]["layer"] + submodule = utils.get_submodule(model, layer) + + activation_dim = config["trainer"]["activation_dim"] + + activation_buffer = ActivationBuffer( + iter(input_strings), + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=device, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + loss_recovered_batch_size, + io=io, + device=device, + n_batches=n_batches, + ) + + hyperparameters = { + "n_inputs": n_inputs, + "context_length": context_length, + } + eval_results["hyperparameters"] = hyperparameters + + print(eval_results) + + with open(output_filename, "w") as f: + json.dump(eval_results, f) + + # return the final eval_results for testing purposes + return eval_results + + +if __name__ == "__main__": + """python pythia.py --save_dir ./run2 --model_name EleutherAI/pythia-70m-deduped --layers 3 --architectures standard standard_new top_k gated --use_wandb + python pythia.py --save_dir ./run3 --model_name google/gemma-2-2b --layers 12 --architectures standard top_k --use_wandb + python pythia.py --save_dir ./jumprelu --model_name EleutherAI/pythia-70m-deduped --layers 3 --architectures jump_relu --use_wandb""" + args = get_args() + for layer in args.layers: + run_sae_training( + model_name=args.model_name, + layer=layer, + save_dir=args.save_dir, + device="cuda:0", + architectures=args.architectures, + dry_run=args.dry_run, + use_wandb=args.use_wandb, + ) + + ae_paths = utils.get_nested_folders(args.save_dir) + + eval_saes(args.model_name, ae_paths, 1000, DEVICE) + + diff --git a/graphing.ipynb b/graphing.ipynb new file mode 100644 index 0000000..2b6dc10 --- /dev/null +++ b/graphing.ipynb @@ -0,0 +1,205 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import json\n", + "from typing import Optional\n", + "\n", + "import dictionary_learning.utils as utils\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "TRAINER_LABELS = {\n", + " \"StandardTrainer\": \"Standard\",\n", + " \"JumpReluTrainer\": \"JumpReLU\",\n", + " \"TrainerTopK\": \"Top K\",\n", + " \"GatedSAETrainer\": \"Gated\",\n", + " \"PAnnealTrainer\": \"P-Anneal\",\n", + "}\n", + "\n", + "TRAINER_MARKERS = {\n", + " \"StandardTrainer\": \"o\",\n", + " \"JumpReluTrainer\": \"X\",\n", + " \"TrainerTopK\": \"^\",\n", + " \"GatedSAETrainer\": \"d\",\n", + " \"PAnnealTrainer\": \"s\",\n", + "}\n", + "\n", + "TRAINER_COLORS = {\n", + " \"StandardTrainer\": \"blue\",\n", + " \"JumpReluTrainer\": \"orange\",\n", + " \"TrainerTopK\": \"green\",\n", + " \"GatedSAETrainer\": \"red\",\n", + " \"PAnnealTrainer\": \"purple\",\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "save_dirs = [\"./run2\", \"./jumprelu\"]\n", + "# save_dirs = [\"./run2\"]\n", + "ae_paths = []\n", + "\n", + "for save_dir in save_dirs:\n", + " ae_paths.extend(utils.get_nested_folders(save_dir))\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plotting_results = {}\n", + "\n", + "for ae_path in ae_paths:\n", + " with open(ae_path + \"/config.json\") as f:\n", + " config = json.load(f)\n", + "\n", + " with open(ae_path + \"/eval_results.json\") as f:\n", + " eval_results = json.load(f)\n", + "\n", + " ae_results = {}\n", + "\n", + " ae_results[\"l0\"] = eval_results[\"l0\"]\n", + " ae_results[\"frac_recovered\"] = eval_results[\"frac_recovered\"]\n", + " ae_results[\"trainer_class\"] = config[\"trainer\"][\"trainer_class\"]\n", + " ae_results[\"dict_size\"] = config[\"trainer\"][\"dict_size\"]\n", + "\n", + " plotting_results[ae_path] = ae_results\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_2var_graph(\n", + " results: dict[str, dict[str, float]],\n", + " custom_metric: str,\n", + " title: str = \"L0 vs Custom Metric\",\n", + " y_label: str = \"Custom Metric\",\n", + " xlims: Optional[tuple[float, float]] = None,\n", + " ylims: Optional[tuple[float, float]] = None,\n", + " output_filename: Optional[str] = None,\n", + " legend_location: str = \"lower right\",\n", + " x_axis_key: str = \"l0\",\n", + " return_fig: bool = False,\n", + "):\n", + " # Extract data from results\n", + " l0_values = [data[x_axis_key] for data in results.values()]\n", + " custom_metric_values = [data[custom_metric] for data in results.values()]\n", + "\n", + " # Create the scatter plot\n", + " fig, ax = plt.subplots(figsize=(10, 6))\n", + "\n", + " handles, labels = [], []\n", + "\n", + " for trainer, marker in TRAINER_MARKERS.items():\n", + " # Filter data for this trainer\n", + " trainer_data = {k: v for k, v in results.items() if v[\"trainer_class\"] == trainer}\n", + "\n", + " if not trainer_data:\n", + " continue # Skip this trainer if no data points\n", + "\n", + " l0_values = [data[x_axis_key] for data in trainer_data.values()]\n", + " custom_metric_values = [data[custom_metric] for data in trainer_data.values()]\n", + "\n", + " # Plot data points\n", + " scatter = ax.scatter(\n", + " l0_values,\n", + " custom_metric_values,\n", + " marker=marker,\n", + " s=100,\n", + " label=trainer,\n", + " color=TRAINER_COLORS[trainer],\n", + " edgecolor=\"black\",\n", + " )\n", + "\n", + " # Create custom legend handle with both marker and color\n", + " legend_handle = plt.scatter(\n", + " [], [], marker=marker, s=100, color=TRAINER_COLORS[trainer], edgecolor=\"black\"\n", + " )\n", + " handles.append(legend_handle)\n", + "\n", + " if trainer in TRAINER_LABELS:\n", + " trainer_label = TRAINER_LABELS[trainer]\n", + " else:\n", + " trainer_label = trainer.capitalize()\n", + " labels.append(trainer_label)\n", + "\n", + " # Set labels and title\n", + " ax.set_xlabel(\"L0 (Sparsity)\")\n", + " ax.set_ylabel(y_label)\n", + " ax.set_title(title)\n", + "\n", + " ax.legend(handles, labels, loc=legend_location)\n", + "\n", + " # Set axis limits\n", + " if xlims:\n", + " ax.set_xlim(*xlims)\n", + " if ylims:\n", + " ax.set_ylim(*ylims)\n", + "\n", + " plt.tight_layout()\n", + "\n", + " # Save and show the plot\n", + " if output_filename:\n", + " plt.savefig(output_filename, bbox_inches=\"tight\")\n", + "\n", + " if return_fig:\n", + " return fig\n", + "\n", + " plt.show()\n", + " \n", + "plt.rcParams.update({\"font.size\": 20})\n", + "plot_2var_graph(plotting_results, \"frac_recovered\", title=\"Fraction Recovered vs L0\", y_label=\"Fraction Recovered\", output_filename=\"frac_recovered_vs_l0.png\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 8e93cab..84aeab5 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -7,7 +7,7 @@ from dictionary_learning.training import trainSAE from dictionary_learning.trainers.standard import StandardTrainer from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK -from dictionary_learning.utils import hf_dataset_to_generator +from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary from dictionary_learning.buffer import ActivationBuffer from dictionary_learning.dictionary import ( AutoEncoder, @@ -58,50 +58,11 @@ EVAL_TOLERANCE = 0.01 -def get_nested_folders(path: str) -> list[str]: - """ - Recursively get a list of folders that contain an ae.pt file, starting the search from the given path - """ - folder_names = [] - - for root, dirs, files in os.walk(path): - if "ae.pt" in files: - folder_names.append(root) - - return folder_names - - -def load_dictionary(base_path: str, device: str) -> tuple: - ae_path = f"{base_path}/ae.pt" - config_path = f"{base_path}/config.json" - - with open(config_path, "r") as f: - config = json.load(f) - - # TODO: Save the submodule name in the config? - # submodule_str = config["trainer"]["submodule_name"] - dict_class = config["trainer"]["dict_class"] - - if dict_class == "AutoEncoder": - dictionary = AutoEncoder.from_pretrained(ae_path, device=device) - elif dict_class == "GatedAutoEncoder": - dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) - elif dict_class == "AutoEncoderNew": - dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) - elif dict_class == "AutoEncoderTopK": - k = config["trainer"]["k"] - dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) - elif dict_class == "JumpReluAutoEncoder": - dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) - else: - raise ValueError(f"Dictionary class {dict_class} not supported") - - return dictionary, config - - def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. - This isn't a nice suite of unit tests, but it's better than nothing.""" + This isn't a nice suite of unit tests, but it's better than nothing. + I have observed that results can slightly vary with library versions. For full determinism, + use pytorch 2.2.0 and nnsight 0.3.3.""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -278,5 +239,12 @@ def test_evaluation(): dict_class = config["trainer"]["dict_class"] expected_results = EXPECTED_RESULTS[dict_class] + max_diff = 0 + max_diff_percent = 0 for key, value in expected_results.items(): - assert abs(eval_results[key] - value) < EVAL_TOLERANCE + diff = abs(eval_results[key] - value) + max_diff = max(max_diff, diff) + max_diff_percent = max(max_diff_percent, diff / value) + + print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") + assert max_diff < EVAL_TOLERANCE diff --git a/utils.py b/utils.py index 8641f05..27a2188 100644 --- a/utils.py +++ b/utils.py @@ -2,6 +2,16 @@ import zstandard as zstd import io import json +import os +from nnsight import LanguageModel + +from dictionary_learning.trainers.top_k import AutoEncoderTopK +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) def hf_dataset_to_generator(dataset_name, split='train', streaming=True): dataset = load_dataset(dataset_name, split=split, streaming=streaming) @@ -24,4 +34,53 @@ def zst_to_generator(data_path): def generator(): for line in text_stream: yield json.loads(line)['text'] - return generator() \ No newline at end of file + return generator() + +def get_nested_folders(path: str) -> list[str]: + """ + Recursively get a list of folders that contain an ae.pt file, starting the search from the given path + """ + folder_names = [] + + for root, dirs, files in os.walk(path): + if "ae.pt" in files: + folder_names.append(root) + + return folder_names + + +def load_dictionary(base_path: str, device: str) -> tuple: + ae_path = f"{base_path}/ae.pt" + config_path = f"{base_path}/config.json" + + with open(config_path, "r") as f: + config = json.load(f) + + dict_class = config["trainer"]["dict_class"] + + if dict_class == "AutoEncoder": + dictionary = AutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "GatedAutoEncoder": + dictionary = GatedAutoEncoder.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderNew": + dictionary = AutoEncoderNew.from_pretrained(ae_path, device=device) + elif dict_class == "AutoEncoderTopK": + k = config["trainer"]["k"] + dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "JumpReluAutoEncoder": + dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) + else: + raise ValueError(f"Dictionary class {dict_class} not supported") + + return dictionary, config + +def get_submodule(model: LanguageModel, layer: int): + """Gets the residual stream submodule""" + model_name = model._model_key + + if "pythia" in model_name: + return model.gpt_neox.layers[layer] + elif "gemma" in model_name: + return model.model.layers[layer] + else: + raise ValueError(f"Please add submodule for model {model_name}") \ No newline at end of file From dcc02f04e504331011a54ce851a91976daf15170 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Sat, 21 Dec 2024 14:32:09 -0500 Subject: [PATCH 13/70] Modularize demo script --- demo.py | 470 ++++++++++++++++++++++++++++++++------------------------ 1 file changed, 269 insertions(+), 201 deletions(-) diff --git a/demo.py b/demo.py index fbac8de..e4dd062 100644 --- a/demo.py +++ b/demo.py @@ -4,50 +4,248 @@ import itertools import os import json - -from dictionary_learning.training import trainSAE -from dictionary_learning.trainers.standard import StandardTrainer -from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK -from dictionary_learning.trainers.gdm import GatedSAETrainer -from dictionary_learning.trainers.p_anneal import PAnnealTrainer -from dictionary_learning.trainers.jumprelu import JumpReluTrainer -from dictionary_learning.utils import hf_dataset_to_generator -from dictionary_learning.buffer import ActivationBuffer -from dictionary_learning.dictionary import AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder -from dictionary_learning.evaluation import evaluate -import dictionary_learning.utils as utils - - -DEVICE = "cuda:0" +from dataclasses import dataclass, field, asdict +from typing import Optional, Type, Any +from enum import Enum + +from training import trainSAE +from trainers.standard import StandardTrainer +from trainers.top_k import TopKTrainer, AutoEncoderTopK +from trainers.gdm import GatedSAETrainer +from trainers.p_anneal import PAnnealTrainer +from trainers.jumprelu import JumpReluTrainer +from utils import hf_dataset_to_generator +from buffer import ActivationBuffer +from dictionary import AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder +from evaluation import evaluate +import utils as utils + + +class TrainerType(Enum): + STANDARD = "standard" + STANDARD_NEW = "standard_new" + TOP_K = "top_k" + BATCH_TOP_K = "batch_top_k" + GATED = "gated" + P_ANNEAL = "p_anneal" + JUMP_RELU = "jump_relu" + + +@dataclass +class LLMConfig: + llm_batch_size: int + context_length: int + sae_batch_size: int + dtype: t.dtype + + +@dataclass +class SparsityPenalties: + standard: list[float] + p_anneal: list[float] + gated: list[float] + + +# TODO: Move all of these to a config file +num_tokens = 50_000_000 +eval_num_inputs = 1_000 +random_seeds = [0] +expansion_factors = [8] + +# note: learning rate is not used for topk +learning_rates = [3e-4] LLM_CONFIG = { - "EleutherAI/pythia-70m-deduped": { - "llm_batch_size": 512, - "context_length": 128, - "sae_batch_size": 4096, - "dtype": t.float32, - }, - "google/gemma-2-2b": { - "llm_batch_size": 32, - "context_length": 128, - "sae_batch_size": 2048, - "dtype": t.bfloat16, - }, + "EleutherAI/pythia-70m-deduped": LLMConfig( + llm_batch_size=512, context_length=128, sae_batch_size=4096, dtype=t.float32 + ), + "google/gemma-2-2b": LLMConfig( + llm_batch_size=32, context_length=128, sae_batch_size=2048, dtype=t.bfloat16 + ), } + +# NOTE: In the current setup, the length of each sparsity penalty and target_l0 should be the same SPARSITY_PENALTIES = { - "EleutherAI/pythia-70m-deduped": { - "standard": [0.01, 0.05, 0.075, 0.1, 0.125, 0.15], - "p_anneal": [0.02, 0.03, 0.035, 0.04, 0.05, 0.075], - "gated": [0.1, 0.3, 0.5, 0.7, 0.9, 1.1], - }, - "google/gemma-2-2b": { - "standard": [0.025, 0.035, 0.04, 0.05, 0.06, 0.07], - "p_anneal": [-1, -1, -1, -1, -1, -1], - "gated": [-1, -1, -1, -1, -1, -1], - }, + "EleutherAI/pythia-70m-deduped": SparsityPenalties( + standard=[0.01, 0.05, 0.075, 0.1, 0.125, 0.15], + p_anneal=[0.02, 0.03, 0.035, 0.04, 0.05, 0.075], + gated=[0.1, 0.3, 0.5, 0.7, 0.9, 1.1], + ), + "google/gemma-2-2b": SparsityPenalties( + standard=[0.025, 0.035, 0.04, 0.05, 0.06, 0.07], + p_anneal=[-1] * 6, + gated=[-1] * 6, + ), } + +TARGET_L0s = [20, 40, 80, 160, 320, 640] + + +@dataclass +class BaseTrainerConfig: + activation_dim: int + dict_size: int + seed: int + device: str + layer: str + lm_name: str + submodule_name: str + trainer: Type[Any] + dict_class: Type[Any] + wandb_name: str + steps: Optional[int] = None + + +@dataclass +class WarmupConfig: + warmup_steps: int = 1000 + resample_steps: Optional[int] = None + + +@dataclass +class StandardTrainerConfig(BaseTrainerConfig, WarmupConfig): + lr: float + l1_penalty: float + + +@dataclass +class StandardNewTrainerConfig(BaseTrainerConfig, WarmupConfig): + lr: float + l1_penalty: float + + +@dataclass +class PAnnealTrainerConfig(BaseTrainerConfig, WarmupConfig): + lr: float + initial_sparsity_penalty: float + sparsity_function: str = "Lp^p" + p_start: float = 1.0 + p_end: float = 0.2 + anneal_start: int = 10000 + anneal_end: Optional[int] = None + sparsity_queue_length: int = 10 + n_sparsity_updates: int = 10 + + +@dataclass +class TopKTrainerConfig(BaseTrainerConfig): + k: int + auxk_alpha: float = 1 / 32 + decay_start: int = 24000 + threshold_beta: float = 0.999 + + +@dataclass +class GatedTrainerConfig(BaseTrainerConfig, WarmupConfig): + lr: float + l1_penalty: float + + +@dataclass +class JumpReluTrainerConfig(BaseTrainerConfig): + lr: float + target_l0: int + sparsity_penalty: float = 1.0 + bandwidth: float = 0.001 + + +def get_trainer_configs( + architectures: list[str], + learning_rate: float, + sparsity_index: int, + seed: int, + activation_dim: int, + dict_size: int, + model_name: str, + device: str, + layer: str, + submodule_name: str, + steps: int, +) -> list[dict]: + trainer_configs = [] + + base_config = { + "activation_dim": activation_dim, + "dict_size": dict_size, + "seed": seed, + "device": device, + "layer": layer, + "lm_name": model_name, + "submodule_name": submodule_name, + } + + if TrainerType.P_ANNEAL.value in architectures: + config = PAnnealTrainerConfig( + **base_config, + trainer=PAnnealTrainer, + dict_class=AutoEncoder, + lr=learning_rate, + initial_sparsity_penalty=SPARSITY_PENALTIES[model_name].p_anneal[sparsity_index], + steps=steps, + wandb_name=f"PAnnealTrainer-{model_name}-{submodule_name}", + ) + trainer_configs.append(asdict(config)) + + if TrainerType.STANDARD.value in architectures: + config = StandardTrainerConfig( + **base_config, + trainer=StandardTrainer, + dict_class=AutoEncoder, + lr=learning_rate, + l1_penalty=SPARSITY_PENALTIES[model_name].standard[sparsity_index], + wandb_name=f"StandardTrainer-{model_name}-{submodule_name}", + ) + trainer_configs.append(asdict(config)) + + if TrainerType.STANDARD_NEW.value in architectures: + config = StandardNewTrainerConfig( + **base_config, + trainer=StandardTrainer, + dict_class=AutoEncoderNew, + lr=learning_rate, + l1_penalty=SPARSITY_PENALTIES[model_name].standard[sparsity_index], + wandb_name=f"StandardTrainerNew-{model_name}-{submodule_name}", + ) + trainer_configs.append(asdict(config)) + + if TrainerType.TOP_K.value in architectures: + config = TopKTrainerConfig( + **base_config, + trainer=TopKTrainer, + dict_class=AutoEncoderTopK, + k=TARGET_L0s[sparsity_index], + steps=steps, + wandb_name=f"TopKTrainer-{model_name}-{submodule_name}", + ) + trainer_configs.append(asdict(config)) + + if TrainerType.GATED.value in architectures: + config = GatedTrainerConfig( + **base_config, + trainer=GatedSAETrainer, + dict_class=GatedAutoEncoder, + lr=learning_rate, + l1_penalty=SPARSITY_PENALTIES[model_name].gated[sparsity_index], + wandb_name=f"GatedTrainer-{model_name}-{submodule_name}", + ) + trainer_configs.append(asdict(config)) + + if TrainerType.JUMP_RELU.value in architectures: + config = JumpReluTrainerConfig( + **base_config, + trainer=JumpReluTrainer, + dict_class=JumpReluAutoEncoder, + lr=learning_rate, + target_l0=TARGET_L0s[sparsity_index], + wandb_name=f"JumpReluTrainer-{model_name}-{submodule_name}", + ) + trainer_configs.append(asdict(config)) + + return trainer_configs + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--save_dir", type=str, required=True, help="where to store sweep") @@ -66,7 +264,7 @@ def get_args(): "--architectures", type=str, nargs="+", - choices=["standard", "standard_new", "top_k", "gated", "p_anneal", "jump_relu"], + choices=[e.value for e in TrainerType], required=True, help="which SAE architectures to train", ) @@ -80,9 +278,14 @@ def run_sae_training( save_dir: str, device: str, architectures: list, + num_tokens: int, + random_seeds: list[int], + expansion_factors: list[float], + learning_rates: list[float], dry_run: bool = False, use_wandb: bool = False, save_checkpoints: bool = False, + buffer_scaling_factor: int = 20, ): # model and data parameters context_length = LLM_CONFIG[model_name]["context_length"] @@ -90,50 +293,17 @@ def run_sae_training( llm_batch_size = LLM_CONFIG[model_name]["llm_batch_size"] sae_batch_size = LLM_CONFIG[model_name]["sae_batch_size"] dtype = LLM_CONFIG[model_name]["dtype"] - num_tokens = 50_000_000 num_contexts_per_sae_batch = sae_batch_size // context_length - buffer_size = num_contexts_per_sae_batch * 20 + buffer_size = num_contexts_per_sae_batch * buffer_scaling_factor # sae training parameters # random_seeds = t.arange(10).tolist() - random_seeds = [0] - expansion_factors = [8] - num_sparsities = 6 + num_sparsities = len(TARGET_L0s) sparsity_indices = t.arange(num_sparsities).tolist() - standard_sparsity_penalties = SPARSITY_PENALTIES[model_name]["standard"] - p_anneal_sparsity_penalties = SPARSITY_PENALTIES[model_name]["p_anneal"] - gated_sparsity_penalties = SPARSITY_PENALTIES[model_name]["gated"] - ks = [20, 40, 80, 160, 320, 640] - - assert len(standard_sparsity_penalties) == num_sparsities - assert len(p_anneal_sparsity_penalties) == num_sparsities - assert len(gated_sparsity_penalties) == num_sparsities - assert len(ks) == num_sparsities steps = int(num_tokens / sae_batch_size) # Total number of batches to train - warmup_steps = 1000 # Warmup period at start of training and after each resample - resample_steps = None - - # note: learning rate is not used for topk - learning_rates = [3e-4] - - # topk sae training parameters - decay_start = 24000 - auxk_alpha = 1 / 32 - - # p_anneal sae training parameters - p_start = 1 - p_end = 0.2 - anneal_end = None # steps - int(steps/10) - sparsity_queue_length = 10 - anneal_start = 10000 - n_sparsity_updates = 10 - - # jumprelu sae training parameters - jumprelu_bandwidth = 0.001 - jumprelu_sparsity_penalty = 1.0 # per figure 9 in the paper if save_checkpoints: # Creates checkpoints at 0.1%, 0.316%, 1%, 3.16%, 10%, 31.6%, 100% of training @@ -152,7 +322,7 @@ def run_sae_training( if not use_wandb: log_steps = None - model = LanguageModel(model_name, dispatch=True, device_map=DEVICE) + model = LanguageModel(model_name, dispatch=True, device_map=device) model = model.to(dtype=dtype) submodule = utils.get_submodule(model, layer) submodule_name = f"resid_post_layer_{layer}" @@ -180,128 +350,21 @@ def run_sae_training( for seed, sparsity_index, expansion_factor, learning_rate in itertools.product( random_seeds, sparsity_indices, expansion_factors, learning_rates ): - if "p_anneal" in architectures: - trainer_configs.append( - { - "trainer": PAnnealTrainer, - "dict_class": AutoEncoder, - "activation_dim": activation_dim, - "dict_size": expansion_factor * activation_dim, - "lr": learning_rate, - "sparsity_function": "Lp^p", - "initial_sparsity_penalty": p_anneal_sparsity_penalties[sparsity_index], - "p_start": p_start, - "p_end": p_end, - "anneal_start": int(anneal_start), - "anneal_end": anneal_end, - "sparsity_queue_length": sparsity_queue_length, - "n_sparsity_updates": n_sparsity_updates, - "warmup_steps": warmup_steps, - "resample_steps": resample_steps, - "steps": steps, - "seed": seed, - "wandb_name": f"PAnnealTrainer-pythia70m-{layer}", - "layer": layer, - "lm_name": model_name, - "device": device, - "submodule_name": submodule_name, - }, - ) - if "standard" in architectures: - trainer_configs.append( - { - "trainer": StandardTrainer, - "dict_class": AutoEncoder, - "activation_dim": activation_dim, - "dict_size": expansion_factor * activation_dim, - "lr": learning_rate, - "l1_penalty": standard_sparsity_penalties[sparsity_index], - "warmup_steps": warmup_steps, - "resample_steps": resample_steps, - "seed": seed, - "wandb_name": f"StandardTrainer-{model_name}-{submodule_name}", - "layer": layer, - "lm_name": model_name, - "device": device, - "submodule_name": submodule_name, - } - ) - if "standard_new" in architectures: - trainer_configs.append( - { - "trainer": StandardTrainer, - "dict_class": AutoEncoderNew, - "activation_dim": activation_dim, - "dict_size": expansion_factor * activation_dim, - "lr": learning_rate, - "l1_penalty": standard_sparsity_penalties[sparsity_index], - "warmup_steps": warmup_steps, - "resample_steps": resample_steps, - "seed": seed, - "wandb_name": f"StandardTrainerNew-{model_name}-{submodule_name}", - "layer": layer, - "lm_name": model_name, - "device": device, - "submodule_name": submodule_name, - } - ) - if "top_k" in architectures: - trainer_configs.append( - { - "trainer": TrainerTopK, - "dict_class": AutoEncoderTopK, - "activation_dim": activation_dim, - "dict_size": expansion_factor * activation_dim, - "k": ks[sparsity_index], - "auxk_alpha": auxk_alpha, # see Appendix A.2 - "decay_start": decay_start, # when does the lr decay start - "steps": steps, # when when does training end - "seed": seed, - "wandb_name": f"TopKTrainer-{model_name}-{submodule_name}", - "device": device, - "layer": layer, - "lm_name": model_name, - "submodule_name": submodule_name, - } - ) - if "gated" in architectures: - trainer_configs.append( - { - "trainer": GatedSAETrainer, - "dict_class": GatedAutoEncoder, - "activation_dim": activation_dim, - "dict_size": expansion_factor * activation_dim, - "lr": learning_rate, - "l1_penalty": gated_sparsity_penalties[sparsity_index], - "warmup_steps": warmup_steps, - "resample_steps": resample_steps, - "seed": seed, - "wandb_name": f"GatedSAETrainer-{model_name}-{submodule_name}", - "device": device, - "layer": layer, - "lm_name": model_name, - "submodule_name": submodule_name, - } - ) - if "jump_relu" in architectures: - trainer_configs.append( - { - "trainer": JumpReluTrainer, - "dict_class": JumpReluAutoEncoder, - "activation_dim": activation_dim, - "dict_size": expansion_factor * activation_dim, - "lr": learning_rate, - "target_l0": ks[sparsity_index], - "sparsity_penalty": jumprelu_sparsity_penalty, - "bandwidth": jumprelu_bandwidth, - "seed": seed, - "wandb_name": f"JumpReLUSAETrainer-{model_name}-{submodule_name}", - "device": device, - "layer": layer, - "lm_name": model_name, - "submodule_name": submodule_name, - } + dict_size = int(expansion_factor * activation_dim) + trainer_configs.extend( + get_trainer_configs( + architectures, + learning_rate, + sparsity_index, + seed, + activation_dim, + dict_size, + model_name, + device, + submodule_name, + steps, ) + ) print(f"len trainer configs: {len(trainer_configs)}") save_dir = f"{save_dir}/{submodule_name}" @@ -318,6 +381,7 @@ def run_sae_training( log_steps=log_steps, ) + @t.no_grad() def eval_saes( model_name: str, @@ -327,7 +391,6 @@ def eval_saes( overwrite_prev_results: bool = False, transcoder: bool = False, ) -> dict: - if transcoder: io = "in_and_out" else: @@ -339,7 +402,7 @@ def eval_saes( sae_batch_size = loss_recovered_batch_size * context_length dtype = LLM_CONFIG[model_name]["dtype"] - model = LanguageModel(model_name, dispatch=True, device_map=DEVICE) + model = LanguageModel(model_name, dispatch=True, device_map=device) model = model.to(dtype=dtype) buffer_size = n_inputs @@ -414,19 +477,24 @@ def eval_saes( python pythia.py --save_dir ./run3 --model_name google/gemma-2-2b --layers 12 --architectures standard top_k --use_wandb python pythia.py --save_dir ./jumprelu --model_name EleutherAI/pythia-70m-deduped --layers 3 --architectures jump_relu --use_wandb""" args = get_args() + + device = "cuda:0" + for layer in args.layers: run_sae_training( model_name=args.model_name, layer=layer, save_dir=args.save_dir, - device="cuda:0", + device=device, architectures=args.architectures, + num_tokens=num_tokens, + random_seeds=random_seeds, + expansion_factors=expansion_factors, + learning_rates=learning_rates, dry_run=args.dry_run, use_wandb=args.use_wandb, ) ae_paths = utils.get_nested_folders(args.save_dir) - eval_saes(args.model_name, ae_paths, 1000, DEVICE) - - + eval_saes(args.model_name, ae_paths, eval_num_inputs, device) From 32d198f738c61b0c1109f1803c43e01afb977d3e Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Sat, 21 Dec 2024 14:32:43 -0500 Subject: [PATCH 14/70] Track threshold for batchtopk, rename for consistency --- trainers/__init__.py | 4 ++-- trainers/batch_top_k.py | 52 +++++++++++++++++++++++++++++------------ trainers/top_k.py | 2 +- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/trainers/__init__.py b/trainers/__init__.py index 99b015b..81998af 100644 --- a/trainers/__init__.py +++ b/trainers/__init__.py @@ -2,6 +2,6 @@ from .gdm import GatedSAETrainer from .p_anneal import PAnnealTrainer from .gated_anneal import GatedAnnealTrainer -from .top_k import TrainerTopK +from .top_k import TopKTrainer from .jumprelu import JumpReluTrainer -from .batch_top_k import TrainerBatchTopK, BatchTopKSAE +from .batch_top_k import BatchTopKTrainer, BatchTopKSAE diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index c65195f..f306301 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -16,6 +16,7 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" self.register_buffer("k", t.tensor(k)) + self.register_buffer("threshold", None) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.bias.data.zero_() @@ -24,9 +25,16 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.set_decoder_norm_to_unit_norm() self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_active: bool = False): + def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 + else: + return encoded_acts_BF + # Flatten and perform batch top-k flattened_acts = post_relu_feat_acts_BF.flatten() post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) @@ -78,10 +86,10 @@ def remove_gradient_parallel_to_decoder_directions(self): @classmethod def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape if k is None: - k = state_dict['k'].item() - elif 'k' in state_dict and k != state_dict['k'].item(): + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") autoencoder = cls(activation_dim, dict_size, k) @@ -91,7 +99,7 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": return autoencoder -class TrainerBatchTopK(SAETrainer): +class BatchTopKTrainer(SAETrainer): def __init__( self, dict_class=BatchTopKSAE, @@ -100,6 +108,7 @@ def __init__( k=8, auxk_alpha=1 / 32, decay_start=24000, + threshold_beta=0.999, steps=30000, top_k_aux=512, seed=None, @@ -117,6 +126,7 @@ def __init__( self.wandb_name = wandb_name self.steps = steps self.k = k + self.threshold_beta = threshold_beta if seed is not None: t.manual_seed(seed) @@ -136,9 +146,7 @@ def __init__( self.dead_feature_threshold = 10_000_000 self.top_k_aux = top_k_aux - self.optimizer = t.optim.Adam( - self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999) - ) + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) def lr_fn(step): if step < decay_start: @@ -165,20 +173,34 @@ def get_auxiliary_loss(self, x, x_reconstruct, acts): acts_aux = t.zeros_like(acts[:, dead_features]).scatter( -1, acts_topk_aux.indices, acts_topk_aux.values ) - x_reconstruct_aux = F.linear( - acts_aux, self.ae.decoder.weight[:, dead_features] - ) + x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features]) l2_loss_aux = ( - self.auxk_alpha - * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() + self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() ) return l2_loss_aux else: return t.tensor(0, dtype=x.dtype, device=x.device) def loss(self, x, step=None, logging=False): - f, active_indices = self.ae.encode(x, return_active=True) - l0 = (f != 0).float().sum(dim=-1).mean().item() + f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False) + # l0 = (f != 0).float().sum(dim=-1).mean().item() + + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min() + + print(f"min_activation: {min_activation}") + + if self.threshold is None: + self.threshold = min_activation + else: + self.threshold = (self.threshold_beta * self.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + x_hat = self.ae.decode(f) e = x_hat - x diff --git a/trainers/top_k.py b/trainers/top_k.py index 33046f5..02c50b2 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -130,7 +130,7 @@ def from_pretrained(path, k: int, device=None): return autoencoder -class TrainerTopK(SAETrainer): +class TopKTrainer(SAETrainer): """ Top-K SAE training scheme. """ From b5821fd87e3676e7a9ab6b87d423c03c57a344dd Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Sun, 22 Dec 2024 04:30:56 +0000 Subject: [PATCH 15/70] Track thresholds for topk and batchtopk during training --- tests/test_end_to_end.py | 4 +-- trainers/batch_top_k.py | 34 +++++++++++++------------ trainers/top_k.py | 55 +++++++++++++++++++++++++++++++++++----- utils.py | 25 ++++++++++++------ 4 files changed, 86 insertions(+), 32 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 84aeab5..b2374ec 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -6,7 +6,7 @@ from dictionary_learning.training import trainSAE from dictionary_learning.trainers.standard import StandardTrainer -from dictionary_learning.trainers.top_k import TrainerTopK, AutoEncoderTopK +from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary from dictionary_learning.buffer import ActivationBuffer from dictionary_learning.dictionary import ( @@ -119,7 +119,7 @@ def test_sae_training(): trainer_configs.extend( [ { - "trainer": TrainerTopK, + "trainer": TopKTrainer, "dict_class": AutoEncoderTopK, "activation_dim": activation_dim, "dict_size": expansion_factor * activation_dim, diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index f306301..9cbe6a7 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -16,7 +16,7 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" self.register_buffer("k", t.tensor(k)) - self.register_buffer("threshold", None) + self.register_buffer("threshold", t.tensor(-1.0)) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.bias.data.zero_() @@ -109,6 +109,7 @@ def __init__( auxk_alpha=1 / 32, decay_start=24000, threshold_beta=0.999, + threshold_start_step=1000, steps=30000, top_k_aux=512, seed=None, @@ -127,6 +128,7 @@ def __init__( self.steps = steps self.k = k self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step if seed is not None: t.manual_seed(seed) @@ -185,21 +187,21 @@ def loss(self, x, step=None, logging=False): f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False) # l0 = (f != 0).float().sum(dim=-1).mean().item() - active = f[f > 0] + if step > self.threshold_start_step: + with t.no_grad(): + active = f[f > 0] - if active.size(0) == 0: - min_activation = 0.0 - else: - min_activation = active.min() - - print(f"min_activation: {min_activation}") + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach() - if self.threshold is None: - self.threshold = min_activation - else: - self.threshold = (self.threshold_beta * self.threshold) + ( - (1 - self.threshold_beta) * min_activation - ) + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) x_hat = self.ae.decode(f) @@ -252,14 +254,14 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "TrainerBatchTopK", + "trainer_class": "BatchTopKTrainer", "dict_class": "BatchTopKSAE", "lr": self.lr, "steps": self.steps, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, - "k": self.ae.k, + "k": self.ae.k.item(), "device": self.device, "layer": self.layer, "lm_name": self.lm_name, diff --git a/trainers/top_k.py b/trainers/top_k.py index 02c50b2..e5ca9ae 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -7,6 +7,7 @@ import torch as t import torch.nn as nn from collections import namedtuple +from typing import Optional from ..config import DEBUG from ..dictionary import Dictionary @@ -58,7 +59,10 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size - self.k = k + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k)) + self.register_buffer("threshold", t.tensor(-1.0)) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.bias.data.zero_() @@ -69,8 +73,17 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.b_dec = nn.Parameter(t.zeros(activation_dim)) - def encode(self, x: t.Tensor, return_topk: bool = False): + def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False): post_relu_feat_acts_BF = nn.functional.relu(self.encoder(x - self.b_dec)) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + if return_topk: + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) + return encoded_acts_BF, post_topk.values, post_topk.indices + else: + return encoded_acts_BF + post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) # We can't split immediately due to nnsight @@ -117,12 +130,18 @@ def remove_gradient_parallel_to_decoder_directions(self): "d_sae, d_in d_sae -> d_in d_sae", ) - def from_pretrained(path, k: int, device=None): + def from_pretrained(path, k: Optional[int] = None, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) dict_size, activation_dim = state_dict["encoder.weight"].shape + + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) autoencoder.load_state_dict(state_dict) if device is not None: @@ -143,6 +162,8 @@ def __init__( k=100, auxk_alpha=1 / 32, # see Appendix A.2 decay_start=24000, # when does the lr decay start + threshold_beta=0.999, + threshold_start_step=1000, steps=30000, # when when does training end seed=None, device=None, @@ -161,6 +182,9 @@ def __init__( self.wandb_name = wandb_name self.steps = steps self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step + if seed is not None: t.manual_seed(seed) t.cuda.manual_seed_all(seed) @@ -201,7 +225,26 @@ def lr_fn(step): def loss(self, x, step=None, logging=False): # Run the SAE - f, top_acts, top_indices = self.ae.encode(x, return_topk=True) + f, top_acts, top_indices = self.ae.encode(x, return_topk=True, use_threshold=False) + + if step > self.threshold_start_step: + with t.no_grad(): + active = top_acts.clone().detach() + active[active <= 0] = float("inf") + min_activations = active.min(dim=1).values + min_activation = min_activations.mean() + + B, K = active.shape + assert len(active.shape) == 2 + assert min_activations.shape == (B,) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + x_hat = self.ae.decode(f) # Measure goodness of reconstruction @@ -293,14 +336,14 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "TrainerTopK", + "trainer_class": "TopKTrainer", "dict_class": "AutoEncoderTopK", "lr": self.lr, "steps": self.steps, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, - "k": self.ae.k, + "k": self.ae.k.item(), "device": self.device, "layer": self.layer, "lm_name": self.lm_name, diff --git a/utils.py b/utils.py index 27a2188..4f34a4e 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,7 @@ from nnsight import LanguageModel from dictionary_learning.trainers.top_k import AutoEncoderTopK +from dictionary_learning.trainers.batch_top_k import BatchTopKSAE from dictionary_learning.dictionary import ( AutoEncoder, GatedAutoEncoder, @@ -13,29 +14,34 @@ JumpReluAutoEncoder, ) -def hf_dataset_to_generator(dataset_name, split='train', streaming=True): + +def hf_dataset_to_generator(dataset_name, split="train", streaming=True): dataset = load_dataset(dataset_name, split=split, streaming=streaming) - + def gen(): for x in iter(dataset): - yield x['text'] - + yield x["text"] + return gen() + def zst_to_generator(data_path): """ Load a dataset from a .jsonl.zst file. The jsonl entries is assumed to have a 'text' field """ - compressed_file = open(data_path, 'rb') + compressed_file = open(data_path, "rb") dctx = zstd.ZstdDecompressor() reader = dctx.stream_reader(compressed_file) - text_stream = io.TextIOWrapper(reader, encoding='utf-8') + text_stream = io.TextIOWrapper(reader, encoding="utf-8") + def generator(): for line in text_stream: - yield json.loads(line)['text'] + yield json.loads(line)["text"] + return generator() + def get_nested_folders(path: str) -> list[str]: """ Recursively get a list of folders that contain an ae.pt file, starting the search from the given path @@ -67,6 +73,8 @@ def load_dictionary(base_path: str, device: str) -> tuple: elif dict_class == "AutoEncoderTopK": k = config["trainer"]["k"] dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "BatchTopKSAE": + dictionary = BatchTopKSAE.from_pretrained(ae_path, device=device) elif dict_class == "JumpReluAutoEncoder": dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) else: @@ -74,6 +82,7 @@ def load_dictionary(base_path: str, device: str) -> tuple: return dictionary, config + def get_submodule(model: LanguageModel, layer: int): """Gets the residual stream submodule""" model_name = model._model_key @@ -83,4 +92,4 @@ def get_submodule(model: LanguageModel, layer: int): elif "gemma" in model_name: return model.model.layers[layer] else: - raise ValueError(f"Please add submodule for model {model_name}") \ No newline at end of file + raise ValueError(f"Please add submodule for model {model_name}") From 57f451b5635c4677ab47a4172aa588a5bdffdb4e Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Sun, 22 Dec 2024 04:31:15 +0000 Subject: [PATCH 16/70] Remove demo script and graphing notebook --- demo.py | 500 ------------------------------------------------- graphing.ipynb | 205 -------------------- 2 files changed, 705 deletions(-) delete mode 100644 demo.py delete mode 100644 graphing.ipynb diff --git a/demo.py b/demo.py deleted file mode 100644 index e4dd062..0000000 --- a/demo.py +++ /dev/null @@ -1,500 +0,0 @@ -import torch as t -from nnsight import LanguageModel -import argparse -import itertools -import os -import json -from dataclasses import dataclass, field, asdict -from typing import Optional, Type, Any -from enum import Enum - -from training import trainSAE -from trainers.standard import StandardTrainer -from trainers.top_k import TopKTrainer, AutoEncoderTopK -from trainers.gdm import GatedSAETrainer -from trainers.p_anneal import PAnnealTrainer -from trainers.jumprelu import JumpReluTrainer -from utils import hf_dataset_to_generator -from buffer import ActivationBuffer -from dictionary import AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder -from evaluation import evaluate -import utils as utils - - -class TrainerType(Enum): - STANDARD = "standard" - STANDARD_NEW = "standard_new" - TOP_K = "top_k" - BATCH_TOP_K = "batch_top_k" - GATED = "gated" - P_ANNEAL = "p_anneal" - JUMP_RELU = "jump_relu" - - -@dataclass -class LLMConfig: - llm_batch_size: int - context_length: int - sae_batch_size: int - dtype: t.dtype - - -@dataclass -class SparsityPenalties: - standard: list[float] - p_anneal: list[float] - gated: list[float] - - -# TODO: Move all of these to a config file -num_tokens = 50_000_000 -eval_num_inputs = 1_000 -random_seeds = [0] -expansion_factors = [8] - -# note: learning rate is not used for topk -learning_rates = [3e-4] - -LLM_CONFIG = { - "EleutherAI/pythia-70m-deduped": LLMConfig( - llm_batch_size=512, context_length=128, sae_batch_size=4096, dtype=t.float32 - ), - "google/gemma-2-2b": LLMConfig( - llm_batch_size=32, context_length=128, sae_batch_size=2048, dtype=t.bfloat16 - ), -} - - -# NOTE: In the current setup, the length of each sparsity penalty and target_l0 should be the same -SPARSITY_PENALTIES = { - "EleutherAI/pythia-70m-deduped": SparsityPenalties( - standard=[0.01, 0.05, 0.075, 0.1, 0.125, 0.15], - p_anneal=[0.02, 0.03, 0.035, 0.04, 0.05, 0.075], - gated=[0.1, 0.3, 0.5, 0.7, 0.9, 1.1], - ), - "google/gemma-2-2b": SparsityPenalties( - standard=[0.025, 0.035, 0.04, 0.05, 0.06, 0.07], - p_anneal=[-1] * 6, - gated=[-1] * 6, - ), -} - - -TARGET_L0s = [20, 40, 80, 160, 320, 640] - - -@dataclass -class BaseTrainerConfig: - activation_dim: int - dict_size: int - seed: int - device: str - layer: str - lm_name: str - submodule_name: str - trainer: Type[Any] - dict_class: Type[Any] - wandb_name: str - steps: Optional[int] = None - - -@dataclass -class WarmupConfig: - warmup_steps: int = 1000 - resample_steps: Optional[int] = None - - -@dataclass -class StandardTrainerConfig(BaseTrainerConfig, WarmupConfig): - lr: float - l1_penalty: float - - -@dataclass -class StandardNewTrainerConfig(BaseTrainerConfig, WarmupConfig): - lr: float - l1_penalty: float - - -@dataclass -class PAnnealTrainerConfig(BaseTrainerConfig, WarmupConfig): - lr: float - initial_sparsity_penalty: float - sparsity_function: str = "Lp^p" - p_start: float = 1.0 - p_end: float = 0.2 - anneal_start: int = 10000 - anneal_end: Optional[int] = None - sparsity_queue_length: int = 10 - n_sparsity_updates: int = 10 - - -@dataclass -class TopKTrainerConfig(BaseTrainerConfig): - k: int - auxk_alpha: float = 1 / 32 - decay_start: int = 24000 - threshold_beta: float = 0.999 - - -@dataclass -class GatedTrainerConfig(BaseTrainerConfig, WarmupConfig): - lr: float - l1_penalty: float - - -@dataclass -class JumpReluTrainerConfig(BaseTrainerConfig): - lr: float - target_l0: int - sparsity_penalty: float = 1.0 - bandwidth: float = 0.001 - - -def get_trainer_configs( - architectures: list[str], - learning_rate: float, - sparsity_index: int, - seed: int, - activation_dim: int, - dict_size: int, - model_name: str, - device: str, - layer: str, - submodule_name: str, - steps: int, -) -> list[dict]: - trainer_configs = [] - - base_config = { - "activation_dim": activation_dim, - "dict_size": dict_size, - "seed": seed, - "device": device, - "layer": layer, - "lm_name": model_name, - "submodule_name": submodule_name, - } - - if TrainerType.P_ANNEAL.value in architectures: - config = PAnnealTrainerConfig( - **base_config, - trainer=PAnnealTrainer, - dict_class=AutoEncoder, - lr=learning_rate, - initial_sparsity_penalty=SPARSITY_PENALTIES[model_name].p_anneal[sparsity_index], - steps=steps, - wandb_name=f"PAnnealTrainer-{model_name}-{submodule_name}", - ) - trainer_configs.append(asdict(config)) - - if TrainerType.STANDARD.value in architectures: - config = StandardTrainerConfig( - **base_config, - trainer=StandardTrainer, - dict_class=AutoEncoder, - lr=learning_rate, - l1_penalty=SPARSITY_PENALTIES[model_name].standard[sparsity_index], - wandb_name=f"StandardTrainer-{model_name}-{submodule_name}", - ) - trainer_configs.append(asdict(config)) - - if TrainerType.STANDARD_NEW.value in architectures: - config = StandardNewTrainerConfig( - **base_config, - trainer=StandardTrainer, - dict_class=AutoEncoderNew, - lr=learning_rate, - l1_penalty=SPARSITY_PENALTIES[model_name].standard[sparsity_index], - wandb_name=f"StandardTrainerNew-{model_name}-{submodule_name}", - ) - trainer_configs.append(asdict(config)) - - if TrainerType.TOP_K.value in architectures: - config = TopKTrainerConfig( - **base_config, - trainer=TopKTrainer, - dict_class=AutoEncoderTopK, - k=TARGET_L0s[sparsity_index], - steps=steps, - wandb_name=f"TopKTrainer-{model_name}-{submodule_name}", - ) - trainer_configs.append(asdict(config)) - - if TrainerType.GATED.value in architectures: - config = GatedTrainerConfig( - **base_config, - trainer=GatedSAETrainer, - dict_class=GatedAutoEncoder, - lr=learning_rate, - l1_penalty=SPARSITY_PENALTIES[model_name].gated[sparsity_index], - wandb_name=f"GatedTrainer-{model_name}-{submodule_name}", - ) - trainer_configs.append(asdict(config)) - - if TrainerType.JUMP_RELU.value in architectures: - config = JumpReluTrainerConfig( - **base_config, - trainer=JumpReluTrainer, - dict_class=JumpReluAutoEncoder, - lr=learning_rate, - target_l0=TARGET_L0s[sparsity_index], - wandb_name=f"JumpReluTrainer-{model_name}-{submodule_name}", - ) - trainer_configs.append(asdict(config)) - - return trainer_configs - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--save_dir", type=str, required=True, help="where to store sweep") - parser.add_argument("--use_wandb", action="store_true", help="use wandb logging") - parser.add_argument("--dry_run", action="store_true", help="dry run sweep") - parser.add_argument( - "--layers", type=int, nargs="+", required=True, help="layers to train SAE on" - ) - parser.add_argument( - "--model_name", - type=str, - required=True, - help="which language model to use", - ) - parser.add_argument( - "--architectures", - type=str, - nargs="+", - choices=[e.value for e in TrainerType], - required=True, - help="which SAE architectures to train", - ) - args = parser.parse_args() - return args - - -def run_sae_training( - model_name: str, - layer: int, - save_dir: str, - device: str, - architectures: list, - num_tokens: int, - random_seeds: list[int], - expansion_factors: list[float], - learning_rates: list[float], - dry_run: bool = False, - use_wandb: bool = False, - save_checkpoints: bool = False, - buffer_scaling_factor: int = 20, -): - # model and data parameters - context_length = LLM_CONFIG[model_name]["context_length"] - - llm_batch_size = LLM_CONFIG[model_name]["llm_batch_size"] - sae_batch_size = LLM_CONFIG[model_name]["sae_batch_size"] - dtype = LLM_CONFIG[model_name]["dtype"] - - num_contexts_per_sae_batch = sae_batch_size // context_length - buffer_size = num_contexts_per_sae_batch * buffer_scaling_factor - - # sae training parameters - # random_seeds = t.arange(10).tolist() - - num_sparsities = len(TARGET_L0s) - sparsity_indices = t.arange(num_sparsities).tolist() - - steps = int(num_tokens / sae_batch_size) # Total number of batches to train - - if save_checkpoints: - # Creates checkpoints at 0.1%, 0.316%, 1%, 3.16%, 10%, 31.6%, 100% of training - desired_checkpoints = t.logspace(-3, 0, 7).tolist() - desired_checkpoints = [0.0] + desired_checkpoints[:-1] - desired_checkpoints.sort() - print(f"desired_checkpoints: {desired_checkpoints}") - - save_steps = [int(steps * step) for step in desired_checkpoints] - save_steps.sort() - print(f"save_steps: {save_steps}") - else: - save_steps = None - - log_steps = 100 # Log the training on wandb - if not use_wandb: - log_steps = None - - model = LanguageModel(model_name, dispatch=True, device_map=device) - model = model.to(dtype=dtype) - submodule = utils.get_submodule(model, layer) - submodule_name = f"resid_post_layer_{layer}" - io = "out" - activation_dim = model.config.hidden_size - - generator = hf_dataset_to_generator("monology/pile-uncopyrighted") - - activation_buffer = ActivationBuffer( - generator, - model, - submodule, - n_ctxs=buffer_size, - ctx_len=context_length, - refresh_batch_size=llm_batch_size, - out_batch_size=sae_batch_size, - io=io, - d_submodule=activation_dim, - device=device, - ) - - # create the list of configs - trainer_configs = [] - - for seed, sparsity_index, expansion_factor, learning_rate in itertools.product( - random_seeds, sparsity_indices, expansion_factors, learning_rates - ): - dict_size = int(expansion_factor * activation_dim) - trainer_configs.extend( - get_trainer_configs( - architectures, - learning_rate, - sparsity_index, - seed, - activation_dim, - dict_size, - model_name, - device, - submodule_name, - steps, - ) - ) - - print(f"len trainer configs: {len(trainer_configs)}") - save_dir = f"{save_dir}/{submodule_name}" - - if not dry_run: - # actually run the sweep - trainSAE( - data=activation_buffer, - trainer_configs=trainer_configs, - use_wandb=use_wandb, - steps=steps, - save_steps=save_steps, - save_dir=save_dir, - log_steps=log_steps, - ) - - -@t.no_grad() -def eval_saes( - model_name: str, - ae_paths: list[str], - n_inputs: int, - device: str, - overwrite_prev_results: bool = False, - transcoder: bool = False, -) -> dict: - if transcoder: - io = "in_and_out" - else: - io = "out" - - context_length = LLM_CONFIG[model_name]["context_length"] - llm_batch_size = LLM_CONFIG[model_name]["llm_batch_size"] - loss_recovered_batch_size = llm_batch_size // 5 - sae_batch_size = loss_recovered_batch_size * context_length - dtype = LLM_CONFIG[model_name]["dtype"] - - model = LanguageModel(model_name, dispatch=True, device_map=device) - model = model.to(dtype=dtype) - - buffer_size = n_inputs - io = "out" - n_batches = n_inputs // loss_recovered_batch_size - - generator = hf_dataset_to_generator("monology/pile-uncopyrighted") - - input_strings = [] - for i, example in enumerate(generator): - input_strings.append(example) - if i > n_inputs * 5: - break - - eval_results = {} - - for ae_path in ae_paths: - output_filename = f"{ae_path}/eval_results.json" - if not overwrite_prev_results: - if os.path.exists(output_filename): - print(f"Skipping {ae_path} as eval results already exist") - continue - - dictionary, config = utils.load_dictionary(ae_path, device) - dictionary = dictionary.to(dtype=model.dtype) - - layer = config["trainer"]["layer"] - submodule = utils.get_submodule(model, layer) - - activation_dim = config["trainer"]["activation_dim"] - - activation_buffer = ActivationBuffer( - iter(input_strings), - model, - submodule, - n_ctxs=buffer_size, - ctx_len=context_length, - refresh_batch_size=llm_batch_size, - out_batch_size=sae_batch_size, - io=io, - d_submodule=activation_dim, - device=device, - ) - - eval_results = evaluate( - dictionary, - activation_buffer, - context_length, - loss_recovered_batch_size, - io=io, - device=device, - n_batches=n_batches, - ) - - hyperparameters = { - "n_inputs": n_inputs, - "context_length": context_length, - } - eval_results["hyperparameters"] = hyperparameters - - print(eval_results) - - with open(output_filename, "w") as f: - json.dump(eval_results, f) - - # return the final eval_results for testing purposes - return eval_results - - -if __name__ == "__main__": - """python pythia.py --save_dir ./run2 --model_name EleutherAI/pythia-70m-deduped --layers 3 --architectures standard standard_new top_k gated --use_wandb - python pythia.py --save_dir ./run3 --model_name google/gemma-2-2b --layers 12 --architectures standard top_k --use_wandb - python pythia.py --save_dir ./jumprelu --model_name EleutherAI/pythia-70m-deduped --layers 3 --architectures jump_relu --use_wandb""" - args = get_args() - - device = "cuda:0" - - for layer in args.layers: - run_sae_training( - model_name=args.model_name, - layer=layer, - save_dir=args.save_dir, - device=device, - architectures=args.architectures, - num_tokens=num_tokens, - random_seeds=random_seeds, - expansion_factors=expansion_factors, - learning_rates=learning_rates, - dry_run=args.dry_run, - use_wandb=args.use_wandb, - ) - - ae_paths = utils.get_nested_folders(args.save_dir) - - eval_saes(args.model_name, ae_paths, eval_num_inputs, device) diff --git a/graphing.ipynb b/graphing.ipynb deleted file mode 100644 index 2b6dc10..0000000 --- a/graphing.ipynb +++ /dev/null @@ -1,205 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import json\n", - "from typing import Optional\n", - "\n", - "import dictionary_learning.utils as utils\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "TRAINER_LABELS = {\n", - " \"StandardTrainer\": \"Standard\",\n", - " \"JumpReluTrainer\": \"JumpReLU\",\n", - " \"TrainerTopK\": \"Top K\",\n", - " \"GatedSAETrainer\": \"Gated\",\n", - " \"PAnnealTrainer\": \"P-Anneal\",\n", - "}\n", - "\n", - "TRAINER_MARKERS = {\n", - " \"StandardTrainer\": \"o\",\n", - " \"JumpReluTrainer\": \"X\",\n", - " \"TrainerTopK\": \"^\",\n", - " \"GatedSAETrainer\": \"d\",\n", - " \"PAnnealTrainer\": \"s\",\n", - "}\n", - "\n", - "TRAINER_COLORS = {\n", - " \"StandardTrainer\": \"blue\",\n", - " \"JumpReluTrainer\": \"orange\",\n", - " \"TrainerTopK\": \"green\",\n", - " \"GatedSAETrainer\": \"red\",\n", - " \"PAnnealTrainer\": \"purple\",\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "save_dirs = [\"./run2\", \"./jumprelu\"]\n", - "# save_dirs = [\"./run2\"]\n", - "ae_paths = []\n", - "\n", - "for save_dir in save_dirs:\n", - " ae_paths.extend(utils.get_nested_folders(save_dir))\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plotting_results = {}\n", - "\n", - "for ae_path in ae_paths:\n", - " with open(ae_path + \"/config.json\") as f:\n", - " config = json.load(f)\n", - "\n", - " with open(ae_path + \"/eval_results.json\") as f:\n", - " eval_results = json.load(f)\n", - "\n", - " ae_results = {}\n", - "\n", - " ae_results[\"l0\"] = eval_results[\"l0\"]\n", - " ae_results[\"frac_recovered\"] = eval_results[\"frac_recovered\"]\n", - " ae_results[\"trainer_class\"] = config[\"trainer\"][\"trainer_class\"]\n", - " ae_results[\"dict_size\"] = config[\"trainer\"][\"dict_size\"]\n", - "\n", - " plotting_results[ae_path] = ae_results\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_2var_graph(\n", - " results: dict[str, dict[str, float]],\n", - " custom_metric: str,\n", - " title: str = \"L0 vs Custom Metric\",\n", - " y_label: str = \"Custom Metric\",\n", - " xlims: Optional[tuple[float, float]] = None,\n", - " ylims: Optional[tuple[float, float]] = None,\n", - " output_filename: Optional[str] = None,\n", - " legend_location: str = \"lower right\",\n", - " x_axis_key: str = \"l0\",\n", - " return_fig: bool = False,\n", - "):\n", - " # Extract data from results\n", - " l0_values = [data[x_axis_key] for data in results.values()]\n", - " custom_metric_values = [data[custom_metric] for data in results.values()]\n", - "\n", - " # Create the scatter plot\n", - " fig, ax = plt.subplots(figsize=(10, 6))\n", - "\n", - " handles, labels = [], []\n", - "\n", - " for trainer, marker in TRAINER_MARKERS.items():\n", - " # Filter data for this trainer\n", - " trainer_data = {k: v for k, v in results.items() if v[\"trainer_class\"] == trainer}\n", - "\n", - " if not trainer_data:\n", - " continue # Skip this trainer if no data points\n", - "\n", - " l0_values = [data[x_axis_key] for data in trainer_data.values()]\n", - " custom_metric_values = [data[custom_metric] for data in trainer_data.values()]\n", - "\n", - " # Plot data points\n", - " scatter = ax.scatter(\n", - " l0_values,\n", - " custom_metric_values,\n", - " marker=marker,\n", - " s=100,\n", - " label=trainer,\n", - " color=TRAINER_COLORS[trainer],\n", - " edgecolor=\"black\",\n", - " )\n", - "\n", - " # Create custom legend handle with both marker and color\n", - " legend_handle = plt.scatter(\n", - " [], [], marker=marker, s=100, color=TRAINER_COLORS[trainer], edgecolor=\"black\"\n", - " )\n", - " handles.append(legend_handle)\n", - "\n", - " if trainer in TRAINER_LABELS:\n", - " trainer_label = TRAINER_LABELS[trainer]\n", - " else:\n", - " trainer_label = trainer.capitalize()\n", - " labels.append(trainer_label)\n", - "\n", - " # Set labels and title\n", - " ax.set_xlabel(\"L0 (Sparsity)\")\n", - " ax.set_ylabel(y_label)\n", - " ax.set_title(title)\n", - "\n", - " ax.legend(handles, labels, loc=legend_location)\n", - "\n", - " # Set axis limits\n", - " if xlims:\n", - " ax.set_xlim(*xlims)\n", - " if ylims:\n", - " ax.set_ylim(*ylims)\n", - "\n", - " plt.tight_layout()\n", - "\n", - " # Save and show the plot\n", - " if output_filename:\n", - " plt.savefig(output_filename, bbox_inches=\"tight\")\n", - "\n", - " if return_fig:\n", - " return fig\n", - "\n", - " plt.show()\n", - " \n", - "plt.rcParams.update({\"font.size\": 20})\n", - "plot_2var_graph(plotting_results, \"frac_recovered\", title=\"Fraction Recovered vs L0\", y_label=\"Fraction Recovered\", output_filename=\"frac_recovered_vs_l0.png\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 81968f2659082996539f08ea3188a5d2ed327696 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 26 Dec 2024 03:51:08 +0000 Subject: [PATCH 17/70] Add option to normalize dataset activations --- dictionary.py | 15 ++++++++++++ trainers/jumprelu.py | 3 +++ training.py | 55 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 73 insertions(+) diff --git a/dictionary.py b/dictionary.py index ababd85..b26aa20 100644 --- a/dictionary.py +++ b/dictionary.py @@ -86,6 +86,10 @@ def forward(self, x, output_features=False, ghost_mask=None): return x_hat, x_ghost, f else: return x_hat, x_ghost + + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.bias.data *= scale @classmethod def from_pretrained(cls, path, dtype=t.float, device=None): @@ -204,6 +208,11 @@ def forward(self, x, output_features=False): else: return x_hat + def scale_biases(self, scale: float): + self.decoder_bias.data *= scale + self.mag_bias.data *= scale + self.gate_bias.data *= scale + def from_pretrained(path, device=None): """ Load a pretrained autoencoder from a file. @@ -215,6 +224,7 @@ def from_pretrained(path, device=None): if device is not None: autoencoder.to(device) return autoencoder + class JumpReluAutoEncoder(Dictionary, nn.Module): """ @@ -267,6 +277,11 @@ def forward(self, x, output_features=False): return x_hat, f else: return x_hat + + def scale_biases(self, scale: float): + self.b_dec.data *= scale + self.b_enc.data *= scale + self.threshold.data *= scale @classmethod def from_pretrained( diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index a3a6371..e27e1b3 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -168,4 +168,7 @@ def config(self): "lm_name": self.lm_name, "wandb_name": self.wandb_name, "submodule_name": self.submodule_name, + "bandwidth": self.bandwidth, + "sparsity_penalty": self.sparsity_coefficient, + "target_l0": self.target_l0, } diff --git a/training.py b/training.py index 13fd4b3..19b8af8 100644 --- a/training.py +++ b/training.py @@ -73,6 +73,31 @@ def log_stats( if log_queues: log_queues[i].put(log) +def get_norm_factor(data, steps: int) -> float: + """Per Section 3.1, find a fixed scalar factor so activation vectors have unit mean squared norm. + This is very helpful for hyperparameter transfer between different layers and models. + Use more steps for more accurate results. + https://arxiv.org/pdf/2408.05147""" + total_mean_squared_norm = 0 + count = 0 + + for step, act_BD in enumerate(tqdm(data, total=steps)): + if step > steps: + break + + count += 1 + mean_squared_norm = t.mean(t.sum(act_BD ** 2, dim=1)) + total_mean_squared_norm += mean_squared_norm + + average_mean_squared_norm = total_mean_squared_norm / count + norm_factor = t.sqrt(average_mean_squared_norm).item() + + print(f"Average mean squared norm: {average_mean_squared_norm}") + print(f"Norm factor: {norm_factor}") + + return norm_factor + + def trainSAE( data, @@ -87,10 +112,16 @@ def trainSAE( activations_split_by_head:bool=False, transcoder:bool=False, run_cfg:dict={}, + normalize_activations:bool=False, ): """ Train SAEs using the given trainers + + If normalize_activations is True, the activations will be normalized to have unit mean squared norm. + The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. + This is very helpful for hyperparameter transfer between different layers and models. """ + trainers = [] for config in trainer_configs: trainer_class = config["trainer"] @@ -130,7 +161,21 @@ def trainSAE( else: save_dirs = [None for _ in trainer_configs] + if normalize_activations: + norm_factor = get_norm_factor(data, steps=100) + + for trainer in trainers: + trainer.config["norm_factor"] = norm_factor + # Verify that all autoencoders have a scale_biases method + trainer.ae.scale_biases(1.0) + for step, act in enumerate(tqdm(data, total=steps)): + + act = act.to(dtype=t.float32) + + if normalize_activations: + act /= norm_factor + if steps is not None and step >= steps: break @@ -144,6 +189,11 @@ def trainSAE( if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): if dir is not None: + + if normalize_activations: + # Temporarily scale up biases for checkpoint saving + trainer.ae.scale_biases(norm_factor) + if not os.path.exists(os.path.join(dir, "checkpoints")): os.mkdir(os.path.join(dir, "checkpoints")) t.save( @@ -151,12 +201,17 @@ def trainSAE( os.path.join(dir, "checkpoints", f"ae_{step}.pt"), ) + if normalize_activations: + trainer.ae.scale_biases(1 / norm_factor) + # training for trainer in trainers: trainer.update(step, act) # save final SAEs for save_dir, trainer in zip(save_dirs, trainers): + if normalize_activations: + trainer.ae.scale_biases(norm_factor) if save_dir is not None: t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt")) From 488a1545922249cdb9ce5a5885c1931a5c21a37f Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 26 Dec 2024 03:51:24 +0000 Subject: [PATCH 18/70] Fix topk bfloat16 dtype error --- trainers/top_k.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trainers/top_k.py b/trainers/top_k.py index e5ca9ae..4c62287 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -313,6 +313,7 @@ def update(self, step, x): # Initialise the decoder bias if step == 0: median = geometric_median(x) + median = median.to(self.ae.b_dec.dtype) self.ae.b_dec.data = median # Make sure the decoder is still unit-norm From 484ca01f405e5791968883123718fd67ee35f299 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 26 Dec 2024 18:15:41 +0000 Subject: [PATCH 19/70] Add bias scaling to topk saes --- trainers/batch_top_k.py | 4 ++++ trainers/top_k.py | 4 ++++ training.py | 7 +++++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index 9cbe6a7..72d144b 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -83,6 +83,10 @@ def remove_gradient_parallel_to_decoder_directions(self): "d_sae, d_in d_sae -> d_in d_sae", ) + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + @classmethod def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": state_dict = t.load(path) diff --git a/trainers/top_k.py b/trainers/top_k.py index 4c62287..4afa176 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -130,6 +130,10 @@ def remove_gradient_parallel_to_decoder_directions(self): "d_sae, d_in d_sae -> d_in d_sae", ) + def scale_biases(self, scale: float): + self.encoder.bias.data *= scale + self.b_dec.data *= scale + def from_pretrained(path, k: Optional[int] = None, device=None): """ Load a pretrained autoencoder from a file. diff --git a/training.py b/training.py index 19b8af8..d4c6a38 100644 --- a/training.py +++ b/training.py @@ -77,11 +77,14 @@ def get_norm_factor(data, steps: int) -> float: """Per Section 3.1, find a fixed scalar factor so activation vectors have unit mean squared norm. This is very helpful for hyperparameter transfer between different layers and models. Use more steps for more accurate results. - https://arxiv.org/pdf/2408.05147""" + https://arxiv.org/pdf/2408.05147 + + If experiencing troubles with hyperparameter transfer between models, it may be worth instead normalizing to the square root of d_model. + https://transformer-circuits.pub/2024/april-update/index.html#training-saes""" total_mean_squared_norm = 0 count = 0 - for step, act_BD in enumerate(tqdm(data, total=steps)): + for step, act_BD in enumerate(tqdm(data, total=steps, desc="Calculating norm factor")): if step > steps: break From 8b95ec9b6e9a6d8d6255092e51b7580dccac70d6 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 26 Dec 2024 22:20:16 +0000 Subject: [PATCH 20/70] Use the correct standard SAE reconstruction loss, initialize W_dec to W_enc.T --- dictionary.py | 14 +++++++++----- trainers/p_anneal.py | 4 ++-- trainers/standard.py | 5 +++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/dictionary.py b/dictionary.py index b26aa20..09e80ab 100644 --- a/dictionary.py +++ b/dictionary.py @@ -47,12 +47,16 @@ def __init__(self, activation_dim, dict_size): self.dict_size = dict_size self.bias = nn.Parameter(t.zeros(activation_dim)) self.encoder = nn.Linear(activation_dim, dict_size, bias=True) - - # rows of decoder weight matrix are unit vectors self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - dec_weight = t.randn_like(self.decoder.weight) - dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True) - self.decoder.weight = nn.Parameter(dec_weight) + + # initialize encoder and decoder weights + w = t.randn(activation_dim, dict_size) + ## normalize columns of w + w = w / w.norm(dim=0, keepdim=True) * 0.1 + ## set encoder and decoder weights + self.encoder.weight = nn.Parameter(w.clone().T) + self.decoder.weight = nn.Parameter(w.clone()) + def encode(self, x): return nn.ReLU()(self.encoder(x - self.bias)) diff --git a/trainers/p_anneal.py b/trainers/p_anneal.py index 4a157b9..0138547 100644 --- a/trainers/p_anneal.py +++ b/trainers/p_anneal.py @@ -166,7 +166,7 @@ def lp_norm(self, f, p): def loss(self, x, step, logging=False): # Compute loss terms x_hat, f = self.ae(x, output_features=True) - l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() lp_loss = self.lp_norm(f, self.p) scaled_lp_loss = lp_loss * self.sparsity_coeff self.lp_loss = lp_loss @@ -201,7 +201,7 @@ def loss(self, x, step, logging=False): self.steps_since_active[~deads] = 0 if logging is False: - return l2_loss + scaled_lp_loss + return recon_loss + scaled_lp_loss else: loss_log = { 'p' : self.p, diff --git a/trainers/standard.py b/trainers/standard.py index 2cfbb6a..07b9b67 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -127,6 +127,7 @@ def resample_neurons(self, deads, activations): def loss(self, x, logging=False, **kwargs): x_hat, f = self.ae(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() l1_loss = f.norm(p=1, dim=-1).mean() if self.steps_since_active is not None: @@ -135,7 +136,7 @@ def loss(self, x, logging=False, **kwargs): self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - loss = l2_loss + self.l1_penalty * l1_loss + loss = recon_loss + self.l1_penalty * sparsity_warmup * l1_loss if not logging: return loss @@ -144,7 +145,7 @@ def loss(self, x, logging=False, **kwargs): x, x_hat, f, { 'l2_loss' : l2_loss.item(), - 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(), + 'mse_loss' : recon_loss.item(), 'sparsity_loss' : l1_loss.item(), 'loss' : loss.item() } From efd76b138f429bb8e5e969e2e45926e886fdd71b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 26 Dec 2024 22:20:51 +0000 Subject: [PATCH 21/70] Also scale topk thresholds when scaling biases --- trainers/batch_top_k.py | 2 ++ trainers/top_k.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index 72d144b..a7fbdc8 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -86,6 +86,8 @@ def remove_gradient_parallel_to_decoder_directions(self): def scale_biases(self, scale: float): self.encoder.bias.data *= scale self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale @classmethod def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": diff --git a/trainers/top_k.py b/trainers/top_k.py index 4afa176..12c549a 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -133,6 +133,8 @@ def remove_gradient_parallel_to_decoder_directions(self): def scale_biases(self, scale: float): self.encoder.bias.data *= scale self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale def from_pretrained(path, k: Optional[int] = None, device=None): """ From 9687bb9858ef05306227309af99cd5c09d91642a Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 26 Dec 2024 23:47:27 +0000 Subject: [PATCH 22/70] Remove leftover variable, update expected results with standard SAE improvements --- tests/test_end_to_end.py | 29 ++++++++++++++++------------- trainers/standard.py | 2 +- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index b2374ec..31cb314 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -33,18 +33,18 @@ "frac_alive": 0.99951171875, }, "AutoEncoder": { - "l2_loss": 6.5741173267364506, - "l1_loss": 32.06615734100342, - "l0": 60.9147216796875, - "frac_variance_explained": 0.9042629599571228, - "cossim": 0.8782194256782532, - "l2_ratio": 0.814234834909439, - "relative_reconstruction_bias": 0.9813631415367127, - "loss_original": 3.328495955467224, - "loss_reconstructed": 5.7899915218353275, + "l2_loss": 6.822399997711182, + "l1_loss": 19.381900978088378, + "l0": 37.4492919921875, + "frac_variance_explained": 0.8993505954742431, + "cossim": 0.8791077017784119, + "l2_ratio": 0.7455410599708557, + "relative_reconstruction_bias": 0.9595056653022767, + "loss_original": 3.3284960985183716, + "loss_reconstructed": 5.203806638717651, "loss_zero": 13.250199031829833, - "frac_recovered": 0.754741370677948, - "frac_alive": 0.9921875, + "frac_recovered": 0.8104169845581055, + "frac_alive": 0.99658203125, }, } @@ -62,7 +62,10 @@ def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing. I have observed that results can slightly vary with library versions. For full determinism, - use pytorch 2.2.0 and nnsight 0.3.3.""" + use pytorch 2.2.0 and nnsight 0.3.3. + + NOTE: `dictionary_learning` is meant to be used as a submodule. Thus, to run this test, you need to use `dictionary_learning` as a submodule + and run the test from the root of the repository using `pytest -s`. Refer to https://github.com/adamkarvonen/dictionary_learning_demo for an example""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -79,7 +82,7 @@ def test_sae_training(): # sae training parameters k = 40 - sparsity_penalty = 0.05 + sparsity_penalty = 2.0 expansion_factor = 8 steps = int(num_tokens / sae_batch_size) # Total number of batches to train diff --git a/trainers/standard.py b/trainers/standard.py index 07b9b67..506a5c0 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -136,7 +136,7 @@ def loss(self, x, logging=False, **kwargs): self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - loss = recon_loss + self.l1_penalty * sparsity_warmup * l1_loss + loss = recon_loss + self.l1_penalty * l1_loss if not logging: return loss From f0bb66d1c25bcb7dc8df62d8dbc3bfd47d26b14c Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 27 Dec 2024 05:00:16 +0000 Subject: [PATCH 23/70] Track lr decay implementation --- trainers/standard_lr_decay.py | 216 ++++++++++++++++++++++++++++++++++ 1 file changed, 216 insertions(+) create mode 100644 trainers/standard_lr_decay.py diff --git a/trainers/standard_lr_decay.py b/trainers/standard_lr_decay.py new file mode 100644 index 0000000..1119ab5 --- /dev/null +++ b/trainers/standard_lr_decay.py @@ -0,0 +1,216 @@ +""" +Implements the standard SAE training scheme. +""" +import torch as t +from typing import Optional + +from ..trainers.trainer import SAETrainer +from ..config import DEBUG +from ..dictionary import AutoEncoder +from collections import namedtuple + +class ConstrainedAdam(t.optim.Adam): + """ + A variant of Adam where some of the parameters are constrained to have unit norm. + """ + def __init__(self, params, constrained_params, lr): + super().__init__(params, lr=lr) + self.constrained_params = list(constrained_params) + + def step(self, closure=None): + with t.no_grad(): + for p in self.constrained_params: + normed_p = p / p.norm(dim=0, keepdim=True) + # project away the parallel component of the gradient + p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p + super().step(closure=closure) + with t.no_grad(): + for p in self.constrained_params: + # renormalize the constrained parameters + p /= p.norm(dim=0, keepdim=True) + +class StandardTrainer(SAETrainer): + """ + Standard SAE training scheme. + """ + def __init__(self, + dict_class=AutoEncoder, + activation_dim:int=512, + dict_size:int=64*512, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + lr_decay_steps_fraction:Optional[float]=0.2, + final_lr_fraction:Optional[float]=0.1, + steps: Optional[int]=None, # total of steps to train for + resample_steps:Optional[int]=None, # how often to resample neurons + seed:Optional[int]=None, + device=None, + layer:Optional[int]=None, + lm_name:Optional[str]=None, + wandb_name:Optional[str]='StandardTrainer', + submodule_name:Optional[str]=None, + ): + """Options: + warump_steps: LR linear warmup period at start of training and after each resample + sparsity_warmup_steps: Sparsity linear warmup period at start of training + lr_decay_steps_fraction: LR linear decay for the last fraction of training""" + super().__init__(seed) + + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + # initialize dictionary + self.ae = dict_class(activation_dim, dict_size) + + self.lr = lr + self.l1_penalty=l1_penalty + self.warmup_steps = warmup_steps + self.wandb_name = wandb_name + + if device is None: + self.device = 'cuda' if t.cuda.is_available() else 'cpu' + else: + self.device = device + self.ae.to(self.device) + + if lr_decay_steps_fraction is not None: + assert steps is not None, "total number of steps must be specified for lr decay" + assert resample_steps is None, "lr decay not implemented for resampling" + assert lr_decay_steps_fraction < 1 and lr_decay_steps_fraction > 0 + assert final_lr_fraction <= 1 and final_lr_fraction >= 0 + + self.steps = steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.lr_decay_steps_fraction = lr_decay_steps_fraction + self.final_lr_fraction = final_lr_fraction + + self.resample_steps = resample_steps + if self.resample_steps is not None: + # how many steps since each neuron was last activated? + self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) + else: + self.steps_since_active = None + + self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + if resample_steps is None: + def warmup_fn(step): + warmup_scale = min(step / warmup_steps, 1.) + + if self.lr_decay_steps_fraction is not None: + cooldown_start = self.steps * (1 - self.lr_decay_steps_fraction) + if step >= cooldown_start: + cooldown = 1.0 + (self.final_lr_fraction - 1.0) * (step - cooldown_start) / (self.steps - cooldown_start) + return max(cooldown, self.final_lr_fraction) + return warmup_scale + else: + def warmup_fn(step): + return min((step % resample_steps) / warmup_steps, 1.) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + + def resample_neurons(self, deads, activations): + with t.no_grad(): + if deads.sum() == 0: return + print(f"resampling {deads.sum().item()} neurons") + + # compute loss for each activation + losses = (activations - self.ae(activations)).norm(dim=-1) + + # sample input to create encoder/decoder weights from + n_resample = min([deads.sum(), losses.shape[0]]) + indices = t.multinomial(losses, num_samples=n_resample, replacement=False) + sampled_vecs = activations[indices] + + # get norm of the living neurons + alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() + + # resample first n_resample dead neurons + deads[deads.nonzero()[n_resample:]] = False + self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 + self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T + self.ae.encoder.bias[deads] = 0. + + + # reset Adam parameters for dead neurons + state_dict = self.optimizer.state_dict()['state'] + ## encoder weight + state_dict[1]['exp_avg'][deads] = 0. + state_dict[1]['exp_avg_sq'][deads] = 0. + ## encoder bias + state_dict[2]['exp_avg'][deads] = 0. + state_dict[2]['exp_avg_sq'][deads] = 0. + ## decoder weight + state_dict[3]['exp_avg'][:,deads] = 0. + state_dict[3]['exp_avg_sq'][:,deads] = 0. + + def loss(self, x, step: int, logging=False, **kwargs): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.) + else: + sparsity_scale = 1. + + x_hat, f = self.ae(x, output_features=True) + l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() + l1_loss = f.norm(p=1, dim=-1).mean() + + if self.steps_since_active is not None: + # update steps_since_active + deads = (f == 0).all(dim=0) + self.steps_since_active[deads] += 1 + self.steps_since_active[~deads] = 0 + + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss + + if not logging: + return loss + else: + return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( + x, x_hat, f, + { + 'l2_loss' : l2_loss.item(), + 'mse_loss' : recon_loss.item(), + 'sparsity_loss' : l1_loss.item(), + 'loss' : loss.item() + } + ) + + + def update(self, step, activations): + activations = activations.to(self.device) + + self.optimizer.zero_grad() + loss = self.loss(activations, step=step) + loss.backward() + self.optimizer.step() + self.scheduler.step() + + if self.resample_steps is not None and step % self.resample_steps == 0: + self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) + + @property + def config(self): + return { + 'dict_class': 'AutoEncoder', + 'trainer_class' : 'StandardTrainer', + 'activation_dim': self.ae.activation_dim, + 'dict_size': self.ae.dict_size, + 'lr' : self.lr, + 'l1_penalty' : self.l1_penalty, + 'warmup_steps' : self.warmup_steps, + 'resample_steps' : self.resample_steps, + 'device' : self.device, + 'layer' : self.layer, + 'lm_name' : self.lm_name, + 'wandb_name': self.wandb_name, + 'submodule_name': self.submodule_name, + } + From e0db40b8fadcdd1e24c1945829ecd4eb57451fa8 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 27 Dec 2024 05:00:34 +0000 Subject: [PATCH 24/70] Clean up lr decay --- trainers/standard_lr_decay.py | 216 ---------------------------------- 1 file changed, 216 deletions(-) delete mode 100644 trainers/standard_lr_decay.py diff --git a/trainers/standard_lr_decay.py b/trainers/standard_lr_decay.py deleted file mode 100644 index 1119ab5..0000000 --- a/trainers/standard_lr_decay.py +++ /dev/null @@ -1,216 +0,0 @@ -""" -Implements the standard SAE training scheme. -""" -import torch as t -from typing import Optional - -from ..trainers.trainer import SAETrainer -from ..config import DEBUG -from ..dictionary import AutoEncoder -from collections import namedtuple - -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - -class StandardTrainer(SAETrainer): - """ - Standard SAE training scheme. - """ - def __init__(self, - dict_class=AutoEncoder, - activation_dim:int=512, - dict_size:int=64*512, - lr:float=1e-3, - l1_penalty:float=1e-1, - warmup_steps:int=1000, # lr warmup period at start of training and after each resample - sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training - lr_decay_steps_fraction:Optional[float]=0.2, - final_lr_fraction:Optional[float]=0.1, - steps: Optional[int]=None, # total of steps to train for - resample_steps:Optional[int]=None, # how often to resample neurons - seed:Optional[int]=None, - device=None, - layer:Optional[int]=None, - lm_name:Optional[str]=None, - wandb_name:Optional[str]='StandardTrainer', - submodule_name:Optional[str]=None, - ): - """Options: - warump_steps: LR linear warmup period at start of training and after each resample - sparsity_warmup_steps: Sparsity linear warmup period at start of training - lr_decay_steps_fraction: LR linear decay for the last fraction of training""" - super().__init__(seed) - - assert layer is not None and lm_name is not None - self.layer = layer - self.lm_name = lm_name - self.submodule_name = submodule_name - - if seed is not None: - t.manual_seed(seed) - t.cuda.manual_seed_all(seed) - - # initialize dictionary - self.ae = dict_class(activation_dim, dict_size) - - self.lr = lr - self.l1_penalty=l1_penalty - self.warmup_steps = warmup_steps - self.wandb_name = wandb_name - - if device is None: - self.device = 'cuda' if t.cuda.is_available() else 'cpu' - else: - self.device = device - self.ae.to(self.device) - - if lr_decay_steps_fraction is not None: - assert steps is not None, "total number of steps must be specified for lr decay" - assert resample_steps is None, "lr decay not implemented for resampling" - assert lr_decay_steps_fraction < 1 and lr_decay_steps_fraction > 0 - assert final_lr_fraction <= 1 and final_lr_fraction >= 0 - - self.steps = steps - self.sparsity_warmup_steps = sparsity_warmup_steps - self.lr_decay_steps_fraction = lr_decay_steps_fraction - self.final_lr_fraction = final_lr_fraction - - self.resample_steps = resample_steps - if self.resample_steps is not None: - # how many steps since each neuron was last activated? - self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) - else: - self.steps_since_active = None - - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if resample_steps is None: - def warmup_fn(step): - warmup_scale = min(step / warmup_steps, 1.) - - if self.lr_decay_steps_fraction is not None: - cooldown_start = self.steps * (1 - self.lr_decay_steps_fraction) - if step >= cooldown_start: - cooldown = 1.0 + (self.final_lr_fraction - 1.0) * (step - cooldown_start) / (self.steps - cooldown_start) - return max(cooldown, self.final_lr_fraction) - return warmup_scale - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) - - def resample_neurons(self, deads, activations): - with t.no_grad(): - if deads.sum() == 0: return - print(f"resampling {deads.sum().item()} neurons") - - # compute loss for each activation - losses = (activations - self.ae(activations)).norm(dim=-1) - - # sample input to create encoder/decoder weights from - n_resample = min([deads.sum(), losses.shape[0]]) - indices = t.multinomial(losses, num_samples=n_resample, replacement=False) - sampled_vecs = activations[indices] - - # get norm of the living neurons - alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() - - # resample first n_resample dead neurons - deads[deads.nonzero()[n_resample:]] = False - self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 - self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T - self.ae.encoder.bias[deads] = 0. - - - # reset Adam parameters for dead neurons - state_dict = self.optimizer.state_dict()['state'] - ## encoder weight - state_dict[1]['exp_avg'][deads] = 0. - state_dict[1]['exp_avg_sq'][deads] = 0. - ## encoder bias - state_dict[2]['exp_avg'][deads] = 0. - state_dict[2]['exp_avg_sq'][deads] = 0. - ## decoder weight - state_dict[3]['exp_avg'][:,deads] = 0. - state_dict[3]['exp_avg_sq'][:,deads] = 0. - - def loss(self, x, step: int, logging=False, **kwargs): - - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.) - else: - sparsity_scale = 1. - - x_hat, f = self.ae(x, output_features=True) - l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() - recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() - l1_loss = f.norm(p=1, dim=-1).mean() - - if self.steps_since_active is not None: - # update steps_since_active - deads = (f == 0).all(dim=0) - self.steps_since_active[deads] += 1 - self.steps_since_active[~deads] = 0 - - loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss - - if not logging: - return loss - else: - return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( - x, x_hat, f, - { - 'l2_loss' : l2_loss.item(), - 'mse_loss' : recon_loss.item(), - 'sparsity_loss' : l1_loss.item(), - 'loss' : loss.item() - } - ) - - - def update(self, step, activations): - activations = activations.to(self.device) - - self.optimizer.zero_grad() - loss = self.loss(activations, step=step) - loss.backward() - self.optimizer.step() - self.scheduler.step() - - if self.resample_steps is not None and step % self.resample_steps == 0: - self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) - - @property - def config(self): - return { - 'dict_class': 'AutoEncoder', - 'trainer_class' : 'StandardTrainer', - 'activation_dim': self.ae.activation_dim, - 'dict_size': self.ae.dict_size, - 'lr' : self.lr, - 'l1_penalty' : self.l1_penalty, - 'warmup_steps' : self.warmup_steps, - 'resample_steps' : self.resample_steps, - 'device' : self.device, - 'layer' : self.layer, - 'lm_name' : self.lm_name, - 'wandb_name': self.wandb_name, - 'submodule_name': self.submodule_name, - } - From 911b95890e20998df92710a01d158f4663d6834b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 27 Dec 2024 05:01:39 +0000 Subject: [PATCH 25/70] Add sparsity warmup for trainers with a sparsity penalty --- tests/test_end_to_end.py | 1 + trainers/gdm.py | 43 +++++++++++++++++++++++++--------------- trainers/jumprelu.py | 40 +++++++++++++++++++++++-------------- trainers/standard.py | 43 +++++++++++++++++++++++++--------------- 4 files changed, 80 insertions(+), 47 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 31cb314..8aa6cfc 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -149,6 +149,7 @@ def test_sae_training(): "lr": learning_rate, "l1_penalty": sparsity_penalty, "warmup_steps": warmup_steps, + "sparsity_warmup_steps": None, "resample_steps": resample_steps, "seed": RANDOM_SEED, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", diff --git a/trainers/gdm.py b/trainers/gdm.py index 47ea772..792e64e 100644 --- a/trainers/gdm.py +++ b/trainers/gdm.py @@ -3,6 +3,8 @@ """ import torch as t +from typing import Optional + from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import GatedAutoEncoder @@ -33,19 +35,19 @@ class GatedSAETrainer(SAETrainer): Gated SAE training scheme. """ def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=5e-5, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name='GatedSAETrainer', - submodule_name=None, + dict_class = GatedAutoEncoder, + activation_dim: int = 512, + dict_size: int = 64*512, + lr: float = 5e-5, + l1_penalty: float = 1e-1, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: int = 2000, + seed: Optional[int] = None, + device: Optional[str] = None, + layer: Optional[int] = None, + lm_name: Optional[str] = None, + wandb_name: Optional[str] = 'GatedSAETrainer', + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -64,6 +66,7 @@ def __init__(self, self.lr = lr self.l1_penalty=l1_penalty self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps self.wandb_name = wandb_name if device is None: @@ -81,7 +84,13 @@ def warmup_fn(step): return min(1, step / warmup_steps) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) - def loss(self, x, logging=False, **kwargs): + def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() @@ -90,7 +99,7 @@ def loss(self, x, logging=False, **kwargs): L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean() L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() - loss = L_recon + self.l1_penalty * L_sparse + L_aux + loss = L_recon + (self.l1_penalty * L_sparse * sparsity_scale) + L_aux if not logging: return loss @@ -108,7 +117,7 @@ def loss(self, x, logging=False, **kwargs): def update(self, step, x): x = x.to(self.device) self.optimizer.zero_grad() - loss = self.loss(x) + loss = self.loss(x, step) loss.backward() self.optimizer.step() self.scheduler.step() @@ -123,6 +132,8 @@ def config(self): 'lr' : self.lr, 'l1_penalty' : self.l1_penalty, 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, 'lm_name' : self.lm_name, diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index e27e1b3..586313e 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -3,6 +3,7 @@ import torch import torch.autograd as autograd from torch import nn +from typing import Optional from ..dictionary import Dictionary, JumpReluAutoEncoder from .trainer import SAETrainer @@ -69,21 +70,22 @@ class JumpReluTrainer(nn.Module, SAETrainer): def __init__( self, dict_class=JumpReluAutoEncoder, - activation_dim=512, - dict_size=8192, - steps=30000, + activation_dim: int = 512, + dict_size: int = 8192, + steps: int = 30000, # XXX: Training decay is not implemented - seed=None, + seed: Optional[int] = None, # TODO: What's the default lr use in the paper? - lr=7e-5, - bandwidth=0.001, - sparsity_penalty=1.0, - target_l0=20.0, - device="cpu", - layer=None, - lm_name=None, - wandb_name="JumpRelu", - submodule_name=None, + lr: float = 7e-5, + bandwidth: float = 0.001, + sparsity_penalty: float = 1.0, + sparsity_warmup_steps: int = 2000, + target_l0: float = 20.0, + device: str = "cpu", + layer: Optional[int] = None, + lm_name: Optional[str] = None, + wandb_name: str = "JumpRelu", + submodule_name: Optional[str] = None, ): super().__init__() @@ -100,6 +102,7 @@ def __init__( self.bandwidth = bandwidth self.sparsity_coefficient = sparsity_penalty + self.sparsity_warmup_steps = sparsity_warmup_steps self.target_l0 = target_l0 # TODO: Better auto-naming (e.g. in BatchTopK package) @@ -119,14 +122,20 @@ def __init__( self.logging_parameters = [] - def loss(self, x, logging=False, **_): + def loss(self, x: torch.Tensor, step: int, logging=False, **_): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + f = self.ae.encode(x) recon = self.ae.decode(f) recon_loss = (x - recon).pow(2).sum(dim=-1).mean() l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() - sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) + sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale loss = recon_loss + sparsity_loss if not logging: @@ -170,5 +179,6 @@ def config(self): "submodule_name": self.submodule_name, "bandwidth": self.bandwidth, "sparsity_penalty": self.sparsity_coefficient, + "sparsity_warmup_steps": self.sparsity_warmup_steps, "target_l0": self.target_l0, } diff --git a/trainers/standard.py b/trainers/standard.py index 506a5c0..8b5157f 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -2,6 +2,8 @@ Implements the standard SAE training scheme. """ import torch as t +from typing import Optional + from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import AutoEncoder @@ -33,18 +35,19 @@ class StandardTrainer(SAETrainer): """ def __init__(self, dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - l1_penalty=1e-1, - warmup_steps=1000, # lr warmup period at start of training and after each resample - resample_steps=None, # how often to resample neurons - seed=None, + activation_dim:int=512, + dict_size:int=64*512, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + resample_steps:Optional[int]=None, # how often to resample neurons + seed:Optional[int]=None, device=None, - layer=None, - lm_name=None, - wandb_name='StandardTrainer', - submodule_name=None, + layer:Optional[int]=None, + lm_name:Optional[str]=None, + wandb_name:Optional[str]='StandardTrainer', + submodule_name:Optional[str]=None, ): super().__init__(seed) @@ -70,10 +73,10 @@ def __init__(self, else: self.device = device self.ae.to(self.device) + + self.sparsity_warmup_steps = sparsity_warmup_steps self.resample_steps = resample_steps - - if self.resample_steps is not None: # how many steps since each neuron was last activated? self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) @@ -124,7 +127,13 @@ def resample_neurons(self, deads, activations): state_dict[3]['exp_avg'][:,deads] = 0. state_dict[3]['exp_avg_sq'][:,deads] = 0. - def loss(self, x, logging=False, **kwargs): + def loss(self, x, step: int, logging=False, **kwargs): + + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + x_hat, f = self.ae(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() @@ -136,7 +145,7 @@ def loss(self, x, logging=False, **kwargs): self.steps_since_active[deads] += 1 self.steps_since_active[~deads] = 0 - loss = recon_loss + self.l1_penalty * l1_loss + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss if not logging: return loss @@ -156,7 +165,7 @@ def update(self, step, activations): activations = activations.to(self.device) self.optimizer.zero_grad() - loss = self.loss(activations) + loss = self.loss(activations, step=step) loss.backward() self.optimizer.step() self.scheduler.step() @@ -175,6 +184,8 @@ def config(self): 'l1_penalty' : self.l1_penalty, 'warmup_steps' : self.warmup_steps, 'resample_steps' : self.resample_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, 'lm_name' : self.lm_name, From a2d6c43e94ef068821441d47fef8ae7b3215d09e Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Sun, 29 Dec 2024 05:34:33 +0000 Subject: [PATCH 26/70] Standardize learning rate and sparsity schedules --- README.md | 17 ++++++--- trainers/batch_top_k.py | 57 ++++++++++++++++++---------- trainers/gated_anneal.py | 77 ++++++++++++++++++++++++++------------ trainers/gdm.py | 35 +++++++++++++++--- trainers/jumprelu.py | 38 ++++++++++++++++--- trainers/p_anneal.py | 80 +++++++++++++++++++++++++++------------- trainers/standard.py | 40 ++++++++++++++++---- trainers/top_k.py | 53 +++++++++++++++++--------- training.py | 4 +- 9 files changed, 288 insertions(+), 113 deletions(-) diff --git a/README.md b/README.md index c2e625c..1766b69 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ This is a repository for doing dictionary learning via sparse autoencoders on ne For accessing, saving, and intervening on NN activations, we use the [`nnsight`](http://nnsight.net/) package; as of March 2024, `nnsight` is under active development and may undergo breaking changes. That said, `nnsight` is easy to use and quick to learn; if you plan to modify this repo, then we recommend going through the main `nnsight` demo [here](https://nnsight.net/notebooks/tutorials/walkthrough/). -Some dictionaries trained using this repository (and asociated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. +Some dictionaries trained using this repository (and associated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. # Set-up @@ -15,6 +15,8 @@ pip install -r requirements.txt To use `dictionary_learning`, include it as a subdirectory in some project's directory and import it; see the examples below. +We also provide a [demonstration](https://github.com/adamkarvonen/dictionary_learning_demo), which trains and evaluates 2 SAEs in ~30 minutes before plotting the results. + # Using trained dictionaries You can load and used a pretrained dictionary as follows @@ -61,7 +63,9 @@ This allows us to implement different training protocols (e.g. p-annealing) for Specifically, this repository supports the following trainers: - [`StandardTrainer`](trainers/standard.py): Implements a training scheme similar to that of [Bricken et al., 2023](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder). - [`GatedSAETrainer`](trainers/gdm.py): Implements the training scheme for Gated SAEs described in [Rajamanoharan et al., 2024](https://arxiv.org/abs/2404.16014). -- [`AutoEncoderTopK`](trainers/top_k.py): Implemented the training scheme for Top-K SAEs described in [Gao et al., 2024](https://arxiv.org/abs/2406.04093). +- [`TopKSAETrainer`](trainers/top_k.py): Implemented the training scheme for Top-K SAEs described in [Gao et al., 2024](https://arxiv.org/abs/2406.04093). +- [`BatchTopKSAETrainer`](trainers/batch_top_k.py): Implemented the training scheme for Batch Top-K SAEs described in [Bussmann et al., 2024](https://arxiv.org/abs/2412.06410). +- [`JumpReluTrainer`](trainers/jumprelu.py): Implemented the training scheme for JumpReLU SAEs described in [Rajamanoharan et al., 2024](https://arxiv.org/abs/2407.14435). - [`PAnnealTrainer`](trainers/p_anneal.py): Extends the `StandardTrainer` by providing the option to anneal the sparsity parameter p. - [`GatedAnnealTrainer`](trainers/gated_anneal.py): Extends the `GatedSAETrainer` by providing the option for p-annealing, similar to `PAnnealTrainer`. @@ -121,8 +125,11 @@ ae = trainSAE( ``` Some technical notes our training infrastructure and supported features: * Training uses the `ConstrainedAdam` optimizer defined in `training.py`. This is a variant of Adam which supports constraining the `AutoEncoder`'s decoder weights to be norm 1. -* Neuron resampling: if a `resample_steps` argument is passed to `trainSAE`, then dead neurons will periodically be resampled according to the procedure specified [here](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling). -* Learning rate warmup: if a `warmup_steps` argument is passed to `trainSAE`, then a linear LR warmup is used at the start of training and, if doing neuron resampling, also after every time neurons are resampled. +* Neuron resampling: if a `resample_steps` argument is passed to the Trainer, then dead neurons will periodically be resampled according to the procedure specified [here](https://transformer-circuits.pub/2023/monosemantic-features/index.html#appendix-autoencoder-resampling). +* Learning rate warmup: if a `warmup_steps` argument is passed to the Trainer, then a linear LR warmup is used at the start of training and, if doing neuron resampling, also after every time neurons are resampled. +* Sparsity penalty warmup: if a `sparsity_warmup_steps` is passed to the Trainer, then a linear warmup is applied to the sparsity penalty at the start of training. +* Learning rate decay: if a `decay_start` is passed to the Trainer, then a linear LR decay is used from `decay_start` to the end of training. +* If `normalize_activations` is True and passed to `trainSAE`, then the activations will be normalized to have unit mean squared norm. The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. This is very helpful for hyperparameter transfer between different layers and models. If `submodule` is a model component where the activations are tuples (e.g. this is common when working with residual stream activations), then the buffer yields the first coordinate of the tuple. @@ -204,5 +211,3 @@ We've included support for some experimental features. We briefly investigated t * h/t to Max Li for this suggestion. * **Replacing L1 loss with entropy**. Based on the ideas in this [post](https://transformer-circuits.pub/2023/may-update/index.html#simple-factorization), we experimented with using entropy to regularize a dictionary's hidden state instead of L1 loss. This seemed to cause the features to split into dead features (which never fired) and very high-frequency features which fired on nearly every input, which was not the desired behavior. But plausibly there is a way to make this work better. * **Ghost grads**, as described [here](https://transformer-circuits.pub/2024/jan-update/index.html). - - diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index a7fbdc8..dd59f1a 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -3,6 +3,7 @@ import torch.nn.functional as F import einops from collections import namedtuple +from typing import Optional from ..dictionary import Dictionary from ..trainers.trainer import SAETrainer @@ -108,22 +109,23 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": class BatchTopKTrainer(SAETrainer): def __init__( self, - dict_class=BatchTopKSAE, - activation_dim=512, - dict_size=64 * 512, - k=8, - auxk_alpha=1 / 32, - decay_start=24000, - threshold_beta=0.999, - threshold_start_step=1000, - steps=30000, - top_k_aux=512, - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name="BatchTopKSAE", - submodule_name=None, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + k: int, + layer: int, + lm_name: str, + dict_class: type = BatchTopKSAE, + auxk_alpha: float = 1 / 32, + warmup_steps: int = 1000, + decay_start: Optional[int] = None, # when does the lr decay start + threshold_beta: float = 0.999, + threshold_start_step: int = 1000, + top_k_aux: int = 512, + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: str = "BatchTopKSAE", + submodule_name: Optional[str] = None, ): super().__init__(seed) assert layer is not None and lm_name is not None @@ -132,6 +134,8 @@ def __init__( self.submodule_name = submodule_name self.wandb_name = wandb_name self.steps = steps + self.decay_start = decay_start + self.warmup_steps = warmup_steps self.k = k self.threshold_beta = threshold_beta self.threshold_start_step = threshold_start_step @@ -156,12 +160,21 @@ def __init__( self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + if decay_start is not None: + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + + assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + def lr_fn(step): - if step < decay_start: - return 1.0 - else: + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: return (steps - step) / (steps - decay_start) + return 1.0 + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) @@ -264,6 +277,12 @@ def config(self): "dict_class": "BatchTopKSAE", "lr": self.lr, "steps": self.steps, + "auxk_alpha": self.auxk_alpha, + "warmup_steps": self.warmup_steps, + "decay_start": self.decay_start, + "threshold_beta": self.threshold_beta, + "threshold_start_step": self.threshold_start_step, + "top_k_aux": self.top_k_aux, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, diff --git a/trainers/gated_anneal.py b/trainers/gated_anneal.py index 664904b..6b4e774 100644 --- a/trainers/gated_anneal.py +++ b/trainers/gated_anneal.py @@ -3,6 +3,8 @@ """ import torch as t +from typing import Optional + from ..trainers.trainer import SAETrainer from ..config import DEBUG from ..dictionary import GatedAutoEncoder @@ -33,26 +35,28 @@ class GatedAnnealTrainer(SAETrainer): Gated SAE training scheme with p-annealing. """ def __init__(self, - dict_class=GatedAutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=3e-4, - warmup_steps=1000, # lr warmup period at start of training and after each resample - sparsity_function='Lp^p', # Lp or Lp^p - initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer - anneal_start=15000, # step at which to start annealing p - anneal_end=None, # step at which to stop annealing, defaults to steps-1 - p_start=1, # starting value of p (constant throughout warmup) - p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded - n_sparsity_updates = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times - sparsity_queue_length = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty - resample_steps=None, # number of steps after which to resample dead neurons - steps=None, # total number of steps to train for - device=None, - seed=42, - layer=None, - lm_name=None, - wandb_name='GatedAnnealTrainer', + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class: type = GatedAutoEncoder, + lr: float = 3e-4, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training + decay_start: Optional[int] = None, # decay learning rate after this many steps + sparsity_function: str = 'Lp^p', # Lp or Lp^p + initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer + anneal_start: int = 15000, # step at which to start annealing p + anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1 + p_start: float = 1, # starting value of p (constant throughout warmup) + p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded + n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times + sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty + resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons + device: Optional[str] = None, + seed: Optional[int] = 42, + wandb_name: str = 'GatedAnnealTrainer', ): super().__init__(seed) @@ -98,6 +102,8 @@ def __init__(self, self.sparsity_queue = [] self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.steps = steps self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] self.seed = seed @@ -111,9 +117,28 @@ def __init__(self, self.steps_since_active = None self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + + if decay_start is not None: + assert resample_steps is None, "decay_start and resample_steps are currently mutually exclusive." + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + + assert 0 <= warmup_steps < anneal_start, "warmup_steps must be >= 0 and < anneal_start." + + if sparsity_warmup_steps is not None: + assert 0 <= sparsity_warmup_steps < anneal_start, "sparsity_warmup_steps must be >= 0 and < anneal_start." + if resample_steps is None: def warmup_fn(step): - return min(step / warmup_steps, 1.) + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + return (steps - step) / (steps - decay_start) + + return 1.0 else: def warmup_fn(step): return min((step % resample_steps) / warmup_steps, 1.) @@ -160,7 +185,11 @@ def lp_norm(self, f, p): else: raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") - def loss(self, x, step, logging=False, **kwargs): + def loss(self, x:t.Tensor, step:int, logging=False, **kwargs): + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() @@ -170,7 +199,7 @@ def loss(self, x, step, logging=False, **kwargs): fs = f_gate # feature activation that we use for sparsity term lp_loss = self.lp_norm(fs, self.p) - scaled_lp_loss = lp_loss * self.sparsity_coeff + scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale self.lp_loss = lp_loss self.scaled_lp_loss = scaled_lp_loss @@ -263,6 +292,8 @@ def config(self): 'n_sparsity_updates' : self.n_sparsity_updates, 'warmup_steps' : self.warmup_steps, 'resample_steps' : self.resample_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'decay_start' : self.decay_start, 'steps' : self.steps, 'seed' : self.seed, 'layer' : self.layer, diff --git a/trainers/gdm.py b/trainers/gdm.py index 792e64e..bb8ff0f 100644 --- a/trainers/gdm.py +++ b/trainers/gdm.py @@ -35,17 +35,19 @@ class GatedSAETrainer(SAETrainer): Gated SAE training scheme. """ def __init__(self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, dict_class = GatedAutoEncoder, - activation_dim: int = 512, - dict_size: int = 64*512, lr: float = 5e-5, l1_penalty: float = 1e-1, warmup_steps: int = 1000, # lr warmup period at start of training and after each resample - sparsity_warmup_steps: int = 2000, + sparsity_warmup_steps: Optional[int] = 2000, + decay_start:Optional[int]=None, # decay learning rate after this many steps seed: Optional[int] = None, device: Optional[str] = None, - layer: Optional[int] = None, - lm_name: Optional[str] = None, wandb_name: Optional[str] = 'GatedSAETrainer', submodule_name: Optional[str] = None, ): @@ -67,6 +69,7 @@ def __init__(self, self.l1_penalty=l1_penalty self.warmup_steps = warmup_steps self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.wandb_name = wandb_name if device is None: @@ -80,8 +83,27 @@ def __init__(self, self.ae.decoder.parameters(), lr=lr ) + + if decay_start is not None: + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + + assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + + if sparsity_warmup_steps is not None: + assert 0 <= sparsity_warmup_steps < steps, "sparsity_warmup_steps must be >= 0 and < steps." + def warmup_fn(step): - return min(1, step / warmup_steps) + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + return (steps - step) / (steps - decay_start) + + return 1.0 + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs): @@ -133,6 +155,7 @@ def config(self): 'l1_penalty' : self.l1_penalty, 'warmup_steps' : self.warmup_steps, 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'decay_start' : self.decay_start, 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index 586313e..c0c351a 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -69,21 +69,23 @@ class JumpReluTrainer(nn.Module, SAETrainer): """ def __init__( self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, dict_class=JumpReluAutoEncoder, - activation_dim: int = 512, - dict_size: int = 8192, - steps: int = 30000, # XXX: Training decay is not implemented seed: Optional[int] = None, # TODO: What's the default lr use in the paper? lr: float = 7e-5, bandwidth: float = 0.001, sparsity_penalty: float = 1.0, - sparsity_warmup_steps: int = 2000, + warmup_steps:int=1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + decay_start:Optional[int]=None, # decay learning rate after this many steps target_l0: float = 20.0, device: str = "cpu", - layer: Optional[int] = None, - lm_name: Optional[str] = None, wandb_name: str = "JumpRelu", submodule_name: Optional[str] = None, ): @@ -102,7 +104,9 @@ def __init__( self.bandwidth = bandwidth self.sparsity_coefficient = sparsity_penalty + self.warmup_steps = warmup_steps self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.target_l0 = target_l0 # TODO: Better auto-naming (e.g. in BatchTopK package) @@ -120,6 +124,28 @@ def __init__( self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8 ) + if decay_start is not None: + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + + assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + + if sparsity_warmup_steps is not None: + assert 0 <= sparsity_warmup_steps < steps, "sparsity_warmup_steps must be >= 0 and < steps." + + def warmup_fn(step): + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + return (steps - step) / (steps - decay_start) + + return 1.0 + + self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + self.logging_parameters = [] def loss(self, x: torch.Tensor, step: int, logging=False, **_): diff --git a/trainers/p_anneal.py b/trainers/p_anneal.py index 0138547..cf886ef 100644 --- a/trainers/p_anneal.py +++ b/trainers/p_anneal.py @@ -1,5 +1,5 @@ import torch as t - +from typing import Optional """ Implements the standard SAE training scheme. """ @@ -34,27 +34,29 @@ class PAnnealTrainer(SAETrainer): You can further choose to use Lp or Lp^p sparsity. """ def __init__(self, - dict_class=AutoEncoder, - activation_dim=512, - dict_size=64*512, - lr=1e-3, - warmup_steps=1000, # lr warmup period at start of training and after each resample - sparsity_function='Lp', # Lp or Lp^p - initial_sparsity_penalty=1e-1, # equal to l1 penalty in standard trainer - anneal_start=15000, # step at which to start annealing p - anneal_end=None, # step at which to stop annealing, defaults to steps-1 - p_start=1, # starting value of p (constant throughout warmup) - p_end=0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded - n_sparsity_updates = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times - sparsity_queue_length = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty - resample_steps=None, # number of steps after which to resample dead neurons - steps=None, # total number of steps to train for - device=None, - seed=42, - layer=None, - lm_name=None, - wandb_name='PAnnealTrainer', - submodule_name: str = None, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class: type = AutoEncoder, + lr: float = 1e-3, + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + decay_start: Optional[int] = None, # step at which to start decaying lr + sparsity_warmup_steps: Optional[int] = 2000, # number of steps to warm up sparsity penalty + sparsity_function: str = 'Lp', # Lp or Lp^p + initial_sparsity_penalty: float = 1e-1, # equal to l1 penalty in standard trainer + anneal_start: int = 15000, # step at which to start annealing p + anneal_end: Optional[int] = None, # step at which to stop annealing, defaults to steps-1 + p_start: float = 1, # starting value of p (constant throughout warmup) + p_end: float = 0, # annealing p_start to p_end linearly after warmup_steps, exact endpoint excluded + n_sparsity_updates: int | str = 10, # number of times to update the sparsity penalty, at most steps-anneal_start times + sparsity_queue_length: int = 10, # number of recent sparsity loss terms, onle needed for adaptive_sparsity_penalty + resample_steps: Optional[int] = None, # number of steps after which to resample dead neurons + device: Optional[str] = None, + seed: int = 42, + wandb_name: str = 'PAnnealTrainer', + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -98,6 +100,8 @@ def __init__(self, self.sparsity_queue = [] self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.decay_start = decay_start self.steps = steps self.logging_parameters = ['p', 'next_p', 'lp_loss', 'scaled_lp_loss', 'sparsity_coeff'] self.seed = seed @@ -111,9 +115,28 @@ def __init__(self, self.steps_since_active = None self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + + if decay_start is not None: + assert resample_steps is None, "decay_start and resample_steps are currently mutually exclusive." + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + + assert 0 <= warmup_steps < anneal_start, "warmup_steps must be >= 0 and < anneal_start." + + if sparsity_warmup_steps is not None: + assert 0 <= sparsity_warmup_steps < anneal_start, "sparsity_warmup_steps must be >= 0 and < anneal_start." + if resample_steps is None: def warmup_fn(step): - return min(step / warmup_steps, 1.) + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + return (steps - step) / (steps - decay_start) + + return 1.0 else: def warmup_fn(step): return min((step % resample_steps) / warmup_steps, 1.) @@ -163,12 +186,17 @@ def lp_norm(self, f, p): else: raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") - def loss(self, x, step, logging=False): + def loss(self, x: t.Tensor, step:int, logging=False): + if self.sparsity_warmup_steps is not None: + sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) + else: + sparsity_scale = 1.0 + # Compute loss terms x_hat, f = self.ae(x, output_features=True) recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() lp_loss = self.lp_norm(f, self.p) - scaled_lp_loss = lp_loss * self.sparsity_coeff + scaled_lp_loss = lp_loss * self.sparsity_coeff * sparsity_scale self.lp_loss = lp_loss self.scaled_lp_loss = scaled_lp_loss @@ -241,6 +269,8 @@ def config(self): 'sparsity_queue_length' : self.sparsity_queue_length, 'n_sparsity_updates' : self.n_sparsity_updates, 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps': self.sparsity_warmup_steps, + 'decay_start': self.decay_start, 'resample_steps' : self.resample_steps, 'steps' : self.steps, 'seed' : self.seed, diff --git a/trainers/standard.py b/trainers/standard.py index 8b5157f..a466cc7 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -34,18 +34,20 @@ class StandardTrainer(SAETrainer): Standard SAE training scheme. """ def __init__(self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, dict_class=AutoEncoder, - activation_dim:int=512, - dict_size:int=64*512, - lr:float=1e-3, + lr:float=1e-3, l1_penalty:float=1e-1, warmup_steps:int=1000, # lr warmup period at start of training and after each resample sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + decay_start:Optional[int]=None, # decay learning rate after this many steps resample_steps:Optional[int]=None, # how often to resample neurons seed:Optional[int]=None, device=None, - layer:Optional[int]=None, - lm_name:Optional[str]=None, wandb_name:Optional[str]='StandardTrainer', submodule_name:Optional[str]=None, ): @@ -66,6 +68,9 @@ def __init__(self, self.lr = lr self.l1_penalty=l1_penalty self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.steps = steps + self.decay_start = decay_start self.wandb_name = wandb_name if device is None: @@ -73,8 +78,6 @@ def __init__(self, else: self.device = device self.ae.to(self.device) - - self.sparsity_warmup_steps = sparsity_warmup_steps self.resample_steps = resample_steps if self.resample_steps is not None: @@ -84,9 +87,28 @@ def __init__(self, self.steps_since_active = None self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) + + if decay_start is not None: + assert resample_steps is None, "decay_start and resample_steps are currently mutually exclusive." + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + + assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + + if sparsity_warmup_steps is not None: + assert 0 <= sparsity_warmup_steps < steps, "sparsity_warmup_steps must be >= 0 and < steps." + if resample_steps is None: def warmup_fn(step): - return min(step / warmup_steps, 1.) + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + return (steps - step) / (steps - decay_start) + + return 1.0 else: def warmup_fn(step): return min((step % resample_steps) / warmup_steps, 1.) @@ -185,6 +207,8 @@ def config(self): 'warmup_steps' : self.warmup_steps, 'resample_steps' : self.resample_steps, 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'steps' : self.steps, + 'decay_start' : self.decay_start, 'seed' : self.seed, 'device' : self.device, 'layer' : self.layer, diff --git a/trainers/top_k.py b/trainers/top_k.py index 12c549a..64ca2ab 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -162,21 +162,22 @@ class TopKTrainer(SAETrainer): def __init__( self, - dict_class=AutoEncoderTopK, - activation_dim=512, - dict_size=64 * 512, - k=100, - auxk_alpha=1 / 32, # see Appendix A.2 - decay_start=24000, # when does the lr decay start - threshold_beta=0.999, - threshold_start_step=1000, - steps=30000, # when when does training end - seed=None, - device=None, - layer=None, - lm_name=None, - wandb_name="AutoEncoderTopK", - submodule_name=None, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + k: int, + layer: int, + lm_name: str, + dict_class: type = AutoEncoderTopK, + auxk_alpha: float = 1 / 32, # see Appendix A.2 + warmup_steps: int = 1000, + decay_start: Optional[int] = None, # when does the lr decay start + threshold_beta: float = 0.999, + threshold_start_step: int = 1000, + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: str = "AutoEncoderTopK", + submodule_name: Optional[str] = None, ): super().__init__(seed) @@ -187,6 +188,8 @@ def __init__( self.wandb_name = wandb_name self.steps = steps + self.decay_start = decay_start + self.warmup_steps = warmup_steps self.k = k self.threshold_beta = threshold_beta self.threshold_start_step = threshold_start_step @@ -212,12 +215,21 @@ def __init__( # Optimizer and scheduler self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + if decay_start is not None: + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + + assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + def lr_fn(step): - if step < decay_start: - return 1.0 - else: + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: return (steps - step) / (steps - decay_start) + return 1.0 + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) # Training parameters @@ -347,6 +359,11 @@ def config(self): "dict_class": "AutoEncoderTopK", "lr": self.lr, "steps": self.steps, + "auxk_alpha": self.auxk_alpha, + "warmup_steps": self.warmup_steps, + "decay_start": self.decay_start, + "threshold_beta": self.threshold_beta, + "threshold_start_step": self.threshold_start_step, "seed": self.seed, "activation_dim": self.ae.activation_dim, "dict_size": self.ae.dict_size, diff --git a/training.py b/training.py index d4c6a38..34c644e 100644 --- a/training.py +++ b/training.py @@ -105,10 +105,10 @@ def get_norm_factor(data, steps: int) -> float: def trainSAE( data, trainer_configs: list[dict], + steps: int, use_wandb:bool=False, wandb_entity:str="", wandb_project:str="", - steps:Optional[int]=None, save_steps:Optional[list[int]]=None, save_dir:Optional[str]=None, log_steps:Optional[int]=None, @@ -179,7 +179,7 @@ def trainSAE( if normalize_activations: act /= norm_factor - if steps is not None and step >= steps: + if step >= steps: break # logging From e00fd643050584f4cfe15ad41e6a01e29e3c0766 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 30 Dec 2024 16:53:56 +0000 Subject: [PATCH 27/70] Properly set new parameters in end to end test --- tests/test_end_to_end.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 8aa6cfc..fdbce15 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -94,7 +94,7 @@ def test_sae_training(): learning_rate = 3e-4 # topk sae training parameters - decay_start = 24000 + decay_start = None auxk_alpha = 1 / 32 submodule = model.gpt_neox.layers[LAYER] @@ -128,6 +128,7 @@ def test_sae_training(): "dict_size": expansion_factor * activation_dim, "k": k, "auxk_alpha": auxk_alpha, # see Appendix A.2 + "warmup_steps": 0, "decay_start": decay_start, # when does the lr decay start "steps": steps, # when when does training end "seed": RANDOM_SEED, @@ -150,6 +151,8 @@ def test_sae_training(): "l1_penalty": sparsity_penalty, "warmup_steps": warmup_steps, "sparsity_warmup_steps": None, + "decay_start": decay_start, + "steps": steps, "resample_steps": resample_steps, "seed": RANDOM_SEED, "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", From 1df47d83d9dea07d2fb905509b635ac6139bcd48 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 30 Dec 2024 17:46:11 +0000 Subject: [PATCH 28/70] Make sure we step the learning rate scheduler --- trainers/jumprelu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index c0c351a..b2f7155 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -185,6 +185,7 @@ def update(self, step, x): torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.optimizer.step() + self.scheduler.step() self.optimizer.zero_grad() return loss.item() From 8ade55b6eb57ed7c7b06a70187ee68e1056bb95b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 31 Dec 2024 17:39:44 +0000 Subject: [PATCH 29/70] Initial matroyshka implementation --- trainers/matroyshka_batch_top_k.py | 366 +++++++++++++++++++++++++++++ 1 file changed, 366 insertions(+) create mode 100644 trainers/matroyshka_batch_top_k.py diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py new file mode 100644 index 0000000..23f28e7 --- /dev/null +++ b/trainers/matroyshka_batch_top_k.py @@ -0,0 +1,366 @@ +import torch as t +import torch.nn as nn +import torch.nn.functional as F +import einops +from collections import namedtuple +from typing import Optional +from math import isclose + +from ..dictionary import Dictionary +from ..trainers.trainer import SAETrainer + + +class MatroyshkaBatchTopKSAE(Dictionary, nn.Module): + def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int]): + super().__init__() + self.activation_dim = activation_dim + self.dict_size = dict_size + + assert sum(group_sizes) == dict_size, "group sizes must sum to dict_size" + assert all(s > 0 for s in group_sizes), "all group sizes must be positive" + + assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" + self.register_buffer("k", t.tensor(k)) + self.register_buffer("threshold", t.tensor(-1.0)) + + self.group_sizes = group_sizes + self.active_groups = len(group_sizes) + group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0)) + + self.register_buffer("group_indices", t.tensor(group_indices)) + + self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) + self.b_enc = nn.Parameter(t.zeros(dict_size)) + self.W_dec = nn.Parameter(t.empty(dict_size, activation_dim)) + self.b_dec = nn.Parameter(t.zeros(activation_dim)) + + self.W_dec.data = t.randn_like(self.W_dec) + self.set_decoder_norm_to_unit_norm() + self.W_enc.data = self.W_dec.data.clone().T + + def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): + post_relu_feat_acts_BF = nn.functional.relu((x - self.b_dec) @ self.W_enc + self.b_enc) + + if use_threshold: + encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) + else: + # Flatten and perform batch top-k + flattened_acts = post_relu_feat_acts_BF.flatten() + post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) + + buffer_BF = t.zeros_like(post_relu_feat_acts_BF) + encoded_acts_BF = ( + buffer_BF.flatten() + .scatter(-1, post_topk.indices, post_topk.values) + .reshape(buffer_BF.shape) + ) + + max_act_index = self.group_indices[self.active_groups] + encoded_acts_BF[:, max_act_index:] = 0 + + if return_active: + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 + else: + return encoded_acts_BF + + def decode(self, x: t.Tensor) -> t.Tensor: + return x @ self.W_dec + self.b_dec + + def forward(self, x: t.Tensor, output_features: bool = False): + encoded_acts_BF = self.encode(x) + x_hat_BD = self.decode(encoded_acts_BF) + + if not output_features: + return x_hat_BD + else: + return x_hat_BD, encoded_acts_BF + + @t.no_grad() + def set_decoder_norm_to_unit_norm(self): + eps = t.finfo(self.W_dec.dtype).eps + norm = t.norm(self.W_dec.data, dim=0, keepdim=True) + self.W_dec.data /= norm + eps + + @t.no_grad() + def remove_gradient_parallel_to_decoder_directions(self): + assert self.W_dec.grad is not None + + parallel_component = einops.einsum( + self.W_dec.grad, + self.W_dec.data, + "d_sae d_in, d_sae d_in -> d_sae", + ) + self.W_dec.grad -= einops.einsum( + parallel_component, + self.W_dec.data, + "d_sae, d_sae d_in -> d_sae d_in", + ) + + @t.no_grad() + def scale_biases(self, scale: float): + self.b_enc.data *= scale + self.b_dec.data *= scale + if self.threshold >= 0: + self.threshold *= scale + + @classmethod + def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatroyshkaBatchTopKSAE": + state_dict = t.load(path) + dict_size, activation_dim = state_dict["W_enc"].shape + if k is None: + k = state_dict["k"].item() + elif "k" in state_dict and k != state_dict["k"].item(): + raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") + + autoencoder = cls(activation_dim, dict_size, k) + autoencoder.load_state_dict(state_dict) + if device is not None: + autoencoder.to(device) + return autoencoder + + +class MatroyshkaBatchTopKTrainer(SAETrainer): + def __init__( + self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + k: int, + layer: int, + lm_name: str, + group_fractions: list[float], + group_weights: Optional[list[float]] = None, + dict_class: type = MatroyshkaBatchTopKSAE, + auxk_alpha: float = 1 / 32, + warmup_steps: int = 1000, + decay_start: Optional[int] = None, # when does the lr decay start + threshold_beta: float = 0.999, + threshold_start_step: int = 1000, + top_k_aux: int = 512, + seed: Optional[int] = None, + device: Optional[str] = None, + wandb_name: str = "BatchTopKSAE", + submodule_name: Optional[str] = None, + ): + super().__init__(seed) + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + self.wandb_name = wandb_name + self.steps = steps + self.decay_start = decay_start + self.warmup_steps = warmup_steps + self.k = k + self.threshold_beta = threshold_beta + self.threshold_start_step = threshold_start_step + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + assert isclose(sum(group_fractions), 1.0), "group_fractions must sum to 1.0" + # Calculate all groups except the last one + group_sizes = [int(f * dict_size) for f in group_fractions[:-1]] + # Put remainder in the last group + group_sizes.append(dict_size - sum(group_sizes)) + + if group_weights is None: + group_weights = [1.0] * len(group_sizes) + + assert len(group_sizes) == len( + group_weights + ), "group_sizes and group_weights must have the same length" + + self.group_fractions = group_fractions + self.group_sizes = group_sizes + self.group_weights = group_weights + + self.ae = dict_class(activation_dim, dict_size, k, group_sizes) + + if device is None: + self.device = "cuda" if t.cuda.is_available() else "cpu" + else: + self.device = device + self.ae.to(self.device) + + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 + self.auxk_alpha = auxk_alpha + self.dead_feature_threshold = 10_000_000 + self.top_k_aux = top_k_aux + + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) + + if decay_start is not None: + assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + + assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + + def lr_fn(step): + if step < warmup_steps: + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + return (steps - step) / (steps - decay_start) + + return 1.0 + + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + self.logging_parameters = ["effective_l0", "dead_features"] + self.effective_l0 = -1 + self.dead_features = -1 + + def get_auxiliary_loss(self, x, x_reconstruct, acts): + dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + if dead_features.sum() > 0: + residual = x.float() - x_reconstruct.float() + acts_topk_aux = t.topk( + acts[:, dead_features], + min(self.top_k_aux, dead_features.sum()), + dim=-1, + ) + acts_aux = t.zeros_like(acts[:, dead_features]).scatter( + -1, acts_topk_aux.indices, acts_topk_aux.values + ) + x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features]) + l2_loss_aux = ( + self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() + ) + return l2_loss_aux + else: + return t.tensor(0, dtype=x.dtype, device=x.device) + + def loss(self, x, step=None, logging=False): + f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False) + # l0 = (f != 0).float().sum(dim=-1).mean().item() + + if step > self.threshold_start_step: + with t.no_grad(): + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach() + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) + + x_reconstruct = t.zeros_like(x) + self.ae.b_dec + total_l2_loss = 0.0 + l2_losses = t.tensor([]).to(self.device) + + for i in range(self.ae.active_groups): + group_start = self.ae.group_indices[i] + group_end = self.ae.group_indices[i + 1] + W_dec_slice = self.ae.W_dec[group_start:group_end, :] + acts_slice = f[:, group_start:group_end] + x_reconstruct = x_reconstruct + acts_slice @ W_dec_slice + + l2_loss = (x_reconstruct - x).pow(2).sum(dim=-1).mean() * self.group_weights[i] + total_l2_loss += l2_loss + l2_losses = t.cat([l2_losses, l2_loss.unsqueeze(0)]) + + min_l2_loss = l2_losses.min().item() + max_l2_loss = l2_losses.max().item() + mean_l2_loss = l2_losses.mean() + + self.effective_l0 = self.k + + num_tokens_in_step = x.size(0) + did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) + did_fire[active_indices] = True + self.num_tokens_since_fired += num_tokens_in_step + self.num_tokens_since_fired[did_fire] = 0 + + auxk_loss = self.get_auxiliary_loss(x, x_reconstruct, f) + + auxk_loss = auxk_loss.sum(dim=-1).mean() + loss = mean_l2_loss + self.auxk_alpha * auxk_loss + + if not logging: + return loss + else: + return namedtuple("LossLog", ["x", "x_hat", "f", "losses"])( + x, + x_reconstruct, + f, + { + "l2_loss": mean_l2_loss.item(), + "auxk_loss": auxk_loss.item(), + "loss": loss.item(), + "min_l2_loss": min_l2_loss, + "max_l2_loss": max_l2_loss, + }, + ) + + def update(self, step, x): + if step == 0: + median = self.geometric_median(x) + self.ae.b_dec.data = median + + self.ae.set_decoder_norm_to_unit_norm() + + x = x.to(self.device) + loss = self.loss(x, step=step) + loss.backward() + + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) + self.ae.remove_gradient_parallel_to_decoder_directions() + + self.optimizer.step() + self.optimizer.zero_grad() + self.scheduler.step() + + return loss.item() + + @property + def config(self): + return { + "trainer_class": "MatroyshkaBatchTopKTrainer", + "dict_class": "MatroyshkaBatchTopKSAE", + "lr": self.lr, + "steps": self.steps, + "auxk_alpha": self.auxk_alpha, + "warmup_steps": self.warmup_steps, + "decay_start": self.decay_start, + "threshold_beta": self.threshold_beta, + "threshold_start_step": self.threshold_start_step, + "top_k_aux": self.top_k_aux, + "seed": self.seed, + "activation_dim": self.ae.activation_dim, + "dict_size": self.ae.dict_size, + "group_fractions": self.group_fractions, + "group_weights": self.group_weights, + "group_sizes": self.group_sizes, + "k": self.ae.k.item(), + "device": self.device, + "layer": self.layer, + "lm_name": self.lm_name, + "wandb_name": self.wandb_name, + "submodule_name": self.submodule_name, + } + + @staticmethod + def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): + guess = points.mean(dim=0) + prev = t.zeros_like(guess) + weights = t.ones(len(points), device=points.device) + + for _ in range(max_iter): + prev = guess + weights = 1 / t.norm(points - guess, dim=1) + weights /= weights.sum() + guess = (weights.unsqueeze(1) * points).sum(dim=0) + if t.norm(guess - prev) < tol: + break + + return guess From 764d4ac4450ea6b7d79de52fdec70c7c1e0dfb79 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 31 Dec 2024 18:08:27 +0000 Subject: [PATCH 30/70] Fix loading matroyshkas from_pretrained() --- trainers/matroyshka_batch_top_k.py | 10 ++++++---- training.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index 23f28e7..3721be3 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -23,11 +23,11 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: lis self.register_buffer("k", t.tensor(k)) self.register_buffer("threshold", t.tensor(-1.0)) - self.group_sizes = group_sizes self.active_groups = len(group_sizes) group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0)) + self.group_indices = group_indices - self.register_buffer("group_indices", t.tensor(group_indices)) + self.register_buffer("group_sizes", t.tensor(group_sizes)) self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) self.b_enc = nn.Parameter(t.zeros(dict_size)) @@ -106,13 +106,15 @@ def scale_biases(self, scale: float): @classmethod def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatroyshkaBatchTopKSAE": state_dict = t.load(path) - dict_size, activation_dim = state_dict["W_enc"].shape + activation_dim, dict_size = state_dict["W_enc"].shape if k is None: k = state_dict["k"].item() elif "k" in state_dict and k != state_dict["k"].item(): raise ValueError(f"k={k} != {state_dict['k'].item()}=state_dict['k']") - autoencoder = cls(activation_dim, dict_size, k) + group_sizes = state_dict["group_sizes"].tolist() + + autoencoder = cls(activation_dim, dict_size, k=k, group_sizes=group_sizes) autoencoder.load_state_dict(state_dict) if device is not None: autoencoder.to(device) diff --git a/training.py b/training.py index 34c644e..63d1cb8 100644 --- a/training.py +++ b/training.py @@ -187,6 +187,16 @@ def trainSAE( log_stats( trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues ) + if step % 100 == 0: + z = act.clone() + for i, trainer in enumerate(trainers): + act = z.clone() + act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) + + # L0 + l0 = (f != 0).float().sum(dim=-1).mean().item() + + print(f"Step {step}: L0 = {l0}") # saving if save_steps is not None and step in save_steps: From 53836033b305142fb6d076a52a7679e0642ddb7a Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 1 Jan 2025 00:06:09 +0000 Subject: [PATCH 31/70] norm the correct decoder dimension --- trainers/matroyshka_batch_top_k.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index 3721be3..d96d6f7 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -78,7 +78,8 @@ def forward(self, x: t.Tensor, output_features: bool = False): @t.no_grad() def set_decoder_norm_to_unit_norm(self): eps = t.finfo(self.W_dec.dtype).eps - norm = t.norm(self.W_dec.data, dim=0, keepdim=True) + norm = t.norm(self.W_dec.data, dim=1, keepdim=True) + self.W_dec.data /= norm + eps @t.no_grad() From ceabbc5233dcf28f0f5afd53e0de850d19f34d78 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 1 Jan 2025 00:06:57 +0000 Subject: [PATCH 32/70] Add temperature scaling to matroyshka --- trainers/matroyshka_batch_top_k.py | 26 +++++++++++++++++++++++++- utils.py | 7 ++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index d96d6f7..483ea0f 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -10,6 +10,25 @@ from ..trainers.trainer import SAETrainer +def apply_temperature(probabilities: list[float], temperature: float) -> list[float]: + """ + Apply temperature scaling to a list of probabilities using PyTorch. + + Args: + probabilities (list[float]): Initial probability distribution + temperature (float): Temperature parameter (> 0) + + Returns: + list[float]: Scaled and normalized probabilities + """ + probs_tensor = t.tensor(probabilities, dtype=t.float32) + logits = t.log(probs_tensor) + scaled_logits = logits / temperature + scaled_probs = t.nn.functional.softmax(scaled_logits, dim=0) + + return scaled_probs.tolist() + + class MatroyshkaBatchTopKSAE(Dictionary, nn.Module): def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int]): super().__init__() @@ -133,6 +152,7 @@ def __init__( lm_name: str, group_fractions: list[float], group_weights: Optional[list[float]] = None, + weights_temperature: float = 1.0, dict_class: type = MatroyshkaBatchTopKSAE, auxk_alpha: float = 1 / 32, warmup_steps: int = 1000, @@ -169,7 +189,9 @@ def __init__( group_sizes.append(dict_size - sum(group_sizes)) if group_weights is None: - group_weights = [1.0] * len(group_sizes) + group_weights = group_fractions.copy() + + group_weights = apply_temperature(group_weights, weights_temperature) assert len(group_sizes) == len( group_weights @@ -178,6 +200,7 @@ def __init__( self.group_fractions = group_fractions self.group_sizes = group_sizes self.group_weights = group_weights + self.weights_temperature = weights_temperature self.ae = dict_class(activation_dim, dict_size, k, group_sizes) @@ -344,6 +367,7 @@ def config(self): "group_fractions": self.group_fractions, "group_weights": self.group_weights, "group_sizes": self.group_sizes, + "weights_temperature": self.weights_temperature, "k": self.ae.k.item(), "device": self.device, "layer": self.layer, diff --git a/utils.py b/utils.py index 4f34a4e..2cc12b3 100644 --- a/utils.py +++ b/utils.py @@ -7,6 +7,7 @@ from dictionary_learning.trainers.top_k import AutoEncoderTopK from dictionary_learning.trainers.batch_top_k import BatchTopKSAE +from dictionary_learning.trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE from dictionary_learning.dictionary import ( AutoEncoder, GatedAutoEncoder, @@ -74,7 +75,11 @@ def load_dictionary(base_path: str, device: str) -> tuple: k = config["trainer"]["k"] dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device) elif dict_class == "BatchTopKSAE": - dictionary = BatchTopKSAE.from_pretrained(ae_path, device=device) + k = config["trainer"]["k"] + dictionary = BatchTopKSAE.from_pretrained(ae_path, k=k, device=device) + elif dict_class == "MatroyshkaBatchTopKSAE": + k = config["trainer"]["k"] + dictionary = MatroyshkaBatchTopKSAE.from_pretrained(ae_path, k=k, device=device) elif dict_class == "JumpReluAutoEncoder": dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) else: From 3e31571b20d3e86823540882ec03c87b155d8e3d Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 1 Jan 2025 17:10:36 +0000 Subject: [PATCH 33/70] Format with ruff --- dictionary.py | 91 +++++++++++++++++++++++++++++---------------------- 1 file changed, 52 insertions(+), 39 deletions(-) diff --git a/dictionary.py b/dictionary.py index 09e80ab..5217d74 100644 --- a/dictionary.py +++ b/dictionary.py @@ -7,12 +7,14 @@ import torch.nn as nn import torch.nn.init as init + class Dictionary(ABC, nn.Module): """ A dictionary consists of a collection of vectors, an encoder, and a decoder. """ - dict_size : int # number of features in the dictionary - activation_dim : int # dimension of the activation vectors + + dict_size: int # number of features in the dictionary + activation_dim: int # dimension of the activation vectors @abstractmethod def encode(self, x): @@ -20,7 +22,7 @@ def encode(self, x): Encode a vector x in the activation space. """ pass - + @abstractmethod def decode(self, f): """ @@ -41,6 +43,7 @@ class AutoEncoder(Dictionary, nn.Module): """ A one-layer autoencoder. """ + def __init__(self, activation_dim, dict_size): super().__init__() self.activation_dim = activation_dim @@ -56,14 +59,13 @@ def __init__(self, activation_dim, dict_size): ## set encoder and decoder weights self.encoder.weight = nn.Parameter(w.clone().T) self.decoder.weight = nn.Parameter(w.clone()) - def encode(self, x): return nn.ReLU()(self.encoder(x - self.bias)) - + def decode(self, f): return self.decoder(f) + self.bias - + def forward(self, x, output_features=False, ghost_mask=None): """ Forward pass of an autoencoder. @@ -71,20 +73,22 @@ def forward(self, x, output_features=False, ghost_mask=None): output_features : if True, return the encoded features as well as the decoded x ghost_mask : if not None, run this autoencoder in "ghost mode" where features are masked """ - if ghost_mask is None: # normal mode + if ghost_mask is None: # normal mode f = self.encode(x) x_hat = self.decode(f) if output_features: return x_hat, f else: return x_hat - - else: # ghost mode + + else: # ghost mode f_pre = self.encoder(x - self.bias) f_ghost = t.exp(f_pre) * ghost_mask.to(f_pre) f = nn.ReLU()(f_pre) - x_ghost = self.decoder(f_ghost) # note that this only applies the decoder weight matrix, no bias + x_ghost = self.decoder( + f_ghost + ) # note that this only applies the decoder weight matrix, no bias x_hat = self.decode(f) if output_features: return x_hat, x_ghost, f @@ -94,24 +98,26 @@ def forward(self, x, output_features=False, ghost_mask=None): def scale_biases(self, scale: float): self.encoder.bias.data *= scale self.bias.data *= scale - + @classmethod def from_pretrained(cls, path, dtype=t.float, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = cls(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: autoencoder.to(dtype=dtype, device=device) return autoencoder - + + class IdentityDict(Dictionary, nn.Module): """ An identity dictionary, i.e. the identity function. """ + def __init__(self, activation_dim=None): super().__init__() self.activation_dim = activation_dim @@ -119,28 +125,30 @@ def __init__(self, activation_dim=None): def encode(self, x): return x - + def decode(self, f): return f - + def forward(self, x, output_features=False, ghost_mask=None): if output_features: return x, x else: return x - + @classmethod def from_pretrained(cls, path, dtype=t.float, device=None): """ Load a pretrained dictionary from a file. """ return cls(None) - + + class GatedAutoEncoder(Dictionary, nn.Module): """ An autoencoder with separate gating and magnitude networks. """ - def __init__(self, activation_dim, dict_size, initialization='default', device=None): + + def __init__(self, activation_dim, dict_size, initialization="default", device=None): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size @@ -150,7 +158,7 @@ def __init__(self, activation_dim, dict_size, initialization='default', device=N self.gate_bias = nn.Parameter(t.empty(dict_size, device=device)) self.mag_bias = nn.Parameter(t.empty(dict_size, device=device)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False, device=device) - if initialization == 'default': + if initialization == "default": self._reset_parameters() else: initialization(self) @@ -200,7 +208,7 @@ def decode(self, f): # Normalizing after encode, and renormalizing before decode to enable comparability f = f / self.decoder.weight.norm(dim=0, keepdim=True) return self.decoder(f) + self.decoder_bias - + def forward(self, x, output_features=False): f = self.encode(x) x_hat = self.decode(f) @@ -222,20 +230,20 @@ def from_pretrained(path, device=None): Load a pretrained autoencoder from a file. """ state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = GatedAutoEncoder(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: autoencoder.to(device) return autoencoder - + class JumpReluAutoEncoder(Dictionary, nn.Module): """ An autoencoder with jump ReLUs. """ - def __init__(self, activation_dim, dict_size, device='cpu'): + def __init__(self, activation_dim, dict_size, device="cpu"): super().__init__() self.activation_dim = activation_dim self.dict_size = dict_size @@ -264,11 +272,11 @@ def encode(self, x, output_pre_jump=False): return f, pre_jump else: return f - + def decode(self, f): f = f / self.W_dec.norm(dim=1) return f @ self.W_dec + self.b_dec - + def forward(self, x, output_features=False): """ Forward pass of an autoencoder. @@ -286,15 +294,15 @@ def scale_biases(self, scale: float): self.b_dec.data *= scale self.b_enc.data *= scale self.threshold.data *= scale - + @classmethod def from_pretrained( - cls, - path: str | None = None, - load_from_sae_lens: bool = False, - dtype: t.dtype = t.float32, - device: t.device | None = None, - **kwargs, + cls, + path: str | None = None, + load_from_sae_lens: bool = False, + dtype: t.dtype = t.float32, + device: t.device | None = None, + **kwargs, ): """ Load a pretrained autoencoder from a file. @@ -303,14 +311,17 @@ def from_pretrained( """ if not load_from_sae_lens: state_dict = t.load(path) - activation_dim, dict_size = state_dict['W_enc'].shape + activation_dim, dict_size = state_dict["W_enc"].shape autoencoder = JumpReluAutoEncoder(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) autoencoder = autoencoder.to(dtype=dtype, device=device) else: from sae_lens import SAE + sae, cfg_dict, _ = SAE.from_pretrained(**kwargs) - assert cfg_dict["finetuning_scaling_factor"] == False, "Finetuning scaling factor not supported" + assert ( + cfg_dict["finetuning_scaling_factor"] == False + ), "Finetuning scaling factor not supported" dict_size, activation_dim = cfg_dict["d_sae"], cfg_dict["d_in"] autoencoder = JumpReluAutoEncoder(activation_dim, dict_size, device=device) autoencoder.load_state_dict(sae.state_dict()) @@ -320,11 +331,13 @@ def from_pretrained( device = autoencoder.W_enc.device return autoencoder.to(dtype=dtype, device=device) + # TODO merge this with AutoEncoder class AutoEncoderNew(Dictionary, nn.Module): """ The autoencoder architecture and initialization used in https://transformer-circuits.pub/2024/april-update/index.html#training-saes """ + def __init__(self, activation_dim, dict_size): super().__init__() self.activation_dim = activation_dim @@ -346,10 +359,10 @@ def __init__(self, activation_dim, dict_size): def encode(self, x): return nn.ReLU()(self.encoder(x)) - + def decode(self, f): return self.decoder(f) - + def forward(self, x, output_features=False): """ Forward pass of an autoencoder. @@ -357,19 +370,19 @@ def forward(self, x, output_features=False): """ if not output_features: return self.decode(self.encode(x)) - else: # TODO rewrite so that x_hat depends on f + else: # TODO rewrite so that x_hat depends on f f = self.encode(x) x_hat = self.decode(f) # multiply f by decoder column norms f = f * self.decoder.weight.norm(dim=0, keepdim=True) return x_hat, f - + def from_pretrained(path, device=None): """ Load a pretrained autoencoder from a file. """ state_dict = t.load(path) - dict_size, activation_dim = state_dict['encoder.weight'].shape + dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = AutoEncoderNew(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) if device is not None: From 8eaa8b2407eabd714bbe7d55fd0c15fcb05fba24 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 1 Jan 2025 17:39:56 +0000 Subject: [PATCH 34/70] Use kaiming initialization if specified in paper, fix batch_top_k aux_k_alpha --- dictionary.py | 13 +++++---- tests/test_end_to_end.py | 44 +++++++++++++++--------------- trainers/batch_top_k.py | 18 ++++++------ trainers/matroyshka_batch_top_k.py | 13 ++++----- trainers/top_k.py | 12 ++++---- 5 files changed, 48 insertions(+), 52 deletions(-) diff --git a/dictionary.py b/dictionary.py index 5217d74..58c62d4 100644 --- a/dictionary.py +++ b/dictionary.py @@ -177,6 +177,7 @@ def _reset_parameters(self): dec_weight = t.randn_like(self.decoder.weight) dec_weight = dec_weight / dec_weight.norm(dim=0, keepdim=True) self.decoder.weight = nn.Parameter(dec_weight) + self.encoder.weight = nn.Parameter(dec_weight.clone().T) def encode(self, x, return_gate=False): """ @@ -249,16 +250,16 @@ def __init__(self, activation_dim, dict_size, device="cpu"): self.dict_size = dict_size self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size, device=device)) self.b_enc = nn.Parameter(t.zeros(dict_size, device=device)) - self.W_dec = nn.Parameter(t.empty(dict_size, activation_dim, device=device)) + self.W_dec = nn.Parameter( + t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim)), device=device + ) self.b_dec = nn.Parameter(t.zeros(activation_dim, device=device)) - self.threshold = nn.Parameter(t.zeros(dict_size, device=device)) + self.threshold = nn.Parameter(t.ones(dict_size, device=device)) * 0.001 # Appendix I self.apply_b_dec_to_input = False - # rows of decoder weight matrix are initialized to unit vectors - self.W_enc.data = t.randn_like(self.W_enc) - self.W_enc.data = self.W_enc / self.W_enc.norm(dim=0, keepdim=True) - self.W_dec.data = self.W_enc.data.clone().T + self.W_dec.data = self.W_dec / self.W_dec.norm(dim=1, keepdim=True) + self.W_enc.data = self.W_dec.data.clone().T def encode(self, x, output_pre_jump=False): if self.apply_b_dec_to_input: diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index fdbce15..b7b5cc9 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -19,31 +19,31 @@ EXPECTED_RESULTS = { "AutoEncoderTopK": { - "l2_loss": 4.325331306457519, - "l1_loss": 47.92763671875, + "l2_loss": 4.362327718734742, + "l1_loss": 50.94957427978515, "l0": 40.0, - "frac_variance_explained": 0.9584966480731965, - "cossim": 0.948570293188095, - "l2_ratio": 0.94872345328331, - "relative_reconstruction_bias": 0.9998040139675141, - "loss_original": 3.328495955467224, - "loss_reconstructed": 3.819682216644287, - "loss_zero": 13.250199031829833, - "frac_recovered": 0.9503251194953919, + "frac_variance_explained": 0.9578053653240204, + "cossim": 0.9478691875934601, + "l2_ratio": 0.9478908002376556, + "relative_reconstruction_bias": 0.999762898683548, + "loss_original": 3.3361297130584715, + "loss_reconstructed": 3.8404462814331053, + "loss_zero": 13.251659297943116, + "frac_recovered": 0.948982036113739, "frac_alive": 0.99951171875, }, "AutoEncoder": { - "l2_loss": 6.822399997711182, - "l1_loss": 19.381900978088378, - "l0": 37.4492919921875, - "frac_variance_explained": 0.8993505954742431, - "cossim": 0.8791077017784119, - "l2_ratio": 0.7455410599708557, - "relative_reconstruction_bias": 0.9595056653022767, - "loss_original": 3.3284960985183716, - "loss_reconstructed": 5.203806638717651, - "loss_zero": 13.250199031829833, - "frac_recovered": 0.8104169845581055, + "l2_loss": 6.822444677352905, + "l1_loss": 19.382131576538086, + "l0": 37.45087890625, + "frac_variance_explained": 0.8993501663208008, + "cossim": 0.8791120409965515, + "l2_ratio": 0.74552041888237, + "relative_reconstruction_bias": 0.9595054805278778, + "loss_original": 3.3361297130584715, + "loss_reconstructed": 5.208198881149292, + "loss_zero": 13.251659297943116, + "frac_recovered": 0.8106247961521149, "frac_alive": 0.99658203125, }, } @@ -62,7 +62,7 @@ def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing. I have observed that results can slightly vary with library versions. For full determinism, - use pytorch 2.2.0 and nnsight 0.3.3. + use pytorch 2.5.1 and nnsight 0.3.7. NOTE: `dictionary_learning` is meant to be used as a submodule. Thus, to run this test, you need to use `dictionary_learning` as a submodule and run the test from the root of the repository using `pytest -s`. Refer to https://github.com/adamkarvonen/dictionary_learning_demo for an example""" diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index dd59f1a..2d4d261 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -19,11 +19,12 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.register_buffer("k", t.tensor(k)) self.register_buffer("threshold", t.tensor(-1.0)) - self.encoder = nn.Linear(activation_dim, dict_size) - self.encoder.bias.data.zero_() self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.decoder.weight.data = self.encoder.weight.data.clone().T self.set_decoder_norm_to_unit_norm() + + self.encoder = nn.Linear(activation_dim, dict_size) + self.encoder.weight.data = self.decoder.weight.T.clone() + self.encoder.bias.data.zero_() self.b_dec = nn.Parameter(t.zeros(activation_dim)) def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): @@ -109,7 +110,7 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "BatchTopKSAE": class BatchTopKTrainer(SAETrainer): def __init__( self, - steps: int, # total number of steps to train for + steps: int, # total number of steps to train for activation_dim: int, dict_size: int, k: int, @@ -121,7 +122,6 @@ def __init__( decay_start: Optional[int] = None, # when does the lr decay start threshold_beta: float = 0.999, threshold_start_step: int = 1000, - top_k_aux: int = 512, seed: Optional[int] = None, device: Optional[str] = None, wandb_name: str = "BatchTopKSAE", @@ -156,7 +156,7 @@ def __init__( self.lr = 2e-4 / scale**0.5 self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 - self.top_k_aux = top_k_aux + self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) @@ -169,7 +169,7 @@ def __init__( def lr_fn(step): if step < warmup_steps: return step / warmup_steps - + if decay_start is not None and step >= decay_start: return (steps - step) / (steps - decay_start) @@ -195,9 +195,7 @@ def get_auxiliary_loss(self, x, x_reconstruct, acts): -1, acts_topk_aux.indices, acts_topk_aux.values ) x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features]) - l2_loss_aux = ( - self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() - ) + l2_loss_aux = (x_reconstruct_aux.float() - residual.float()).pow(2).mean() return l2_loss_aux else: return t.tensor(0, dtype=x.dtype, device=x.device) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index 483ea0f..58c1524 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -50,10 +50,9 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: lis self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size)) self.b_enc = nn.Parameter(t.zeros(dict_size)) - self.W_dec = nn.Parameter(t.empty(dict_size, activation_dim)) + self.W_dec = nn.Parameter(t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim))) self.b_dec = nn.Parameter(t.zeros(activation_dim)) - self.W_dec.data = t.randn_like(self.W_dec) self.set_decoder_norm_to_unit_norm() self.W_enc.data = self.W_dec.data.clone().T @@ -159,7 +158,6 @@ def __init__( decay_start: Optional[int] = None, # when does the lr decay start threshold_beta: float = 0.999, threshold_start_step: int = 1000, - top_k_aux: int = 512, seed: Optional[int] = None, device: Optional[str] = None, wandb_name: str = "BatchTopKSAE", @@ -214,7 +212,7 @@ def __init__( self.lr = 2e-4 / scale**0.5 self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 - self.top_k_aux = top_k_aux + self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) @@ -252,10 +250,9 @@ def get_auxiliary_loss(self, x, x_reconstruct, acts): acts_aux = t.zeros_like(acts[:, dead_features]).scatter( -1, acts_topk_aux.indices, acts_topk_aux.values ) - x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features]) - l2_loss_aux = ( - self.auxk_alpha * (x_reconstruct_aux.float() - residual.float()).pow(2).mean() - ) + x_reconstruct_aux = F.linear(acts_aux, self.ae.W_dec[dead_features, :].T) + l2_loss_aux = (x_reconstruct_aux.float() - residual.float()).pow(2).mean() + return l2_loss_aux else: return t.tensor(0, dtype=x.dtype, device=x.device) diff --git a/trainers/top_k.py b/trainers/top_k.py index 64ca2ab..0215295 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -64,13 +64,13 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.register_buffer("k", t.tensor(k)) self.register_buffer("threshold", t.tensor(-1.0)) - self.encoder = nn.Linear(activation_dim, dict_size) - self.encoder.bias.data.zero_() - self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.decoder.weight.data = self.encoder.weight.data.clone().T self.set_decoder_norm_to_unit_norm() + self.encoder = nn.Linear(activation_dim, dict_size) + self.encoder.weight.data = self.decoder.weight.T.clone() + self.encoder.bias.data.zero_() + self.b_dec = nn.Parameter(t.zeros(activation_dim)) def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = False): @@ -162,7 +162,7 @@ class TopKTrainer(SAETrainer): def __init__( self, - steps: int, # total number of steps to train for + steps: int, # total number of steps to train for activation_dim: int, dict_size: int, k: int, @@ -224,7 +224,7 @@ def __init__( def lr_fn(step): if step < warmup_steps: return step / warmup_steps - + if decay_start is not None and step >= decay_start: return (steps - step) / (steps - decay_start) From ec961acde2244b98b26bcf796c3ec00b721088bb Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 1 Jan 2025 23:00:17 +0000 Subject: [PATCH 35/70] Fix jumprelu training --- dictionary.py | 29 +++++++++++++++++++++++++---- trainers/jumprelu.py | 11 ++++++++--- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/dictionary.py b/dictionary.py index 58c62d4..9886282 100644 --- a/dictionary.py +++ b/dictionary.py @@ -6,6 +6,7 @@ import torch as t import torch.nn as nn import torch.nn.init as init +import einops class Dictionary(ABC, nn.Module): @@ -251,10 +252,10 @@ def __init__(self, activation_dim, dict_size, device="cpu"): self.W_enc = nn.Parameter(t.empty(activation_dim, dict_size, device=device)) self.b_enc = nn.Parameter(t.zeros(dict_size, device=device)) self.W_dec = nn.Parameter( - t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim)), device=device + t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim, device=device)) ) self.b_dec = nn.Parameter(t.zeros(activation_dim, device=device)) - self.threshold = nn.Parameter(t.ones(dict_size, device=device)) * 0.001 # Appendix I + self.threshold = nn.Parameter(t.ones(dict_size, device=device) * 0.001) # Appendix I self.apply_b_dec_to_input = False @@ -267,7 +268,6 @@ def encode(self, x, output_pre_jump=False): pre_jump = x @ self.W_enc + self.b_enc f = nn.ReLU()(pre_jump * (pre_jump > self.threshold)) - f = f * self.W_dec.norm(dim=1) if output_pre_jump: return f, pre_jump @@ -275,7 +275,6 @@ def encode(self, x, output_pre_jump=False): return f def decode(self, f): - f = f / self.W_dec.norm(dim=1) return f @ self.W_dec + self.b_dec def forward(self, x, output_features=False): @@ -332,6 +331,28 @@ def from_pretrained( device = autoencoder.W_enc.device return autoencoder.to(dtype=dtype, device=device) + @t.no_grad() + def set_decoder_norm_to_unit_norm(self): + eps = t.finfo(self.W_dec.dtype).eps + norm = t.norm(self.W_dec.data, dim=1, keepdim=True) + + self.W_dec.data /= norm + eps + + @t.no_grad() + def remove_gradient_parallel_to_decoder_directions(self): + assert self.W_dec.grad is not None + + parallel_component = einops.einsum( + self.W_dec.grad, + self.W_dec.data, + "d_sae d_in, d_sae d_in -> d_sae", + ) + self.W_dec.grad -= einops.einsum( + parallel_component, + self.W_dec.data, + "d_sae, d_sae d_in -> d_sae d_in", + ) + # TODO merge this with AutoEncoder class AutoEncoderNew(Dictionary, nn.Module): diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index b2f7155..42a3ea0 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -8,7 +8,6 @@ from ..dictionary import Dictionary, JumpReluAutoEncoder from .trainer import SAETrainer - class RectangleFunction(autograd.Function): @staticmethod def forward(ctx, x): @@ -32,7 +31,7 @@ def forward(ctx, x, threshold, bandwidth): @staticmethod def backward(ctx, grad_output): x, threshold, bandwidth_tensor = ctx.saved_tensors - bandwidth = bandwidth_tensor.item() + bandwidth = bandwidth_tensor.item() x_grad = (x > threshold).float() * grad_output threshold_grad = ( -(threshold / bandwidth) @@ -155,7 +154,10 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): else: sparsity_scale = 1.0 - f = self.ae.encode(x) + + pre_jump = x @ self.ae.W_enc + self.ae.b_enc + f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) + recon = self.ae.decode(f) recon_loss = (x - recon).pow(2).sum(dim=-1).mean() @@ -178,11 +180,14 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): ) def update(self, step, x): + self.ae.set_decoder_norm_to_unit_norm() + x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) + self.ae.remove_gradient_parallel_to_decoder_directions() self.optimizer.step() self.scheduler.step() From c2fe5b89e78ae4a9d41a4809f4d00b8a3fcd0b36 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 1 Jan 2025 23:00:28 +0000 Subject: [PATCH 36/70] Add option to ignore bos tokens --- buffer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/buffer.py b/buffer.py index 6cbf8e3..7d304b7 100644 --- a/buffer.py +++ b/buffer.py @@ -26,7 +26,8 @@ def __init__(self, ctx_len=128, # length of each context refresh_batch_size=512, # size of batches in which to process the data when adding to buffer out_batch_size=8192, # size of batches in which to yield activations - device='cpu' # device on which to store the activations + device='cpu', # device on which to store the activations + remove_bos: bool = False, ): if io not in ['in', 'out']: @@ -54,6 +55,7 @@ def __init__(self, self.refresh_batch_size = refresh_batch_size self.out_batch_size = out_batch_size self.device = device + self.remove_bos = remove_bos def __iter__(self): return self @@ -131,6 +133,9 @@ def refresh(self): hidden_states = hidden_states.value if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] + if self.remove_bos: + hidden_states = hidden_states[:, 1:, :] + attn_mask = attn_mask[:, 1:] hidden_states = hidden_states[attn_mask != 0] remaining_space = self.activation_buffer_size - current_idx From 810dbb8bdce4ac6f1ce371872297b4f7a104e3f6 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 2 Jan 2025 03:05:16 +0000 Subject: [PATCH 37/70] Add notes --- trainers/jumprelu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index 42a3ea0..6a1fc5b 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -74,7 +74,6 @@ def __init__( layer: int, lm_name: str, dict_class=JumpReluAutoEncoder, - # XXX: Training decay is not implemented seed: Optional[int] = None, # TODO: What's the default lr use in the paper? lr: float = 7e-5, @@ -148,6 +147,9 @@ def warmup_fn(step): self.logging_parameters = [] def loss(self, x: torch.Tensor, step: int, logging=False, **_): + # Note: We are using threshold, not log_threshold as in this notebook: + # https://colab.research.google.com/drive/1PlFzI_PWGTN9yCQLuBcSuPJUjgHL7GiD#scrollTo=yP828a6uIlSO + # I had poor results when using log_threshold and it would complicate the scale_biases() function if self.sparsity_warmup_steps is not None: sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) From 3b03b92b97d61a95e98b6f187dad97e939f6f977 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 2 Jan 2025 04:42:02 +0000 Subject: [PATCH 38/70] Add trainer number to wandb name --- training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/training.py b/training.py index 63d1cb8..e8e22b2 100644 --- a/training.py +++ b/training.py @@ -126,7 +126,9 @@ def trainSAE( """ trainers = [] - for config in trainer_configs: + for i, config in enumerate(trainer_configs): + if "wandb_name" in config: + config["wandb_name"] = f"{config['wandb_name']}_trainer_{i}" trainer_class = config["trainer"] del config["trainer"] trainers.append(trainer_class(**config)) From 77da7945f520f448b0524e476f539b3a44a4ca43 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 2 Jan 2025 04:42:14 +0000 Subject: [PATCH 39/70] Log number of dead features to wandb --- trainers/jumprelu.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index 6a1fc5b..4d72ac3 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -144,7 +144,11 @@ def warmup_fn(step): self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) - self.logging_parameters = [] + # Purely for logging purposes + self.dead_feature_threshold = 10_000_000 + self.num_tokens_since_fired = torch.zeros(dict_size, dtype=torch.long, device=device) + self.dead_features = -1 + self.logging_parameters = ["dead_features"] def loss(self, x: torch.Tensor, step: int, logging=False, **_): # Note: We are using threshold, not log_threshold as in this notebook: @@ -156,10 +160,16 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): else: sparsity_scale = 1.0 - pre_jump = x @ self.ae.W_enc + self.ae.b_enc f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) + active_indices = f.sum(0) > 0 + did_fire = torch.zeros_like(self.num_tokens_since_fired, dtype=torch.bool) + did_fire[active_indices] = True + self.num_tokens_since_fired += x.size(0) + self.num_tokens_since_fired[active_indices] = 0 + self.dead_features = (self.num_tokens_since_fired > self.dead_feature_threshold).sum().item() + recon = self.ae.decode(f) recon_loss = (x - recon).pow(2).sum(dim=-1).mean() From 936a69c38a74980830f24fc851c40fb93abe8f07 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 2 Jan 2025 04:42:48 +0000 Subject: [PATCH 40/70] Log dead features for batch top k SAEs --- trainers/batch_top_k.py | 2 ++ trainers/matroyshka_batch_top_k.py | 1 + 2 files changed, 3 insertions(+) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index 2d4d261..a548149 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -184,6 +184,8 @@ def lr_fn(step): def get_auxiliary_loss(self, x, x_reconstruct, acts): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + self.dead_features = int(dead_features.sum()) + if dead_features.sum() > 0: residual = x.float() - x_reconstruct.float() acts_topk_aux = t.topk( diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index 58c1524..11d3bc3 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -240,6 +240,7 @@ def lr_fn(step): def get_auxiliary_loss(self, x, x_reconstruct, acts): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + self.dead_features = int(dead_features.sum()) if dead_features.sum() > 0: residual = x.float() - x_reconstruct.float() acts_topk_aux = t.topk( From 370272a4aac0ad0e59a2982073aa7b08970712b6 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 2 Jan 2025 20:37:00 +0000 Subject: [PATCH 41/70] Prevent wandb cuda multiprocessing errors --- training.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/training.py b/training.py index e8e22b2..6671111 100644 --- a/training.py +++ b/training.py @@ -3,7 +3,7 @@ """ import json -import multiprocessing as mp +import torch.multiprocessing as mp import os from queue import Empty from typing import Optional @@ -64,10 +64,12 @@ def log_stats( l0 = (f != 0).float().sum(dim=-1).mean().item() # log parameters from training - log.update({f"{k}": v for k, v in losslog.items()}) + log.update({f"{k}": v.cpu().item() if isinstance(v, t.Tensor) else v for k, v in losslog.items()}) log[f"l0"] = l0 trainer_log = trainer.get_logging_parameters() for name, value in trainer_log.items(): + if isinstance(value, t.Tensor): + value = value.cpu().item() log[f"{name}"] = value if log_queues: @@ -137,10 +139,16 @@ def trainSAE( log_queues = [] if use_wandb: + # Note: If encountering wandb and CUDA related errors, try setting start method to spawn in the if __name__ == "__main__" block + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.set_start_method + # Everything should work fine with the default fork method but it may not be as robust for i, trainer in enumerate(trainers): log_queue = mp.Queue() log_queues.append(log_queue) wandb_config = trainer.config | run_cfg + # Make sure wandb config doesn't contain any CUDA tensors + wandb_config = {k: v.cpu().item() if isinstance(v, t.Tensor) else v + for k, v in wandb_config.items()} wandb_process = mp.Process( target=new_wandb_process, args=(wandb_config, log_queue, wandb_entity, wandb_project), From 0ff687bdc12cba66a0233825cb301df28da3a9db Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 3 Jan 2025 04:09:38 +0000 Subject: [PATCH 42/70] Add a verbose option during training --- training.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/training.py b/training.py index 6671111..7bdf1a7 100644 --- a/training.py +++ b/training.py @@ -38,6 +38,7 @@ def log_stats( activations_split_by_head: bool, transcoder: bool, log_queues: list=[], + verbose: bool=False, ): with t.no_grad(): # quick hack to make sure all trainers get the same x @@ -63,6 +64,9 @@ def log_stats( # L0 l0 = (f != 0).float().sum(dim=-1).mean().item() + if verbose: + print(f"Step {step}: L0 = {l0}, frac_variance_explained = {frac_variance_explained}") + # log parameters from training log.update({f"{k}": v.cpu().item() if isinstance(v, t.Tensor) else v for k, v in losslog.items()}) log[f"l0"] = l0 @@ -118,6 +122,7 @@ def trainSAE( transcoder:bool=False, run_cfg:dict={}, normalize_activations:bool=False, + verbose:bool=False, ): """ Train SAEs using the given trainers @@ -193,20 +198,10 @@ def trainSAE( break # logging - if log_steps is not None and step % log_steps == 0: + if (use_wandb or verbose) and step % log_steps == 0: log_stats( - trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues + trainers, step, act, activations_split_by_head, transcoder, log_queues=log_queues, verbose=verbose ) - if step % 100 == 0: - z = act.clone() - for i, trainer in enumerate(trainers): - act = z.clone() - act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) - - # L0 - l0 = (f != 0).float().sum(dim=-1).mean().item() - - print(f"Step {step}: L0 = {l0}") # saving if save_steps is not None and step in save_steps: From 9751c57731a25c04871e8173d16a0e4d902edc19 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 3 Jan 2025 05:37:49 +0000 Subject: [PATCH 43/70] Consolidate LR Schedulers, Sparsity Schedulers, and constrained optimizers --- dictionary.py | 24 +--- trainers/batch_top_k.py | 60 ++++------ trainers/gated_anneal.py | 57 +--------- trainers/gdm.py | 52 ++------- trainers/jumprelu.py | 79 ++++++------- trainers/matroyshka_batch_top_k.py | 60 ++++------ trainers/p_anneal.py | 54 +-------- trainers/standard.py | 54 +-------- trainers/top_k.py | 58 ++++------ trainers/trainer.py | 177 +++++++++++++++++++++++++++-- 10 files changed, 293 insertions(+), 382 deletions(-) diff --git a/dictionary.py b/dictionary.py index 9886282..8a7dad4 100644 --- a/dictionary.py +++ b/dictionary.py @@ -255,7 +255,7 @@ def __init__(self, activation_dim, dict_size, device="cpu"): t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim, device=device)) ) self.b_dec = nn.Parameter(t.zeros(activation_dim, device=device)) - self.threshold = nn.Parameter(t.ones(dict_size, device=device) * 0.001) # Appendix I + self.threshold = nn.Parameter(t.ones(dict_size, device=device) * 0.001) # Appendix I self.apply_b_dec_to_input = False @@ -331,28 +331,6 @@ def from_pretrained( device = autoencoder.W_enc.device return autoencoder.to(dtype=dtype, device=device) - @t.no_grad() - def set_decoder_norm_to_unit_norm(self): - eps = t.finfo(self.W_dec.dtype).eps - norm = t.norm(self.W_dec.data, dim=1, keepdim=True) - - self.W_dec.data /= norm + eps - - @t.no_grad() - def remove_gradient_parallel_to_decoder_directions(self): - assert self.W_dec.grad is not None - - parallel_component = einops.einsum( - self.W_dec.grad, - self.W_dec.data, - "d_sae d_in, d_sae d_in -> d_sae", - ) - self.W_dec.grad -= einops.einsum( - parallel_component, - self.W_dec.data, - "d_sae, d_sae d_in -> d_sae d_in", - ) - # TODO merge this with AutoEncoder class AutoEncoderNew(Dictionary, nn.Module): diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index a548149..a3c6214 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -6,7 +6,12 @@ from typing import Optional from ..dictionary import Dictionary -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) class BatchTopKSAE(Dictionary, nn.Module): @@ -20,7 +25,9 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.register_buffer("threshold", t.tensor(-1.0)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.set_decoder_norm_to_unit_norm() + self.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.decoder.weight, activation_dim, dict_size + ) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.weight.data = self.decoder.weight.T.clone() @@ -65,26 +72,6 @@ def forward(self, x: t.Tensor, output_features: bool = False): else: return x_hat_BD, encoded_acts_BF - @t.no_grad() - def set_decoder_norm_to_unit_norm(self): - eps = t.finfo(self.decoder.weight.dtype).eps - norm = t.norm(self.decoder.weight.data, dim=0, keepdim=True) - self.decoder.weight.data /= norm + eps - - @t.no_grad() - def remove_gradient_parallel_to_decoder_directions(self): - assert self.decoder.weight.grad is not None - parallel_component = einops.einsum( - self.decoder.weight.grad, - self.decoder.weight.data, - "d_in d_sae, d_in d_sae -> d_sae", - ) - self.decoder.weight.grad -= einops.einsum( - parallel_component, - self.decoder.weight.data, - "d_sae, d_in d_sae -> d_in d_sae", - ) - def scale_biases(self, scale: float): self.encoder.bias.data *= scale self.b_dec.data *= scale @@ -160,20 +147,7 @@ def __init__( self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) - if decay_start is not None: - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - - assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." - - def lr_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) @@ -185,7 +159,7 @@ def lr_fn(step): def get_auxiliary_loss(self, x, x_reconstruct, acts): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold self.dead_features = int(dead_features.sum()) - + if dead_features.sum() > 0: residual = x.float() - x_reconstruct.float() acts_topk_aux = t.topk( @@ -255,14 +229,22 @@ def update(self, step, x): median = self.geometric_median(x) self.ae.b_dec.data = median - self.ae.set_decoder_norm_to_unit_norm() + # Make sure the decoder is still unit-norm + self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size + ) x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) - self.ae.remove_gradient_parallel_to_decoder_directions() + self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.decoder.weight, + self.ae.decoder.weight.grad, + self.ae.activation_dim, + self.ae.dict_size, + ) self.optimizer.step() self.optimizer.zero_grad() diff --git a/trainers/gated_anneal.py b/trainers/gated_anneal.py index 6b4e774..09f69e6 100644 --- a/trainers/gated_anneal.py +++ b/trainers/gated_anneal.py @@ -5,31 +5,11 @@ import torch as t from typing import Optional -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG from ..dictionary import GatedAutoEncoder from collections import namedtuple -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr, betas=(0, 0.999)) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class GatedAnnealTrainer(SAETrainer): """ Gated SAE training scheme with p-annealing. @@ -116,33 +96,11 @@ def __init__(self, else: self.steps_since_active = None - self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - - if decay_start is not None: - assert resample_steps is None, "decay_start and resample_steps are currently mutually exclusive." - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - if sparsity_warmup_steps is not None: - assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr, betas=(0.0, 0.999)) - assert 0 <= warmup_steps < anneal_start, "warmup_steps must be >= 0 and < anneal_start." - - if sparsity_warmup_steps is not None: - assert 0 <= sparsity_warmup_steps < anneal_start, "sparsity_warmup_steps must be >= 0 and < anneal_start." - - if resample_steps is None: - def warmup_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) def resample_neurons(self, deads, activations): with t.no_grad(): @@ -186,10 +144,7 @@ def lp_norm(self, f, p): raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") def loss(self, x:t.Tensor, step:int, logging=False, **kwargs): - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) - else: - sparsity_scale = 1.0 + sparsity_scale = self.sparsity_warmup_fn(step) f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() diff --git a/trainers/gdm.py b/trainers/gdm.py index bb8ff0f..ecab59c 100644 --- a/trainers/gdm.py +++ b/trainers/gdm.py @@ -5,31 +5,11 @@ import torch as t from typing import Optional -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG from ..dictionary import GatedAutoEncoder from collections import namedtuple -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr, betas=(0, 0.999)) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class GatedSAETrainer(SAETrainer): """ Gated SAE training scheme. @@ -81,37 +61,19 @@ def __init__(self, self.optimizer = ConstrainedAdam( self.ae.parameters(), self.ae.decoder.parameters(), - lr=lr + lr=lr, + betas=(0.0, 0.999), ) - if decay_start is not None: - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - if sparsity_warmup_steps is not None: - assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps) - assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) - if sparsity_warmup_steps is not None: - assert 0 <= sparsity_warmup_steps < steps, "sparsity_warmup_steps must be >= 0 and < steps." - - def warmup_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 - - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) def loss(self, x:t.Tensor, step:int, logging:bool=False, **kwargs): - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) - else: - sparsity_scale = 1.0 + sparsity_scale = self.sparsity_warmup_fn(step) f, f_gate = self.ae.encode(x, return_gate=True) x_hat = self.ae.decode(f) diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index 4d72ac3..7323c27 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -6,7 +6,14 @@ from typing import Optional from ..dictionary import Dictionary, JumpReluAutoEncoder -from .trainer import SAETrainer +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + get_sparsity_warmup_fn, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) + class RectangleFunction(autograd.Function): @staticmethod @@ -31,7 +38,7 @@ def forward(ctx, x, threshold, bandwidth): @staticmethod def backward(ctx, grad_output): x, threshold, bandwidth_tensor = ctx.saved_tensors - bandwidth = bandwidth_tensor.item() + bandwidth = bandwidth_tensor.item() x_grad = (x > threshold).float() * grad_output threshold_grad = ( -(threshold / bandwidth) @@ -53,9 +60,7 @@ def backward(ctx, grad_output): bandwidth = bandwidth_tensor.item() x_grad = torch.zeros_like(x) threshold_grad = ( - -(1.0 / bandwidth) - * RectangleFunction.apply((x - threshold) / bandwidth) - * grad_output + -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output ) return x_grad, threshold_grad, None # None for bandwidth @@ -66,9 +71,10 @@ class JumpReluTrainer(nn.Module, SAETrainer): Note does not use learning rate or sparsity scheduling as in the paper. """ + def __init__( self, - steps: int, # total number of steps to train for + steps: int, # total number of steps to train for activation_dim: int, dict_size: int, layer: int, @@ -79,9 +85,9 @@ def __init__( lr: float = 7e-5, bandwidth: float = 0.001, sparsity_penalty: float = 1.0, - warmup_steps:int=1000, # lr warmup period at start of training and after each resample - sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training - decay_start:Optional[int]=None, # decay learning rate after this many steps + warmup_steps: int = 1000, # lr warmup period at start of training and after each resample + sparsity_warmup_steps: Optional[int] = 2000, # sparsity warmup period at start of training + decay_start: Optional[int] = None, # decay learning rate after this many steps target_l0: float = 20.0, device: str = "cpu", wandb_name: str = "JumpRelu", @@ -118,31 +124,19 @@ def __init__( ).to(self.device) # Parameters from the paper - self.optimizer = torch.optim.Adam( - self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8 + self.optimizer = torch.optim.Adam(self.ae.parameters(), lr=lr, betas=(0.0, 0.999), eps=1e-8) + + lr_fn = get_lr_schedule( + steps, + warmup_steps, + decay_start, + resample_steps=None, + sparsity_warmup_steps=sparsity_warmup_steps, ) - if decay_start is not None: - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - if sparsity_warmup_steps is not None: - assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." - - assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." + self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - if sparsity_warmup_steps is not None: - assert 0 <= sparsity_warmup_steps < steps, "sparsity_warmup_steps must be >= 0 and < steps." - - def warmup_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 - - self.scheduler = torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) # Purely for logging purposes self.dead_feature_threshold = 10_000_000 @@ -155,10 +149,7 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): # https://colab.research.google.com/drive/1PlFzI_PWGTN9yCQLuBcSuPJUjgHL7GiD#scrollTo=yP828a6uIlSO # I had poor results when using log_threshold and it would complicate the scale_biases() function - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) - else: - sparsity_scale = 1.0 + sparsity_scale = self.sparsity_warmup_fn(step) pre_jump = x @ self.ae.W_enc + self.ae.b_enc f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) @@ -168,14 +159,18 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): did_fire[active_indices] = True self.num_tokens_since_fired += x.size(0) self.num_tokens_since_fired[active_indices] = 0 - self.dead_features = (self.num_tokens_since_fired > self.dead_feature_threshold).sum().item() + self.dead_features = ( + (self.num_tokens_since_fired > self.dead_feature_threshold).sum().item() + ) recon = self.ae.decode(f) recon_loss = (x - recon).pow(2).sum(dim=-1).mean() l0 = StepFunction.apply(f, self.ae.threshold, self.bandwidth).sum(dim=-1).mean() - sparsity_loss = self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale + sparsity_loss = ( + self.sparsity_coefficient * ((l0 / self.target_l0) - 1).pow(2) * sparsity_scale + ) loss = recon_loss + sparsity_loss if not logging: @@ -192,14 +187,20 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): ) def update(self, step, x): - self.ae.set_decoder_norm_to_unit_norm() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.data = set_decoder_norm_to_unit_norm( + self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size + ).T x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) - self.ae.remove_gradient_parallel_to_decoder_directions() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size + ).T self.optimizer.step() self.scheduler.step() diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index 11d3bc3..69425fd 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -7,7 +7,12 @@ from math import isclose from ..dictionary import Dictionary -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) def apply_temperature(probabilities: list[float], temperature: float) -> list[float]: @@ -53,7 +58,10 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: lis self.W_dec = nn.Parameter(t.nn.init.kaiming_uniform_(t.empty(dict_size, activation_dim))) self.b_dec = nn.Parameter(t.zeros(activation_dim)) - self.set_decoder_norm_to_unit_norm() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.W_dec.data = set_decoder_norm_to_unit_norm( + self.W_dec.data.T, activation_dim, dict_size + ).T self.W_enc.data = self.W_dec.data.clone().T def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = True): @@ -93,28 +101,6 @@ def forward(self, x: t.Tensor, output_features: bool = False): else: return x_hat_BD, encoded_acts_BF - @t.no_grad() - def set_decoder_norm_to_unit_norm(self): - eps = t.finfo(self.W_dec.dtype).eps - norm = t.norm(self.W_dec.data, dim=1, keepdim=True) - - self.W_dec.data /= norm + eps - - @t.no_grad() - def remove_gradient_parallel_to_decoder_directions(self): - assert self.W_dec.grad is not None - - parallel_component = einops.einsum( - self.W_dec.grad, - self.W_dec.data, - "d_sae d_in, d_sae d_in -> d_sae", - ) - self.W_dec.grad -= einops.einsum( - parallel_component, - self.W_dec.data, - "d_sae, d_sae d_in -> d_sae d_in", - ) - @t.no_grad() def scale_biases(self, scale: float): self.b_enc.data *= scale @@ -216,21 +202,7 @@ def __init__( self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) - if decay_start is not None: - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - - assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." - - def lr_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 - + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) @@ -331,14 +303,20 @@ def update(self, step, x): median = self.geometric_median(x) self.ae.b_dec.data = median - self.ae.set_decoder_norm_to_unit_norm() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.data = set_decoder_norm_to_unit_norm( + self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size + ).T x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) - self.ae.remove_gradient_parallel_to_decoder_directions() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size + ).T self.optimizer.step() self.optimizer.zero_grad() diff --git a/trainers/p_anneal.py b/trainers/p_anneal.py index cf886ef..de9bdff 100644 --- a/trainers/p_anneal.py +++ b/trainers/p_anneal.py @@ -5,29 +5,9 @@ """ from ..dictionary import AutoEncoder -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class PAnnealTrainer(SAETrainer): """ SAE training scheme with the option to anneal the sparsity parameter p. @@ -116,31 +96,10 @@ def __init__(self, self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if decay_start is not None: - assert resample_steps is None, "decay_start and resample_steps are currently mutually exclusive." - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - if sparsity_warmup_steps is not None: - assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - assert 0 <= warmup_steps < anneal_start, "warmup_steps must be >= 0 and < anneal_start." - - if sparsity_warmup_steps is not None: - assert 0 <= sparsity_warmup_steps < anneal_start, "sparsity_warmup_steps must be >= 0 and < anneal_start." - - if resample_steps is None: - def warmup_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) if (self.sparsity_update_steps.unique(return_counts=True)[1] >1).any(): print("Warning! Duplicates om self.sparsity_update_steps detected!") @@ -187,10 +146,7 @@ def lp_norm(self, f, p): raise ValueError("Sparsity function must be 'Lp' or 'Lp^p'") def loss(self, x: t.Tensor, step:int, logging=False): - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) - else: - sparsity_scale = 1.0 + sparsity_scale = self.sparsity_warmup_fn(step) # Compute loss terms x_hat, f = self.ae(x, output_features=True) diff --git a/trainers/standard.py b/trainers/standard.py index a466cc7..ccc44bc 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -4,31 +4,11 @@ import torch as t from typing import Optional -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import SAETrainer, get_lr_schedule, get_sparsity_warmup_fn, ConstrainedAdam from ..config import DEBUG from ..dictionary import AutoEncoder from collections import namedtuple -class ConstrainedAdam(t.optim.Adam): - """ - A variant of Adam where some of the parameters are constrained to have unit norm. - """ - def __init__(self, params, constrained_params, lr): - super().__init__(params, lr=lr) - self.constrained_params = list(constrained_params) - - def step(self, closure=None): - with t.no_grad(): - for p in self.constrained_params: - normed_p = p / p.norm(dim=0, keepdim=True) - # project away the parallel component of the gradient - p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p - super().step(closure=closure) - with t.no_grad(): - for p in self.constrained_params: - # renormalize the constrained parameters - p /= p.norm(dim=0, keepdim=True) - class StandardTrainer(SAETrainer): """ Standard SAE training scheme. @@ -88,31 +68,10 @@ def __init__(self, self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) - if decay_start is not None: - assert resample_steps is None, "decay_start and resample_steps are currently mutually exclusive." - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - if sparsity_warmup_steps is not None: - assert decay_start > sparsity_warmup_steps, "decay_start must be > sparsity_warmup_steps." + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps, sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." - - if sparsity_warmup_steps is not None: - assert 0 <= sparsity_warmup_steps < steps, "sparsity_warmup_steps must be >= 0 and < steps." - - if resample_steps is None: - def warmup_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 - else: - def warmup_fn(step): - return min((step % resample_steps) / warmup_steps, 1.) - self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) def resample_neurons(self, deads, activations): with t.no_grad(): @@ -151,10 +110,7 @@ def resample_neurons(self, deads, activations): def loss(self, x, step: int, logging=False, **kwargs): - if self.sparsity_warmup_steps is not None: - sparsity_scale = min(step / self.sparsity_warmup_steps, 1.0) - else: - sparsity_scale = 1.0 + sparsity_scale = self.sparsity_warmup_fn(step) x_hat, f = self.ae(x, output_features=True) l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() diff --git a/trainers/top_k.py b/trainers/top_k.py index 0215295..a155a9a 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -11,7 +11,12 @@ from ..config import DEBUG from ..dictionary import Dictionary -from ..trainers.trainer import SAETrainer +from ..trainers.trainer import ( + SAETrainer, + get_lr_schedule, + set_decoder_norm_to_unit_norm, + remove_gradient_parallel_to_decoder_directions, +) @t.no_grad() @@ -65,7 +70,9 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.register_buffer("threshold", t.tensor(-1.0)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False) - self.set_decoder_norm_to_unit_norm() + self.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.decoder.weight, activation_dim, dict_size + ) self.encoder = nn.Linear(activation_dim, dict_size) self.encoder.weight.data = self.decoder.weight.T.clone() @@ -109,27 +116,6 @@ def forward(self, x: t.Tensor, output_features: bool = False): else: return x_hat_BD, encoded_acts_BF - @t.no_grad() - def set_decoder_norm_to_unit_norm(self): - eps = t.finfo(self.decoder.weight.dtype).eps - norm = t.norm(self.decoder.weight.data, dim=0, keepdim=True) - self.decoder.weight.data /= norm + eps - - @t.no_grad() - def remove_gradient_parallel_to_decoder_directions(self): - assert self.decoder.weight.grad is not None # keep pyright happy - - parallel_component = einops.einsum( - self.decoder.weight.grad, - self.decoder.weight.data, - "d_in d_sae, d_in d_sae -> d_sae", - ) - self.decoder.weight.grad -= einops.einsum( - parallel_component, - self.decoder.weight.data, - "d_sae, d_in d_sae -> d_in d_sae", - ) - def scale_biases(self, scale: float): self.encoder.bias.data *= scale self.b_dec.data *= scale @@ -215,20 +201,7 @@ def __init__( # Optimizer and scheduler self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) - if decay_start is not None: - assert 0 <= decay_start < steps, "decay_start must be >= 0 and < steps." - assert decay_start > warmup_steps, "decay_start must be > warmup_steps." - - assert 0 <= warmup_steps < steps, "warmup_steps must be >= 0 and < steps." - - def lr_fn(step): - if step < warmup_steps: - return step / warmup_steps - - if decay_start is not None and step >= decay_start: - return (steps - step) / (steps - decay_start) - - return 1.0 + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start=decay_start) self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) @@ -335,7 +308,9 @@ def update(self, step, x): self.ae.b_dec.data = median # Make sure the decoder is still unit-norm - self.ae.set_decoder_norm_to_unit_norm() + self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size + ) # compute the loss x = x.to(self.device) @@ -344,7 +319,12 @@ def update(self, step, x): # clip grad norm and remove grads parallel to decoder directions t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) - self.ae.remove_gradient_parallel_to_decoder_directions() + self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( + self.ae.decoder.weight, + self.ae.decoder.weight.grad, + self.ae.activation_dim, + self.ae.dict_size, + ) # do a training step self.optimizer.step() diff --git a/trainers/trainer.py b/trainers/trainer.py index 04170b9..03ffb54 100644 --- a/trainers/trainer.py +++ b/trainers/trainer.py @@ -1,16 +1,23 @@ +from typing import Optional, Callable +import torch +import einops + + class SAETrainer: """ Generic class for implementing SAE training algorithms """ + def __init__(self, seed=None): self.seed = seed self.logging_parameters = [] - def update(self, - step, # index of step in training - activations, # of shape [batch_size, d_submodule] - ): - pass # implemented by subclasses + def update( + self, + step, # index of step in training + activations, # of shape [batch_size, d_submodule] + ): + pass # implemented by subclasses def get_logging_parameters(self): stats = {} @@ -20,9 +27,165 @@ def get_logging_parameters(self): else: print(f"Warning: {param} not found in {self}") return stats - + @property def config(self): return { - 'wandb_name': 'trainer', + "wandb_name": "trainer", } + + +class ConstrainedAdam(torch.optim.Adam): + """ + A variant of Adam where some of the parameters are constrained to have unit norm. + Note: This should be used with a decoder that is nn.Linear, not nn.Parameter. + If nn.Parameter, the dim argument to norm should be 1. + """ + + def __init__(self, params, constrained_params, lr: float, betas: tuple[float, float] = (0.9, 0.999)): + super().__init__(params, lr=lr, betas=betas) + self.constrained_params = list(constrained_params) + + def step(self, closure=None): + with torch.no_grad(): + for p in self.constrained_params: + normed_p = p / p.norm(dim=0, keepdim=True) + # project away the parallel component of the gradient + p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p + super().step(closure=closure) + with torch.no_grad(): + for p in self.constrained_params: + # renormalize the constrained parameters + p /= p.norm(dim=0, keepdim=True) + + +# The next two functions could be replaced with the ConstrainedAdam Optimizer +@torch.no_grad() +def set_decoder_norm_to_unit_norm( + W_dec_DF: torch.nn.Parameter, activation_dim: int, d_sae: int +) -> torch.Tensor: + """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. + nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in + to catch this error.""" + + D, F = W_dec_DF.shape + + assert D == activation_dim + assert F == d_sae + + eps = torch.finfo(W_dec_DF.dtype).eps + norm = torch.norm(W_dec_DF.data, dim=0, keepdim=True) + W_dec_DF.data /= norm + eps + return W_dec_DF.data + + +@torch.no_grad() +def remove_gradient_parallel_to_decoder_directions( + W_dec_DF: torch.Tensor, + W_dec_DF_grad: torch.Tensor, + activation_dim: int, + d_sae: int, +) -> torch.Tensor: + """There's a major footgun here: we use this with both nn.Linear and nn.Parameter decoders. + nn.Linear stores the decoder weights in a transposed format (d_model, d_sae). So, we pass the dimensions in + to catch this error.""" + + D, F = W_dec_DF.shape + assert D == activation_dim + assert F == d_sae + + parallel_component = einops.einsum( + W_dec_DF_grad, + W_dec_DF, + "d_in d_sae, d_in d_sae -> d_sae", + ) + W_dec_DF_grad -= einops.einsum( + parallel_component, + W_dec_DF, + "d_sae, d_in d_sae -> d_in d_sae", + ) + return W_dec_DF_grad + + +def get_lr_schedule( + total_steps: int, + warmup_steps: int, + decay_start: Optional[int] = None, + resample_steps: Optional[int] = None, + sparsity_warmup_steps: Optional[int] = None, +) -> Callable[[int], float]: + """ + Creates a learning rate schedule function with linear warmup followed by an optional decay phase. + + Note: resample_steps creates a repeating warmup pattern instead of the standard phases, but + is rarely used in practice. + + Args: + total_steps: Total number of training steps + warmup_steps: Steps for linear warmup from 0 to 1 + decay_start: Optional step to begin linear decay to 0 + resample_steps: Optional period for repeating warmup pattern + sparsity_warmup_steps: Used for validation with decay_start + + Returns: + Function that computes LR scale factor for a given step + """ + if decay_start is not None: + assert ( + resample_steps is None + ), "decay_start and resample_steps are currently mutually exclusive." + assert 0 <= decay_start < total_steps, "decay_start must be >= 0 and < steps." + assert decay_start > warmup_steps, "decay_start must be > warmup_steps." + if sparsity_warmup_steps is not None: + assert ( + decay_start > sparsity_warmup_steps + ), "decay_start must be > sparsity_warmup_steps." + + assert 0 <= warmup_steps < total_steps, "warmup_steps must be >= 0 and < steps." + + if resample_steps is None: + + def lr_schedule(step: int) -> float: + if step < warmup_steps: + # Warm-up phase + return step / warmup_steps + + if decay_start is not None and step >= decay_start: + # Decay phase + return (total_steps - step) / (total_steps - decay_start) + + # Constant phase + return 1.0 + else: + assert 0 < resample_steps < total_steps, "resample_steps must be > 0 and < steps." + + def lr_schedule(step: int) -> float: + return min((step % resample_steps) / warmup_steps, 1.0) + + return lr_schedule + + +def get_sparsity_warmup_fn( + total_steps: int, sparsity_warmup_steps: Optional[int] = None +) -> Callable[[int], float]: + """ + Return a function that computes a scale factor for sparsity penalty at a given step. + + If `sparsity_warmup_steps` is None or 0, returns 1.0 for all steps. + Otherwise, scales from 0.0 up to 1.0 across `sparsity_warmup_steps`. + """ + + if sparsity_warmup_steps is not None: + assert ( + 0 <= sparsity_warmup_steps < total_steps + ), "sparsity_warmup_steps must be >= 0 and < steps." + + def scale_fn(step: int) -> float: + if not sparsity_warmup_steps: + # If it's None or zero, we just return 1.0 + return 1.0 + else: + # Gradually increase from 0.0 -> 1.0 as step goes from 0 -> sparsity_warmup_steps + return min(step / sparsity_warmup_steps, 1.0) + + return scale_fn From cfb36fff661fa60f38a2d1b372b6802517c08257 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 3 Jan 2025 19:53:13 +0000 Subject: [PATCH 44/70] Add April Update Standard Trainer --- dictionary.py | 33 +++++++++++- trainers/standard.py | 116 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 147 insertions(+), 2 deletions(-) diff --git a/dictionary.py b/dictionary.py index 8a7dad4..3b7e976 100644 --- a/dictionary.py +++ b/dictionary.py @@ -100,8 +100,32 @@ def scale_biases(self, scale: float): self.encoder.bias.data *= scale self.bias.data *= scale + def normalize_decoder(self): + norms = t.norm(self.decoder.weight, dim=0).to(dtype=self.decoder.weight.dtype, device=self.decoder.weight.device) + + if t.allclose(norms, t.ones_like(norms)): + return + print("Normalizing decoder weights") + + test_input = t.randn(10, self.activation_dim) + initial_output = self(test_input) + + self.decoder.weight.data /= norms + + new_norms = t.norm(self.decoder.weight, dim=0) + assert t.allclose(new_norms, t.ones_like(new_norms)) + + self.encoder.weight.data *= norms[:, None] + self.encoder.bias.data *= norms + + new_output = self(test_input) + + # Errors can be relatively large in larger SAEs due to floating point precision + assert t.allclose(initial_output, new_output, atol=1e-4) + + @classmethod - def from_pretrained(cls, path, dtype=t.float, device=None): + def from_pretrained(cls, path, dtype=t.float, device=None, normalize_decoder=True): """ Load a pretrained autoencoder from a file. """ @@ -109,8 +133,15 @@ def from_pretrained(cls, path, dtype=t.float, device=None): dict_size, activation_dim = state_dict["encoder.weight"].shape autoencoder = cls(activation_dim, dict_size) autoencoder.load_state_dict(state_dict) + + # This is useful for doing analysis where e.g. feature activation magnitudes are important + # If training the SAE using the April update, the decoder weights are not normalized + if normalize_decoder: + autoencoder.normalize_decoder() + if device is not None: autoencoder.to(dtype=dtype, device=device) + return autoencoder diff --git a/trainers/standard.py b/trainers/standard.py index ccc44bc..a839737 100644 --- a/trainers/standard.py +++ b/trainers/standard.py @@ -11,7 +11,7 @@ class StandardTrainer(SAETrainer): """ - Standard SAE training scheme. + Standard SAE training scheme following Towards Monosemanticity. Decoder column norms are constrained to 1. """ def __init__(self, steps: int, # total number of steps to train for @@ -173,3 +173,117 @@ def config(self): 'submodule_name': self.submodule_name, } + +class StandardTrainerAprilUpdate(SAETrainer): + """ + Standard SAE training scheme following the Anthropic April update. Decoder column norms are NOT constrained to 1. + This trainer does not support resampling or ghost gradients. This trainer will have fewer dead neurons than the standard trainer. + """ + def __init__(self, + steps: int, # total number of steps to train for + activation_dim: int, + dict_size: int, + layer: int, + lm_name: str, + dict_class=AutoEncoder, + lr:float=1e-3, + l1_penalty:float=1e-1, + warmup_steps:int=1000, # lr warmup period at start of training + sparsity_warmup_steps:Optional[int]=2000, # sparsity warmup period at start of training + decay_start:Optional[int]=None, # decay learning rate after this many steps + seed:Optional[int]=None, + device=None, + wandb_name:Optional[str]='StandardTrainerAprilUpdate', + submodule_name:Optional[str]=None, + ): + super().__init__(seed) + + assert layer is not None and lm_name is not None + self.layer = layer + self.lm_name = lm_name + self.submodule_name = submodule_name + + if seed is not None: + t.manual_seed(seed) + t.cuda.manual_seed_all(seed) + + # initialize dictionary + self.ae = dict_class(activation_dim, dict_size) + + self.lr = lr + self.l1_penalty=l1_penalty + self.warmup_steps = warmup_steps + self.sparsity_warmup_steps = sparsity_warmup_steps + self.steps = steps + self.decay_start = decay_start + self.wandb_name = wandb_name + + if device is None: + self.device = 'cuda' if t.cuda.is_available() else 'cpu' + else: + self.device = device + self.ae.to(self.device) + + self.optimizer = t.optim.Adam(self.ae.parameters(), lr=lr) + + lr_fn = get_lr_schedule(steps, warmup_steps, decay_start, resample_steps=None, sparsity_warmup_steps=sparsity_warmup_steps) + self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) + + self.sparsity_warmup_fn = get_sparsity_warmup_fn(steps, sparsity_warmup_steps) + + def loss(self, x, step: int, logging=False, **kwargs): + + sparsity_scale = self.sparsity_warmup_fn(step) + + x_hat, f = self.ae(x, output_features=True) + l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() + recon_loss = (x - x_hat).pow(2).sum(dim=-1).mean() + l1_loss = (f * self.ae.decoder.weight.norm(p=2, dim=0)).sum(dim=-1).mean() + + loss = recon_loss + self.l1_penalty * sparsity_scale * l1_loss + + if not logging: + return loss + else: + return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( + x, x_hat, f, + { + 'l2_loss' : l2_loss.item(), + 'mse_loss' : recon_loss.item(), + 'sparsity_loss' : l1_loss.item(), + 'loss' : loss.item() + } + ) + + + def update(self, step, activations): + activations = activations.to(self.device) + + self.optimizer.zero_grad() + loss = self.loss(activations, step=step) + loss.backward() + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) + self.optimizer.step() + self.scheduler.step() + + @property + def config(self): + return { + 'dict_class': 'AutoEncoder', + 'trainer_class' : 'StandardTrainerAprilUpdate', + 'activation_dim': self.ae.activation_dim, + 'dict_size': self.ae.dict_size, + 'lr' : self.lr, + 'l1_penalty' : self.l1_penalty, + 'warmup_steps' : self.warmup_steps, + 'sparsity_warmup_steps' : self.sparsity_warmup_steps, + 'steps' : self.steps, + 'decay_start' : self.decay_start, + 'seed' : self.seed, + 'device' : self.device, + 'layer' : self.layer, + 'lm_name' : self.lm_name, + 'wandb_name': self.wandb_name, + 'submodule_name': self.submodule_name, + } + From 8316a4418dc4acb70ccad9854d3b05df1b817b9d Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 3 Jan 2025 20:21:00 +0000 Subject: [PATCH 45/70] Add an option to pass LR to TopK trainers --- tests/test_end_to_end.py | 1 + trainers/batch_top_k.py | 9 +++++++-- trainers/matroyshka_batch_top_k.py | 9 +++++++-- trainers/top_k.py | 10 +++++++--- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index b7b5cc9..4172129 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -124,6 +124,7 @@ def test_sae_training(): { "trainer": TopKTrainer, "dict_class": AutoEncoderTopK, + "lr": None, "activation_dim": activation_dim, "dict_size": expansion_factor * activation_dim, "k": k, diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index a3c6214..d2062ea 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -104,6 +104,7 @@ def __init__( layer: int, lm_name: str, dict_class: type = BatchTopKSAE, + lr: Optional[float] = None, auxk_alpha: float = 1 / 32, warmup_steps: int = 1000, decay_start: Optional[int] = None, # when does the lr decay start @@ -139,8 +140,12 @@ def __init__( self.device = device self.ae.to(self.device) - scale = dict_size / (2**14) - self.lr = 2e-4 / scale**0.5 + if lr is not None: + self.lr = lr + else: + # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index 69425fd..a71838d 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -139,6 +139,7 @@ def __init__( group_weights: Optional[list[float]] = None, weights_temperature: float = 1.0, dict_class: type = MatroyshkaBatchTopKSAE, + lr: Optional[float] = None, auxk_alpha: float = 1 / 32, warmup_steps: int = 1000, decay_start: Optional[int] = None, # when does the lr decay start @@ -194,8 +195,12 @@ def __init__( self.device = device self.ae.to(self.device) - scale = dict_size / (2**14) - self.lr = 2e-4 / scale**0.5 + if lr is not None: + self.lr = lr + else: + # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper diff --git a/trainers/top_k.py b/trainers/top_k.py index a155a9a..6889f5c 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -155,6 +155,7 @@ def __init__( layer: int, lm_name: str, dict_class: type = AutoEncoderTopK, + lr: Optional[float] = None, auxk_alpha: float = 1 / 32, # see Appendix A.2 warmup_steps: int = 1000, decay_start: Optional[int] = None, # when does the lr decay start @@ -192,9 +193,12 @@ def __init__( self.device = device self.ae.to(self.device) - # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper - scale = dict_size / (2**14) - self.lr = 2e-4 / scale**0.5 + if lr is not None: + self.lr = lr + else: + # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper + scale = dict_size / (2**14) + self.lr = 2e-4 / scale**0.5 self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 From 3c5a5cdef682cbeb12e23b825f39709f518e2c0a Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Fri, 3 Jan 2025 21:32:52 +0000 Subject: [PATCH 46/70] Save state dicts to cpu --- training.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/training.py b/training.py index 7bdf1a7..39932eb 100644 --- a/training.py +++ b/training.py @@ -214,8 +214,10 @@ def trainSAE( if not os.path.exists(os.path.join(dir, "checkpoints")): os.mkdir(os.path.join(dir, "checkpoints")) + + checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} t.save( - trainer.ae.state_dict(), + checkpoint, os.path.join(dir, "checkpoints", f"ae_{step}.pt"), ) @@ -231,7 +233,8 @@ def trainSAE( if normalize_activations: trainer.ae.scale_biases(norm_factor) if save_dir is not None: - t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt")) + final = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} + t.save(final, os.path.join(save_dir, "ae.pt")) # Signal wandb processes to finish if use_wandb: From 832f4a32428cda68ec418aff9abe7dca66a9f66e Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 7 Jan 2025 03:04:52 +0000 Subject: [PATCH 47/70] Add torch autocast to training loop --- training.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/training.py b/training.py index 39932eb..50f20ee 100644 --- a/training.py +++ b/training.py @@ -7,6 +7,7 @@ import os from queue import Empty from typing import Optional +from contextlib import nullcontext import torch as t from tqdm import tqdm @@ -123,6 +124,8 @@ def trainSAE( run_cfg:dict={}, normalize_activations:bool=False, verbose:bool=False, + device:str="cuda", + autocast_dtype: t.dtype = t.float32, ): """ Train SAEs using the given trainers @@ -130,8 +133,13 @@ def trainSAE( If normalize_activations is True, the activations will be normalized to have unit mean squared norm. The autoencoders weights will be scaled before saving, so the activations don't need to be scaled during inference. This is very helpful for hyperparameter transfer between different layers and models. + + Setting autocast_dtype to t.bfloat16 provides a significant speedup with minimal change in performance. """ + device_type = "cuda" if "cuda" in device else "cpu" + autocast_context = nullcontext() if device_type == "cpu" else t.autocast(device_type=device_type, dtype=autocast_dtype) + trainers = [] for i, config in enumerate(trainer_configs): if "wandb_name" in config: @@ -189,7 +197,7 @@ def trainSAE( for step, act in enumerate(tqdm(data, total=steps)): - act = act.to(dtype=t.float32) + act = act.to(dtype=autocast_dtype) if normalize_activations: act /= norm_factor @@ -226,7 +234,8 @@ def trainSAE( # training for trainer in trainers: - trainer.update(step, act) + with autocast_context: + trainer.update(step, act) # save final SAEs for save_dir, trainer in zip(save_dirs, trainers): From 17aa5d52f818545afe5fbbe3edf1f774cde92f44 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 7 Jan 2025 06:17:40 +0000 Subject: [PATCH 48/70] Disable autocast for threshold tracking --- trainers/batch_top_k.py | 10 ++++++---- trainers/top_k.py | 9 +++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index d2062ea..21a21ff 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -21,8 +21,8 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.dict_size = dict_size assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" - self.register_buffer("k", t.tensor(k)) - self.register_buffer("threshold", t.tensor(-1.0)) + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False) self.decoder.weight.data = set_decoder_norm_to_unit_norm( @@ -186,13 +186,14 @@ def loss(self, x, step=None, logging=False): # l0 = (f != 0).float().sum(dim=-1).mean().item() if step > self.threshold_start_step: - with t.no_grad(): + device_type = 'cuda' if x.is_cuda else 'cpu' + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): active = f[f > 0] if active.size(0) == 0: min_activation = 0.0 else: - min_activation = active.min().detach() + min_activation = active.min().detach().to(dtype=t.float32) if self.ae.threshold < 0: self.ae.threshold = min_activation @@ -232,6 +233,7 @@ def loss(self, x, step=None, logging=False): def update(self, step, x): if step == 0: median = self.geometric_median(x) + median = median.to(self.ae.b_dec.dtype) self.ae.b_dec.data = median # Make sure the decoder is still unit-norm diff --git a/trainers/top_k.py b/trainers/top_k.py index 6889f5c..0f79fd1 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -66,8 +66,8 @@ def __init__(self, activation_dim: int, dict_size: int, k: int): self.dict_size = dict_size assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" - self.register_buffer("k", t.tensor(k)) - self.register_buffer("threshold", t.tensor(-1.0)) + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) self.decoder = nn.Linear(dict_size, activation_dim, bias=False) self.decoder.weight.data = set_decoder_norm_to_unit_norm( @@ -223,10 +223,11 @@ def loss(self, x, step=None, logging=False): f, top_acts, top_indices = self.ae.encode(x, return_topk=True, use_threshold=False) if step > self.threshold_start_step: - with t.no_grad(): + device_type = 'cuda' if x.is_cuda else 'cpu' + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): active = top_acts.clone().detach() active[active <= 0] = float("inf") - min_activations = active.min(dim=1).values + min_activations = active.min(dim=1).values.to(dtype=t.float32) min_activation = min_activations.mean() B, K = active.shape From 65e7af80441e5b601114756afc36a4041cec152f Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 7 Jan 2025 06:38:13 +0000 Subject: [PATCH 49/70] Also update context manager for matroyshka threshold --- trainers/matroyshka_batch_top_k.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index a71838d..fc76533 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -44,8 +44,8 @@ def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: lis assert all(s > 0 for s in group_sizes), "all group sizes must be positive" assert isinstance(k, int) and k > 0, f"k={k} must be a positive integer" - self.register_buffer("k", t.tensor(k)) - self.register_buffer("threshold", t.tensor(-1.0)) + self.register_buffer("k", t.tensor(k, dtype=t.int)) + self.register_buffer("threshold", t.tensor(-1.0, dtype=t.float32)) self.active_groups = len(group_sizes) group_indices = [0] + list(t.cumsum(t.tensor(group_sizes), dim=0)) @@ -240,13 +240,14 @@ def loss(self, x, step=None, logging=False): # l0 = (f != 0).float().sum(dim=-1).mean().item() if step > self.threshold_start_step: - with t.no_grad(): + device_type = 'cuda' if x.is_cuda else 'cpu' + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): active = f[f > 0] if active.size(0) == 0: min_activation = 0.0 else: - min_activation = active.min().detach() + min_activation = active.min().detach().to(dtype=t.float32) if self.ae.threshold < 0: self.ae.threshold = min_activation From 52b0c54ba92630cfb2ae007f020ed447d4a5ba9f Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 7 Jan 2025 18:07:51 +0000 Subject: [PATCH 50/70] By default, don't normalize Gated activations during inference --- dictionary.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/dictionary.py b/dictionary.py index 3b7e976..bceed6a 100644 --- a/dictionary.py +++ b/dictionary.py @@ -211,7 +211,7 @@ def _reset_parameters(self): self.decoder.weight = nn.Parameter(dec_weight) self.encoder.weight = nn.Parameter(dec_weight.clone().T) - def encode(self, x, return_gate=False): + def encode(self, x: t.Tensor, return_gate:bool=False, normalize_decoder:bool=False): """ Returns features, gate value (pre-Heavyside) """ @@ -227,26 +227,29 @@ def encode(self, x, return_gate=False): f = f_gate * f_mag - # W_dec norm is not kept constant, as per Anthropic's April 2024 Update - # Normalizing after encode, and renormalizing before decode to enable comparability - f = f * self.decoder.weight.norm(dim=0, keepdim=True) + if normalize_decoder: + # If the SAE is trained without ConstrainedAdam, the decoder vectors are not normalized + # Normalizing after encode, and renormalizing before decode to enable comparability + f = f * self.decoder.weight.norm(dim=0, keepdim=True) if return_gate: return f, nn.ReLU()(pi_gate) return f - def decode(self, f): - # W_dec norm is not kept constant, as per Anthropic's April 2024 Update - # Normalizing after encode, and renormalizing before decode to enable comparability - f = f / self.decoder.weight.norm(dim=0, keepdim=True) + def decode(self, f: t.Tensor, normalize_decoder:bool=False): + if normalize_decoder: + # If the SAE is trained without ConstrainedAdam, the decoder vectors are not normalized + # Normalizing after encode, and renormalizing before decode to enable comparability + f = f / self.decoder.weight.norm(dim=0, keepdim=True) return self.decoder(f) + self.decoder_bias - def forward(self, x, output_features=False): + def forward(self, x:t.Tensor, output_features:bool=False, normalize_decoder:bool=False): f = self.encode(x) x_hat = self.decode(f) - f = f * self.decoder.weight.norm(dim=0, keepdim=True) + if normalize_decoder: + f = f * self.decoder.weight.norm(dim=0, keepdim=True) if output_features: return x_hat, f From 8363ff779eee04518edaac9d10d97e459f708b66 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 7 Jan 2025 19:56:29 +0000 Subject: [PATCH 51/70] Import trainers from correct relative location for submodule use --- utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/utils.py b/utils.py index 2cc12b3..c5b4dbc 100644 --- a/utils.py +++ b/utils.py @@ -5,10 +5,10 @@ import os from nnsight import LanguageModel -from dictionary_learning.trainers.top_k import AutoEncoderTopK -from dictionary_learning.trainers.batch_top_k import BatchTopKSAE -from dictionary_learning.trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE -from dictionary_learning.dictionary import ( +from .trainers.top_k import AutoEncoderTopK +from .trainers.batch_top_k import BatchTopKSAE +from .trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE +from .dictionary import ( AutoEncoder, GatedAutoEncoder, AutoEncoderNew, From c697d0f83984f0f257be2044231c30f2abb15aa1 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 7 Jan 2025 21:02:37 +0000 Subject: [PATCH 52/70] Make sure x is on the correct dtype for jumprelu when logging --- trainers/jumprelu.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index 7323c27..5e8d27d 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -150,6 +150,7 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): # I had poor results when using log_threshold and it would complicate the scale_biases() function sparsity_scale = self.sparsity_warmup_fn(step) + x = x.to(self.ae.W_enc.dtype) pre_jump = x @ self.ae.W_enc + self.ae.b_enc f = JumpReLUFunction.apply(pre_jump, self.ae.threshold, self.bandwidth) From 6c2fcfc2a8108ba720591eb414be6ab16157dc36 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Wed, 8 Jan 2025 22:22:08 +0000 Subject: [PATCH 53/70] Remove experimental matroyshka temperature --- trainers/matroyshka_batch_top_k.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index fc76533..e4ec401 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -137,7 +137,6 @@ def __init__( lm_name: str, group_fractions: list[float], group_weights: Optional[list[float]] = None, - weights_temperature: float = 1.0, dict_class: type = MatroyshkaBatchTopKSAE, lr: Optional[float] = None, auxk_alpha: float = 1 / 32, @@ -174,9 +173,7 @@ def __init__( group_sizes.append(dict_size - sum(group_sizes)) if group_weights is None: - group_weights = group_fractions.copy() - - group_weights = apply_temperature(group_weights, weights_temperature) + group_weights = [(1.0 / len(group_sizes))] * len(group_sizes) assert len(group_sizes) == len( group_weights @@ -185,7 +182,6 @@ def __init__( self.group_fractions = group_fractions self.group_sizes = group_sizes self.group_weights = group_weights - self.weights_temperature = weights_temperature self.ae = dict_class(activation_dim, dict_size, k, group_sizes) @@ -349,7 +345,6 @@ def config(self): "group_fractions": self.group_fractions, "group_weights": self.group_weights, "group_sizes": self.group_sizes, - "weights_temperature": self.weights_temperature, "k": self.ae.k.item(), "device": self.device, "layer": self.layer, From 200ed3bed09c88d336c25a886eee4cb98c1e616e Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 13 Jan 2025 04:05:03 +0000 Subject: [PATCH 54/70] Normalize decoder after optimzer step --- trainers/batch_top_k.py | 12 ++++++------ trainers/jumprelu.py | 13 +++++++------ trainers/matroyshka_batch_top_k.py | 20 ++++++++++---------- trainers/top_k.py | 13 +++++++------ trainers/trainer.py | 28 ++++++++++++++++------------ 5 files changed, 46 insertions(+), 40 deletions(-) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index 21a21ff..d4af156 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -236,27 +236,27 @@ def update(self, step, x): median = median.to(self.ae.b_dec.dtype) self.ae.b_dec.data = median - # Make sure the decoder is still unit-norm - self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( - self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size - ) - x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() - t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( self.ae.decoder.weight, self.ae.decoder.weight.grad, self.ae.activation_dim, self.ae.dict_size, ) + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + # Make sure the decoder is still unit-norm + self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size + ) + return loss.item() @property diff --git a/trainers/jumprelu.py b/trainers/jumprelu.py index 5e8d27d..2ac717a 100644 --- a/trainers/jumprelu.py +++ b/trainers/jumprelu.py @@ -188,24 +188,25 @@ def loss(self, x: torch.Tensor, step: int, logging=False, **_): ) def update(self, step, x): - # We must transpose because we are using nn.Parameter, not nn.Linear - self.ae.W_dec.data = set_decoder_norm_to_unit_norm( - self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size - ).T - x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() - torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) # We must transpose because we are using nn.Parameter, not nn.Linear self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size ).T + torch.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() + + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.data = set_decoder_norm_to_unit_norm( + self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size + ).T + return loss.item() @property diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index e4ec401..fea0803 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -175,9 +175,9 @@ def __init__( if group_weights is None: group_weights = [(1.0 / len(group_sizes))] * len(group_sizes) - assert len(group_sizes) == len( - group_weights - ), "group_sizes and group_weights must have the same length" + assert len(group_sizes) == len(group_weights), ( + "group_sizes and group_weights must have the same length" + ) self.group_fractions = group_fractions self.group_sizes = group_sizes @@ -236,7 +236,7 @@ def loss(self, x, step=None, logging=False): # l0 = (f != 0).float().sum(dim=-1).mean().item() if step > self.threshold_start_step: - device_type = 'cuda' if x.is_cuda else 'cpu' + device_type = "cuda" if x.is_cuda else "cpu" with t.autocast(device_type=device_type, enabled=False), t.no_grad(): active = f[f > 0] @@ -305,25 +305,25 @@ def update(self, step, x): median = self.geometric_median(x) self.ae.b_dec.data = median - # We must transpose because we are using nn.Parameter, not nn.Linear - self.ae.W_dec.data = set_decoder_norm_to_unit_norm( - self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size - ).T - x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() - t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) # We must transpose because we are using nn.Parameter, not nn.Linear self.ae.W_dec.grad = remove_gradient_parallel_to_decoder_directions( self.ae.W_dec.T, self.ae.W_dec.grad.T, self.ae.activation_dim, self.ae.dict_size ).T + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + # We must transpose because we are using nn.Parameter, not nn.Linear + self.ae.W_dec.data = set_decoder_norm_to_unit_norm( + self.ae.W_dec.T, self.ae.activation_dim, self.ae.dict_size + ).T + return loss.item() @property diff --git a/trainers/top_k.py b/trainers/top_k.py index 0f79fd1..2141afc 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -312,29 +312,30 @@ def update(self, step, x): median = median.to(self.ae.b_dec.dtype) self.ae.b_dec.data = median - # Make sure the decoder is still unit-norm - self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( - self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size - ) - # compute the loss x = x.to(self.device) loss = self.loss(x, step=step) loss.backward() # clip grad norm and remove grads parallel to decoder directions - t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) self.ae.decoder.weight.grad = remove_gradient_parallel_to_decoder_directions( self.ae.decoder.weight, self.ae.decoder.weight.grad, self.ae.activation_dim, self.ae.dict_size, ) + t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) # do a training step self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() + + # Make sure the decoder is still unit-norm + self.ae.decoder.weight.data = set_decoder_norm_to_unit_norm( + self.ae.decoder.weight, self.ae.activation_dim, self.ae.dict_size + ) + return loss.item() @property diff --git a/trainers/trainer.py b/trainers/trainer.py index 03ffb54..15eb4ed 100644 --- a/trainers/trainer.py +++ b/trainers/trainer.py @@ -42,7 +42,9 @@ class ConstrainedAdam(torch.optim.Adam): If nn.Parameter, the dim argument to norm should be 1. """ - def __init__(self, params, constrained_params, lr: float, betas: tuple[float, float] = (0.9, 0.999)): + def __init__( + self, params, constrained_params, lr: float, betas: tuple[float, float] = (0.9, 0.999) + ): super().__init__(params, lr=lr, betas=betas) self.constrained_params = list(constrained_params) @@ -94,14 +96,16 @@ def remove_gradient_parallel_to_decoder_directions( assert D == activation_dim assert F == d_sae + normed_W_dec_DF = W_dec_DF / (torch.norm(W_dec_DF, dim=0, keepdim=True) + 1e-6) + parallel_component = einops.einsum( W_dec_DF_grad, - W_dec_DF, + normed_W_dec_DF, "d_in d_sae, d_in d_sae -> d_sae", ) W_dec_DF_grad -= einops.einsum( parallel_component, - W_dec_DF, + normed_W_dec_DF, "d_sae, d_in d_sae -> d_in d_sae", ) return W_dec_DF_grad @@ -131,15 +135,15 @@ def get_lr_schedule( Function that computes LR scale factor for a given step """ if decay_start is not None: - assert ( - resample_steps is None - ), "decay_start and resample_steps are currently mutually exclusive." + assert resample_steps is None, ( + "decay_start and resample_steps are currently mutually exclusive." + ) assert 0 <= decay_start < total_steps, "decay_start must be >= 0 and < steps." assert decay_start > warmup_steps, "decay_start must be > warmup_steps." if sparsity_warmup_steps is not None: - assert ( - decay_start > sparsity_warmup_steps - ), "decay_start must be > sparsity_warmup_steps." + assert decay_start > sparsity_warmup_steps, ( + "decay_start must be > sparsity_warmup_steps." + ) assert 0 <= warmup_steps < total_steps, "warmup_steps must be >= 0 and < steps." @@ -176,9 +180,9 @@ def get_sparsity_warmup_fn( """ if sparsity_warmup_steps is not None: - assert ( - 0 <= sparsity_warmup_steps < total_steps - ), "sparsity_warmup_steps must be >= 0 and < steps." + assert 0 <= sparsity_warmup_steps < total_steps, ( + "sparsity_warmup_steps must be >= 0 and < steps." + ) def scale_fn(step: int) -> float: if not sparsity_warmup_steps: From 0af19713feb5b4c35788039245013736bf974383 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 13 Jan 2025 05:15:40 +0000 Subject: [PATCH 55/70] Standardize and fix topk auxk loss implementation --- trainers/batch_top_k.py | 122 ++++++++++++++------------ trainers/matroyshka_batch_top_k.py | 98 ++++++++++++--------- trainers/top_k.py | 134 +++++++++++++++-------------- 3 files changed, 192 insertions(+), 162 deletions(-) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index d4af156..c933267 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -39,24 +39,19 @@ def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = if use_threshold: encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) - if return_active: - return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 - else: - return encoded_acts_BF - - # Flatten and perform batch top-k - flattened_acts = post_relu_feat_acts_BF.flatten() - post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) - - buffer_BF = t.zeros_like(post_relu_feat_acts_BF) - encoded_acts_BF = ( - buffer_BF.flatten() - .scatter(-1, post_topk.indices, post_topk.values) - .reshape(buffer_BF.shape) - ) + else: + # Flatten and perform batch top-k + flattened_acts = post_relu_feat_acts_BF.flatten() + post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) + + encoded_acts_BF = ( + t.zeros_like(post_relu_feat_acts_BF.flatten()) + .scatter_(-1, post_topk.indices, post_topk.values) + .reshape(post_relu_feat_acts_BF.shape) + ) if return_active: - return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF else: return encoded_acts_BF @@ -146,9 +141,15 @@ def __init__( # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper scale = dict_size / (2**14) self.lr = 2e-4 / scale**0.5 + self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper + self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_aux_loss"] + self.effective_l0 = -1 + self.dead_features = -1 + self.pre_norm_aux_loss = -1 self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) @@ -156,68 +157,79 @@ def __init__( self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features"] - self.effective_l0 = -1 - self.dead_features = -1 - - def get_auxiliary_loss(self, x, x_reconstruct, acts): + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold self.dead_features = int(dead_features.sum()) if dead_features.sum() > 0: - residual = x.float() - x_reconstruct.float() - acts_topk_aux = t.topk( - acts[:, dead_features], - min(self.top_k_aux, dead_features.sum()), - dim=-1, - ) - acts_aux = t.zeros_like(acts[:, dead_features]).scatter( - -1, acts_topk_aux.indices, acts_topk_aux.values + k_aux = min(self.top_k_aux, dead_features.sum()) + + auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) + + # Top-k dead latents + auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) + + auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) + auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + + # Note: decoder(), not decode(), as we don't want to apply the bias + x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) + l2_loss_aux = ( + (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() ) - x_reconstruct_aux = F.linear(acts_aux, self.ae.decoder.weight[:, dead_features]) - l2_loss_aux = (x_reconstruct_aux.float() - residual.float()).pow(2).mean() - return l2_loss_aux + + self.pre_norm_auxk_loss = l2_loss_aux + + # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) + loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + normalized_auxk_loss = l2_loss_aux / loss_denom + + return normalized_auxk_loss else: - return t.tensor(0, dtype=x.dtype, device=x.device) + self.pre_norm_auxk_loss = -1 + return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) + + def update_threshold(self, f: t.Tensor): + device_type = "cuda" if f.is_cuda else "cpu" + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach().to(dtype=t.float32) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) def loss(self, x, step=None, logging=False): - f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False) + f, active_indices_F, post_relu_acts_BF = self.ae.encode( + x, return_active=True, use_threshold=False + ) # l0 = (f != 0).float().sum(dim=-1).mean().item() if step > self.threshold_start_step: - device_type = 'cuda' if x.is_cuda else 'cpu' - with t.autocast(device_type=device_type, enabled=False), t.no_grad(): - active = f[f > 0] - - if active.size(0) == 0: - min_activation = 0.0 - else: - min_activation = active.min().detach().to(dtype=t.float32) - - if self.ae.threshold < 0: - self.ae.threshold = min_activation - else: - self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( - (1 - self.threshold_beta) * min_activation - ) + self.update_threshold(f) x_hat = self.ae.decode(f) - e = x_hat - x + e = x - x_hat self.effective_l0 = self.k num_tokens_in_step = x.size(0) did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) - did_fire[active_indices] = True + did_fire[active_indices_F] = True self.num_tokens_since_fired += num_tokens_in_step self.num_tokens_since_fired[did_fire] = 0 - auxk_loss = self.get_auxiliary_loss(x, x_hat, f) - l2_loss = e.pow(2).sum(dim=-1).mean() - auxk_loss = auxk_loss.sum(dim=-1).mean() + auxk_loss = self.get_auxiliary_loss(e, post_relu_acts_BF) loss = l2_loss + self.auxk_alpha * auxk_loss if not logging: diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index fea0803..e191377 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -74,18 +74,17 @@ def encode(self, x: t.Tensor, return_active: bool = False, use_threshold: bool = flattened_acts = post_relu_feat_acts_BF.flatten() post_topk = flattened_acts.topk(self.k * x.size(0), sorted=False, dim=-1) - buffer_BF = t.zeros_like(post_relu_feat_acts_BF) encoded_acts_BF = ( - buffer_BF.flatten() - .scatter(-1, post_topk.indices, post_topk.values) - .reshape(buffer_BF.shape) + t.zeros_like(post_relu_feat_acts_BF.flatten()) + .scatter_(-1, post_topk.indices, post_topk.values) + .reshape(post_relu_feat_acts_BF.shape) ) max_act_index = self.group_indices[self.active_groups] encoded_acts_BF[:, max_act_index:] = 0 if return_active: - return encoded_acts_BF, encoded_acts_BF.sum(0) > 0 + return encoded_acts_BF, encoded_acts_BF.sum(0) > 0, post_relu_feat_acts_BF else: return encoded_acts_BF @@ -207,50 +206,69 @@ def __init__( self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features"] + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] self.effective_l0 = -1 self.dead_features = -1 + self.pre_norm_auxk_loss = -1 - def get_auxiliary_loss(self, x, x_reconstruct, acts): + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold self.dead_features = int(dead_features.sum()) - if dead_features.sum() > 0: - residual = x.float() - x_reconstruct.float() - acts_topk_aux = t.topk( - acts[:, dead_features], - min(self.top_k_aux, dead_features.sum()), - dim=-1, - ) - acts_aux = t.zeros_like(acts[:, dead_features]).scatter( - -1, acts_topk_aux.indices, acts_topk_aux.values + + if self.dead_features > 0: + k_aux = min(self.top_k_aux, self.dead_features) + + auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) + + # Top-k dead latents + auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) + + auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) + auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + + # We don't want to apply the bias + x_reconstruct_aux = auxk_acts_BF @ self.ae.W_dec + l2_loss_aux = ( + (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() ) - x_reconstruct_aux = F.linear(acts_aux, self.ae.W_dec[dead_features, :].T) - l2_loss_aux = (x_reconstruct_aux.float() - residual.float()).pow(2).mean() - return l2_loss_aux + self.pre_norm_auxk_loss = l2_loss_aux + + # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) + loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + normalized_auxk_loss = l2_loss_aux / loss_denom + + return normalized_auxk_loss else: - return t.tensor(0, dtype=x.dtype, device=x.device) + self.pre_norm_auxk_loss = -1 + return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) + + def update_threshold(self, f: t.Tensor): + device_type = "cuda" if f.is_cuda else "cpu" + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): + active = f[f > 0] + + if active.size(0) == 0: + min_activation = 0.0 + else: + min_activation = active.min().detach().to(dtype=t.float32) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) def loss(self, x, step=None, logging=False): - f, active_indices = self.ae.encode(x, return_active=True, use_threshold=False) + f, active_indices_F, post_relu_acts_BF = self.ae.encode( + x, return_active=True, use_threshold=False + ) # l0 = (f != 0).float().sum(dim=-1).mean().item() if step > self.threshold_start_step: - device_type = "cuda" if x.is_cuda else "cpu" - with t.autocast(device_type=device_type, enabled=False), t.no_grad(): - active = f[f > 0] - - if active.size(0) == 0: - min_activation = 0.0 - else: - min_activation = active.min().detach().to(dtype=t.float32) - - if self.ae.threshold < 0: - self.ae.threshold = min_activation - else: - self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( - (1 - self.threshold_beta) * min_activation - ) + self.update_threshold(f) x_reconstruct = t.zeros_like(x) + self.ae.b_dec total_l2_loss = 0.0 @@ -263,7 +281,7 @@ def loss(self, x, step=None, logging=False): acts_slice = f[:, group_start:group_end] x_reconstruct = x_reconstruct + acts_slice @ W_dec_slice - l2_loss = (x_reconstruct - x).pow(2).sum(dim=-1).mean() * self.group_weights[i] + l2_loss = (x - x_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[i] total_l2_loss += l2_loss l2_losses = t.cat([l2_losses, l2_loss.unsqueeze(0)]) @@ -275,13 +293,11 @@ def loss(self, x, step=None, logging=False): num_tokens_in_step = x.size(0) did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) - did_fire[active_indices] = True + did_fire[active_indices_F] = True self.num_tokens_since_fired += num_tokens_in_step self.num_tokens_since_fired[did_fire] = 0 - auxk_loss = self.get_auxiliary_loss(x, x_reconstruct, f) - - auxk_loss = auxk_loss.sum(dim=-1).mean() + auxk_loss = self.get_auxiliary_loss((x - x_reconstruct), post_relu_acts_BF) loss = mean_l2_loss + self.auxk_alpha * auxk_loss if not logging: diff --git a/trainers/top_k.py b/trainers/top_k.py index 2141afc..f60b711 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -87,7 +87,7 @@ def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = F encoded_acts_BF = post_relu_feat_acts_BF * (post_relu_feat_acts_BF > self.threshold) if return_topk: post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1) - return encoded_acts_BF, post_topk.values, post_topk.indices + return encoded_acts_BF, post_topk.values, post_topk.indices, post_relu_feat_acts_BF else: return encoded_acts_BF @@ -101,7 +101,7 @@ def encode(self, x: t.Tensor, return_topk: bool = False, use_threshold: bool = F encoded_acts_BF = buffer_BF.scatter_(dim=-1, index=top_indices_BK, src=tops_acts_BK) if return_topk: - return encoded_acts_BF, tops_acts_BK, top_indices_BK + return encoded_acts_BF, tops_acts_BK, top_indices_BK, post_relu_feat_acts_BF else: return encoded_acts_BF @@ -199,8 +199,15 @@ def __init__( # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper scale = dict_size / (2**14) self.lr = 2e-4 / scale**0.5 + self.auxk_alpha = auxk_alpha self.dead_feature_threshold = 10_000_000 + self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper + self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] + self.effective_l0 = -1 + self.dead_features = -1 + self.pre_norm_auxk_loss = -1 # Optimizer and scheduler self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) @@ -209,90 +216,85 @@ def __init__( self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) - # Training parameters - self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) + def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor): + dead_features = self.num_tokens_since_fired >= self.dead_feature_threshold + self.dead_features = int(dead_features.sum()) - # Log the effective L0, i.e. number of features actually used, which should a constant value (K) - # Note: The standard L0 is essentially a measure of dead features for Top-K SAEs) - self.logging_parameters = ["effective_l0", "dead_features"] - self.effective_l0 = -1 - self.dead_features = -1 + if self.dead_features > 0: + k_aux = min(self.top_k_aux, self.dead_features) + + auxk_latents = t.where(dead_features[None], post_relu_acts_BF, -t.inf) + + # Top-k dead latents + auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) + + auxk_buffer_BF = t.zeros_like(post_relu_acts_BF) + auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) + + # Note: decoder(), not decode(), as we don't want to apply the bias + x_reconstruct_aux = self.ae.decoder(auxk_acts_BF) + l2_loss_aux = ( + (residual_BD.float() - x_reconstruct_aux.float()).pow(2).sum(dim=-1).mean() + ) + + self.pre_norm_auxk_loss = l2_loss_aux + + # normalization from OpenAI implementation: https://github.com/openai/sparse_autoencoder/blob/main/sparse_autoencoder/kernels.py#L614 + residual_mu = residual_BD.mean(dim=0)[None, :].broadcast_to(residual_BD.shape) + loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() + normalized_auxk_loss = l2_loss_aux / loss_denom + + return normalized_auxk_loss + else: + self.pre_norm_auxk_loss = -1 + return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) + + def update_threshold(self, top_acts_BK: t.Tensor): + device_type = "cuda" if top_acts_BK.is_cuda else "cpu" + with t.autocast(device_type=device_type, enabled=False), t.no_grad(): + active = top_acts_BK.clone().detach() + active[active <= 0] = float("inf") + min_activations = active.min(dim=1).values.to(dtype=t.float32) + min_activation = min_activations.mean() + + B, K = active.shape + assert len(active.shape) == 2 + assert min_activations.shape == (B,) + + if self.ae.threshold < 0: + self.ae.threshold = min_activation + else: + self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( + (1 - self.threshold_beta) * min_activation + ) def loss(self, x, step=None, logging=False): # Run the SAE - f, top_acts, top_indices = self.ae.encode(x, return_topk=True, use_threshold=False) + f, top_acts_BK, top_indices_BK, post_relu_acts_BF = self.ae.encode( + x, return_topk=True, use_threshold=False + ) if step > self.threshold_start_step: - device_type = 'cuda' if x.is_cuda else 'cpu' - with t.autocast(device_type=device_type, enabled=False), t.no_grad(): - active = top_acts.clone().detach() - active[active <= 0] = float("inf") - min_activations = active.min(dim=1).values.to(dtype=t.float32) - min_activation = min_activations.mean() - - B, K = active.shape - assert len(active.shape) == 2 - assert min_activations.shape == (B,) - - if self.ae.threshold < 0: - self.ae.threshold = min_activation - else: - self.ae.threshold = (self.threshold_beta * self.ae.threshold) + ( - (1 - self.threshold_beta) * min_activation - ) + self.update_threshold(top_acts_BK) x_hat = self.ae.decode(f) # Measure goodness of reconstruction - e = x_hat - x - total_variance = (x - x.mean(0)).pow(2).sum(0) + e = x - x_hat # Update the effective L0 (again, should just be K) - self.effective_l0 = top_acts.size(1) + self.effective_l0 = top_acts_BK.size(1) # Update "number of tokens since fired" for each features num_tokens_in_step = x.size(0) did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) - did_fire[top_indices.flatten()] = True + did_fire[top_indices_BK.flatten()] = True self.num_tokens_since_fired += num_tokens_in_step self.num_tokens_since_fired[did_fire] = 0 - # Compute dead feature mask based on "number of tokens since fired" - dead_mask = ( - self.num_tokens_since_fired > self.dead_feature_threshold - if self.auxk_alpha > 0 - else None - ).to(f.device) - self.dead_features = int(dead_mask.sum()) - - # If dead features: Second decoder pass for AuxK loss - if dead_mask is not None and (num_dead := int(dead_mask.sum())) > 0: - # Heuristic from Appendix B.1 in the paper - k_aux = x.shape[-1] // 2 - - # Reduce the scale of the loss if there are a small number of dead latents - scale = min(num_dead / k_aux, 1.0) - k_aux = min(k_aux, num_dead) - - # Don't include living latents in this loss - auxk_latents = t.where(dead_mask[None], f, -t.inf) - - # Top-k dead latents - auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) - - auxk_buffer_BF = t.zeros_like(f) - auxk_acts_BF = auxk_buffer_BF.scatter_(dim=-1, index=auxk_indices, src=auxk_acts) - - # Encourage the top ~50% of dead latents to predict the residual of the - # top k living latents - e_hat = self.ae.decode(auxk_acts_BF) - auxk_loss = (e_hat - e).pow(2) # .sum(0) - auxk_loss = scale * t.mean(auxk_loss / total_variance) - else: - auxk_loss = x_hat.new_tensor(0.0) - l2_loss = e.pow(2).sum(dim=-1).mean() - auxk_loss = auxk_loss.sum(dim=-1).mean() + auxk_loss = self.get_auxiliary_loss(e, post_relu_acts_BF) if self.auxk_alpha > 0 else 0 + loss = l2_loss + self.auxk_alpha * auxk_loss if not logging: From db2b5642e2966559a907e4885bf3317ea997a494 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 13 Jan 2025 21:00:42 +0000 Subject: [PATCH 56/70] Make sure to detach reconstruction before calculating aux loss --- trainers/batch_top_k.py | 4 ++-- trainers/matroyshka_batch_top_k.py | 4 ++-- trainers/top_k.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index c933267..4abd8e6 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -185,7 +185,7 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor) loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() normalized_auxk_loss = l2_loss_aux / loss_denom - return normalized_auxk_loss + return normalized_auxk_loss.nan_to_num(0.0) else: self.pre_norm_auxk_loss = -1 return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) @@ -229,7 +229,7 @@ def loss(self, x, step=None, logging=False): self.num_tokens_since_fired[did_fire] = 0 l2_loss = e.pow(2).sum(dim=-1).mean() - auxk_loss = self.get_auxiliary_loss(e, post_relu_acts_BF) + auxk_loss = self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) loss = l2_loss + self.auxk_alpha * auxk_loss if not logging: diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matroyshka_batch_top_k.py index e191377..8b32eed 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matroyshka_batch_top_k.py @@ -239,7 +239,7 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor) loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() normalized_auxk_loss = l2_loss_aux / loss_denom - return normalized_auxk_loss + return normalized_auxk_loss.nan_to_num(0.0) else: self.pre_norm_auxk_loss = -1 return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) @@ -297,7 +297,7 @@ def loss(self, x, step=None, logging=False): self.num_tokens_since_fired += num_tokens_in_step self.num_tokens_since_fired[did_fire] = 0 - auxk_loss = self.get_auxiliary_loss((x - x_reconstruct), post_relu_acts_BF) + auxk_loss = self.get_auxiliary_loss((x - x_reconstruct).detach(), post_relu_acts_BF) loss = mean_l2_loss + self.auxk_alpha * auxk_loss if not logging: diff --git a/trainers/top_k.py b/trainers/top_k.py index f60b711..f6f5692 100644 --- a/trainers/top_k.py +++ b/trainers/top_k.py @@ -244,7 +244,7 @@ def get_auxiliary_loss(self, residual_BD: t.Tensor, post_relu_acts_BF: t.Tensor) loss_denom = (residual_BD.float() - residual_mu.float()).pow(2).sum(dim=-1).mean() normalized_auxk_loss = l2_loss_aux / loss_denom - return normalized_auxk_loss + return normalized_auxk_loss.nan_to_num(0.0) else: self.pre_norm_auxk_loss = -1 return t.tensor(0, dtype=residual_BD.dtype, device=residual_BD.device) @@ -293,7 +293,9 @@ def loss(self, x, step=None, logging=False): self.num_tokens_since_fired[did_fire] = 0 l2_loss = e.pow(2).sum(dim=-1).mean() - auxk_loss = self.get_auxiliary_loss(e, post_relu_acts_BF) if self.auxk_alpha > 0 else 0 + auxk_loss = ( + self.get_auxiliary_loss(e.detach(), post_relu_acts_BF) if self.auxk_alpha > 0 else 0 + ) loss = l2_loss + self.auxk_alpha * auxk_loss From 77f2690abcd56ce19aaf3c1404dcfcfc6cf9381b Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 13 Jan 2025 19:09:54 -0600 Subject: [PATCH 57/70] Add citation --- README.md | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 1766b69..d01811b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ -This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks and Aaron Mueller. +This is a repository for doing dictionary learning via sparse autoencoders on neural network activations. It was developed by Samuel Marks, Adam Karvonen, and Aaron Mueller. For accessing, saving, and intervening on NN activations, we use the [`nnsight`](http://nnsight.net/) package; as of March 2024, `nnsight` is under active development and may undergo breaking changes. That said, `nnsight` is easy to use and quick to learn; if you plan to modify this repo, then we recommend going through the main `nnsight` demo [here](https://nnsight.net/notebooks/tutorials/walkthrough/). -Some dictionaries trained using this repository (and associated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. +Some dictionaries trained using this repository (and associated training checkpoints) can be accessed at [https://baulab.us/u/smarks/autoencoders/](https://baulab.us/u/smarks/autoencoders/). See below for more information about these dictionaries. SAEs trained with `dictionary_learning` can be evaluated with [SAE Bench](https://www.neuronpedia.org/sae-bench/info) using a convenient [evaluation script](https://github.com/adamkarvonen/SAEBench/tree/main/sae_bench/custom_saes). # Set-up @@ -211,3 +211,16 @@ We've included support for some experimental features. We briefly investigated t * h/t to Max Li for this suggestion. * **Replacing L1 loss with entropy**. Based on the ideas in this [post](https://transformer-circuits.pub/2023/may-update/index.html#simple-factorization), we experimented with using entropy to regularize a dictionary's hidden state instead of L1 loss. This seemed to cause the features to split into dead features (which never fired) and very high-frequency features which fired on nearly every input, which was not the desired behavior. But plausibly there is a way to make this work better. * **Ghost grads**, as described [here](https://transformer-circuits.pub/2024/jan-update/index.html). + +# Citation + +Please cite the package as follows: + +``` +@misc{marks2024dictionary_learning, + title = {dictionary_learning}, + author = {Samuel Marks, Adam Karvonen, and Aaron Mueller}, + year = {2024}, + howpublished = {\url{https://github.com/saprmarks/dictionary_learning}}, +} +``` From 784a62a405be4ee8754a76ad4d3e61fd7de06348 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Mon, 13 Jan 2025 20:22:53 -0600 Subject: [PATCH 58/70] Fix incorrect auxk logging name --- trainers/batch_top_k.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trainers/batch_top_k.py b/trainers/batch_top_k.py index 4abd8e6..686dc0a 100644 --- a/trainers/batch_top_k.py +++ b/trainers/batch_top_k.py @@ -146,10 +146,10 @@ def __init__( self.dead_feature_threshold = 10_000_000 self.top_k_aux = activation_dim // 2 # Heuristic from B.1 of the paper self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) - self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_aux_loss"] + self.logging_parameters = ["effective_l0", "dead_features", "pre_norm_auxk_loss"] self.effective_l0 = -1 self.dead_features = -1 - self.pre_norm_aux_loss = -1 + self.pre_norm_auxk_loss = -1 self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) From aa45bf6ed9aa981f6a266f333e6d4a8b9d459909 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 16 Jan 2025 21:36:52 +0000 Subject: [PATCH 59/70] Fix matryoshka spelling --- ...shka_batch_top_k.py => matryoshka_batch_top_k.py} | 12 ++++++------ utils.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) rename trainers/{matroyshka_batch_top_k.py => matryoshka_batch_top_k.py} (98%) diff --git a/trainers/matroyshka_batch_top_k.py b/trainers/matryoshka_batch_top_k.py similarity index 98% rename from trainers/matroyshka_batch_top_k.py rename to trainers/matryoshka_batch_top_k.py index 8b32eed..fc0c805 100644 --- a/trainers/matroyshka_batch_top_k.py +++ b/trainers/matryoshka_batch_top_k.py @@ -34,7 +34,7 @@ def apply_temperature(probabilities: list[float], temperature: float) -> list[fl return scaled_probs.tolist() -class MatroyshkaBatchTopKSAE(Dictionary, nn.Module): +class MatryoshkaBatchTopKSAE(Dictionary, nn.Module): def __init__(self, activation_dim: int, dict_size: int, k: int, group_sizes: list[int]): super().__init__() self.activation_dim = activation_dim @@ -108,7 +108,7 @@ def scale_biases(self, scale: float): self.threshold *= scale @classmethod - def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatroyshkaBatchTopKSAE": + def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatryoshkaBatchTopKSAE": state_dict = t.load(path) activation_dim, dict_size = state_dict["W_enc"].shape if k is None: @@ -125,7 +125,7 @@ def from_pretrained(cls, path, k=None, device=None, **kwargs) -> "MatroyshkaBatc return autoencoder -class MatroyshkaBatchTopKTrainer(SAETrainer): +class MatryoshkaBatchTopKTrainer(SAETrainer): def __init__( self, steps: int, # total number of steps to train for @@ -136,7 +136,7 @@ def __init__( lm_name: str, group_fractions: list[float], group_weights: Optional[list[float]] = None, - dict_class: type = MatroyshkaBatchTopKSAE, + dict_class: type = MatryoshkaBatchTopKSAE, lr: Optional[float] = None, auxk_alpha: float = 1 / 32, warmup_steps: int = 1000, @@ -345,8 +345,8 @@ def update(self, step, x): @property def config(self): return { - "trainer_class": "MatroyshkaBatchTopKTrainer", - "dict_class": "MatroyshkaBatchTopKSAE", + "trainer_class": "MatryoshkaBatchTopKTrainer", + "dict_class": "MatryoshkaBatchTopKSAE", "lr": self.lr, "steps": self.steps, "auxk_alpha": self.auxk_alpha, diff --git a/utils.py b/utils.py index c5b4dbc..3b1077e 100644 --- a/utils.py +++ b/utils.py @@ -7,7 +7,7 @@ from .trainers.top_k import AutoEncoderTopK from .trainers.batch_top_k import BatchTopKSAE -from .trainers.matroyshka_batch_top_k import MatroyshkaBatchTopKSAE +from .trainers.matryoshka_batch_top_k import MatryoshkaBatchTopKSAE from .dictionary import ( AutoEncoder, GatedAutoEncoder, @@ -77,9 +77,9 @@ def load_dictionary(base_path: str, device: str) -> tuple: elif dict_class == "BatchTopKSAE": k = config["trainer"]["k"] dictionary = BatchTopKSAE.from_pretrained(ae_path, k=k, device=device) - elif dict_class == "MatroyshkaBatchTopKSAE": + elif dict_class == "MatryoshkaBatchTopKSAE": k = config["trainer"]["k"] - dictionary = MatroyshkaBatchTopKSAE.from_pretrained(ae_path, k=k, device=device) + dictionary = MatryoshkaBatchTopKSAE.from_pretrained(ae_path, k=k, device=device) elif dict_class == "JumpReluAutoEncoder": dictionary = JumpReluAutoEncoder.from_pretrained(ae_path, device=device) else: From 505a4455358f079db9f2b0309cc0922169869965 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 16 Jan 2025 21:39:01 +0000 Subject: [PATCH 60/70] Use torch.split() instead of direct indexing for 25% speedup --- trainers/matryoshka_batch_top_k.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/trainers/matryoshka_batch_top_k.py b/trainers/matryoshka_batch_top_k.py index fc0c805..3d4f85e 100644 --- a/trainers/matryoshka_batch_top_k.py +++ b/trainers/matryoshka_batch_top_k.py @@ -274,14 +274,22 @@ def loss(self, x, step=None, logging=False): total_l2_loss = 0.0 l2_losses = t.tensor([]).to(self.device) + intermediates = [] + # We could potentially refactor the ae class to use W_dec_chunks instead of W_dec, may be more efficient + W_dec_chunks = t.split(self.ae.W_dec, self.ae.group_sizes.tolist(), dim=0) + f_chunks = t.split(f, self.ae.group_sizes.tolist(), dim=1) + for i in range(self.ae.active_groups): - group_start = self.ae.group_indices[i] - group_end = self.ae.group_indices[i + 1] - W_dec_slice = self.ae.W_dec[group_start:group_end, :] - acts_slice = f[:, group_start:group_end] + W_dec_slice = W_dec_chunks[i] + acts_slice = f_chunks[i] + x_reconstruct = x_reconstruct + acts_slice @ W_dec_slice + intermediates.append(x_reconstruct) - l2_loss = (x - x_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[i] + for intermediate_reconstruct in intermediates: + l2_loss = (x - intermediate_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[ + i + ] total_l2_loss += l2_loss l2_losses = t.cat([l2_losses, l2_loss.unsqueeze(0)]) From 43421f5934a1476cb3f32f0b9e1b5d14b84540a1 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Thu, 16 Jan 2025 23:07:36 +0000 Subject: [PATCH 61/70] simplify matryoshka loss --- trainers/matryoshka_batch_top_k.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/trainers/matryoshka_batch_top_k.py b/trainers/matryoshka_batch_top_k.py index 3d4f85e..67647fb 100644 --- a/trainers/matryoshka_batch_top_k.py +++ b/trainers/matryoshka_batch_top_k.py @@ -274,7 +274,6 @@ def loss(self, x, step=None, logging=False): total_l2_loss = 0.0 l2_losses = t.tensor([]).to(self.device) - intermediates = [] # We could potentially refactor the ae class to use W_dec_chunks instead of W_dec, may be more efficient W_dec_chunks = t.split(self.ae.W_dec, self.ae.group_sizes.tolist(), dim=0) f_chunks = t.split(f, self.ae.group_sizes.tolist(), dim=1) @@ -282,14 +281,9 @@ def loss(self, x, step=None, logging=False): for i in range(self.ae.active_groups): W_dec_slice = W_dec_chunks[i] acts_slice = f_chunks[i] - x_reconstruct = x_reconstruct + acts_slice @ W_dec_slice - intermediates.append(x_reconstruct) - for intermediate_reconstruct in intermediates: - l2_loss = (x - intermediate_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[ - i - ] + l2_loss = (x - x_reconstruct).pow(2).sum(dim=-1).mean() * self.group_weights[i] total_l2_loss += l2_loss l2_losses = t.cat([l2_losses, l2_loss.unsqueeze(0)]) From 0ff88883e7caac8ebd7ea0d8e07585451d8b7f9f Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sun, 9 Feb 2025 22:07:01 -0800 Subject: [PATCH 62/70] feat: pypi packaging and auto-release with semantic release --- .github/workflows/build.yml | 90 ++++++++++++ .gitignore | 2 +- README.md | 6 +- __init__.py | 2 - dictionary_learning/__init__.py | 6 + buffer.py => dictionary_learning/buffer.py | 0 config.py => dictionary_learning/config.py | 0 .../dictionary.py | 0 .../evaluation.py | 0 .../grad_pursuit.py | 0 interp.py => dictionary_learning/interp.py | 2 +- .../trainers}/__init__.py | 12 ++ .../trainers}/batch_top_k.py | 0 .../trainers}/gated_anneal.py | 0 .../trainers}/gdm.py | 0 .../trainers}/jumprelu.py | 0 .../trainers}/matryoshka_batch_top_k.py | 0 .../trainers}/p_anneal.py | 0 .../trainers}/standard.py | 0 .../trainers}/top_k.py | 0 .../trainers}/trainer.py | 0 .../training.py | 0 utils.py => dictionary_learning/utils.py | 0 pyproject.toml | 45 ++++++ requirements.txt | 13 -- tests/test_end_to_end.py | 10 +- tests/unit/test_dictionary.py | 136 ++++++++++++++++++ 27 files changed, 298 insertions(+), 26 deletions(-) create mode 100644 .github/workflows/build.yml delete mode 100644 __init__.py create mode 100644 dictionary_learning/__init__.py rename buffer.py => dictionary_learning/buffer.py (100%) rename config.py => dictionary_learning/config.py (100%) rename dictionary.py => dictionary_learning/dictionary.py (100%) rename evaluation.py => dictionary_learning/evaluation.py (100%) rename grad_pursuit.py => dictionary_learning/grad_pursuit.py (100%) rename interp.py => dictionary_learning/interp.py (99%) rename {trainers => dictionary_learning/trainers}/__init__.py (58%) rename {trainers => dictionary_learning/trainers}/batch_top_k.py (100%) rename {trainers => dictionary_learning/trainers}/gated_anneal.py (100%) rename {trainers => dictionary_learning/trainers}/gdm.py (100%) rename {trainers => dictionary_learning/trainers}/jumprelu.py (100%) rename {trainers => dictionary_learning/trainers}/matryoshka_batch_top_k.py (100%) rename {trainers => dictionary_learning/trainers}/p_anneal.py (100%) rename {trainers => dictionary_learning/trainers}/standard.py (100%) rename {trainers => dictionary_learning/trainers}/top_k.py (100%) rename {trainers => dictionary_learning/trainers}/trainer.py (100%) rename training.py => dictionary_learning/training.py (100%) rename utils.py => dictionary_learning/utils.py (100%) create mode 100644 pyproject.toml delete mode 100644 requirements.txt create mode 100644 tests/unit/test_dictionary.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..c99b233 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,90 @@ +name: build + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Cache Huggingface assets + uses: actions/cache@v4 + with: + key: huggingface-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + path: ~/.cache/huggingface + restore-keys: | + huggingface-0-${{ runner.os }}-${{ matrix.python-version }}- + - name: Load cached Poetry installation + id: cached-poetry + uses: actions/cache@v4 + with: + path: ~/.local # the path depends on the OS + key: poetry-${{ runner.os }}-${{ matrix.python-version }}-1 # increment to reset cache + - name: Install Poetry + if: steps.cached-poetry.outputs.cache-hit != 'true' + uses: snok/install-poetry@v1 + with: + virtualenvs-create: true + virtualenvs-in-project: true + installer-parallel: true + - name: Load cached venv + id: cached-poetry-dependencies + uses: actions/cache@v4 + with: + path: .venv + key: venv-0-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + venv-0-${{ runner.os }}-${{ matrix.python-version }}- + - name: Install dependencies + if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' + run: poetry install --no-interaction + - name: Run Unit Tests + run: poetry run pytest tests/unit + - name: Build package + run: poetry build + + release: + needs: build + permissions: + contents: write + id-token: write + # https://github.community/t/how-do-i-specify-job-dependency-running-in-another-workflow/16482 + if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') + runs-on: ubuntu-latest + concurrency: release + environment: + name: pypi + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Semantic Release + id: release + uses: python-semantic-release/python-semantic-release@v8.0.7 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + - name: Publish package distributions to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + if: steps.release.outputs.released == 'true' + - name: Publish package distributions to GitHub Releases + uses: python-semantic-release/upload-to-gh-release@main + if: steps.release.outputs.released == 'true' + with: + github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 98cc45c..71326c1 100644 --- a/.gitignore +++ b/.gitignore @@ -99,7 +99,7 @@ ipython_config.py # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock +poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. diff --git a/README.md b/README.md index d01811b..febe79f 100644 --- a/README.md +++ b/README.md @@ -8,13 +8,9 @@ Some dictionaries trained using this repository (and associated training checkpo Navigate to the to the location where you would like to clone this repo, clone and enter the repo, and install the requirements. ```bash -git clone https://github.com/saprmarks/dictionary_learning -cd dictionary_learning -pip install -r requirements.txt +pip install dictionary-learning ``` -To use `dictionary_learning`, include it as a subdirectory in some project's directory and import it; see the examples below. - We also provide a [demonstration](https://github.com/adamkarvonen/dictionary_learning_demo), which trains and evaluates 2 SAEs in ~30 minutes before plotting the results. # Using trained dictionaries diff --git a/__init__.py b/__init__.py deleted file mode 100644 index d4f5e83..0000000 --- a/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder -from .buffer import ActivationBuffer \ No newline at end of file diff --git a/dictionary_learning/__init__.py b/dictionary_learning/__init__.py new file mode 100644 index 0000000..2067aaa --- /dev/null +++ b/dictionary_learning/__init__.py @@ -0,0 +1,6 @@ +__version__ = "0.1.0" + +from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder +from .buffer import ActivationBuffer + +__all__ = ["AutoEncoder", "GatedAutoEncoder", "JumpReluAutoEncoder", "ActivationBuffer"] diff --git a/buffer.py b/dictionary_learning/buffer.py similarity index 100% rename from buffer.py rename to dictionary_learning/buffer.py diff --git a/config.py b/dictionary_learning/config.py similarity index 100% rename from config.py rename to dictionary_learning/config.py diff --git a/dictionary.py b/dictionary_learning/dictionary.py similarity index 100% rename from dictionary.py rename to dictionary_learning/dictionary.py diff --git a/evaluation.py b/dictionary_learning/evaluation.py similarity index 100% rename from evaluation.py rename to dictionary_learning/evaluation.py diff --git a/grad_pursuit.py b/dictionary_learning/grad_pursuit.py similarity index 100% rename from grad_pursuit.py rename to dictionary_learning/grad_pursuit.py diff --git a/interp.py b/dictionary_learning/interp.py similarity index 99% rename from interp.py rename to dictionary_learning/interp.py index e721eb9..18ac308 100644 --- a/interp.py +++ b/dictionary_learning/interp.py @@ -188,4 +188,4 @@ def feature_umap( hover_name=df.index, color=colors, ) - raise ValueError("n_components must be 2 or 3") + raise ValueError("n_components must be 2 or 3") \ No newline at end of file diff --git a/trainers/__init__.py b/dictionary_learning/trainers/__init__.py similarity index 58% rename from trainers/__init__.py rename to dictionary_learning/trainers/__init__.py index 81998af..4135a82 100644 --- a/trainers/__init__.py +++ b/dictionary_learning/trainers/__init__.py @@ -5,3 +5,15 @@ from .top_k import TopKTrainer from .jumprelu import JumpReluTrainer from .batch_top_k import BatchTopKTrainer, BatchTopKSAE + + +__all__ = [ + "StandardTrainer", + "GatedSAETrainer", + "PAnnealTrainer", + "GatedAnnealTrainer", + "TopKTrainer", + "JumpReluTrainer", + "BatchTopKTrainer", + "BatchTopKSAE", +] diff --git a/trainers/batch_top_k.py b/dictionary_learning/trainers/batch_top_k.py similarity index 100% rename from trainers/batch_top_k.py rename to dictionary_learning/trainers/batch_top_k.py diff --git a/trainers/gated_anneal.py b/dictionary_learning/trainers/gated_anneal.py similarity index 100% rename from trainers/gated_anneal.py rename to dictionary_learning/trainers/gated_anneal.py diff --git a/trainers/gdm.py b/dictionary_learning/trainers/gdm.py similarity index 100% rename from trainers/gdm.py rename to dictionary_learning/trainers/gdm.py diff --git a/trainers/jumprelu.py b/dictionary_learning/trainers/jumprelu.py similarity index 100% rename from trainers/jumprelu.py rename to dictionary_learning/trainers/jumprelu.py diff --git a/trainers/matryoshka_batch_top_k.py b/dictionary_learning/trainers/matryoshka_batch_top_k.py similarity index 100% rename from trainers/matryoshka_batch_top_k.py rename to dictionary_learning/trainers/matryoshka_batch_top_k.py diff --git a/trainers/p_anneal.py b/dictionary_learning/trainers/p_anneal.py similarity index 100% rename from trainers/p_anneal.py rename to dictionary_learning/trainers/p_anneal.py diff --git a/trainers/standard.py b/dictionary_learning/trainers/standard.py similarity index 100% rename from trainers/standard.py rename to dictionary_learning/trainers/standard.py diff --git a/trainers/top_k.py b/dictionary_learning/trainers/top_k.py similarity index 100% rename from trainers/top_k.py rename to dictionary_learning/trainers/top_k.py diff --git a/trainers/trainer.py b/dictionary_learning/trainers/trainer.py similarity index 100% rename from trainers/trainer.py rename to dictionary_learning/trainers/trainer.py diff --git a/training.py b/dictionary_learning/training.py similarity index 100% rename from training.py rename to dictionary_learning/training.py diff --git a/utils.py b/dictionary_learning/utils.py similarity index 100% rename from utils.py rename to dictionary_learning/utils.py diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..28b2b9e --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,45 @@ +[tool.poetry] +name = "dictionary-learning" +version = "0.1.0" +description = "Dictionary learning via sparse autoencoders on neural network activations" +authors = ["Samuel Marks", "Adam Karvonen", "Aaron Mueller"] +packages = [{ include = "dictionary_learning" }] +license = "MIT" +readme = "README.md" +keywords = [ + "deep-learning", + "sparse-autoencoders", + "mechanistic-interpretability", + "PyTorch", +] +classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"] +repository = "https://github.com/saprmarks/dictionary_learning" +homepage = "https://github.com/saprmarks/dictionary_learning" + + +[tool.poetry.dependencies] +python = "^3.10" +circuitsvis = ">=1.43.2" +datasets = ">=2.18.0" +einops = ">=0.7.0" +nnsight = ">=0.3.0,<0.4.0" +pandas = ">=2.2.1" +plotly = ">=5.18.0" +tqdm = ">=4.66.1" +zstandard = ">=0.22.0" +wandb = ">=0.12.0" +umap-learn = ">=0.5.6" +llvmlite = ">=0.40.0" + +[tool.poetry.group.dev.dependencies] +pytest = "^8.3.4" + +[build-system] +requires = ["poetry-core>=2.0.0,<3.0.0"] +build-backend = "poetry.core.masonry.api" + +[tool.semantic_release] +version_variables = ["dictionary_learning/__init__.py:__version__"] +version_toml = ["pyproject.toml:tool.poetry.version"] +branch = "main" +build_command = "pip install poetry && poetry build" diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index bda16d1..0000000 --- a/requirements.txt +++ /dev/null @@ -1,13 +0,0 @@ -circuitsvis>=1.43.2 -datasets>=2.18.0 -einops>=0.7.0 -matplotlib>=3.8.3 -nnsight>=0.3.0 -pandas>=2.2.1 -plotly>=5.18.0 -torch>=2.1.2 -tqdm>=4.66.1 -umap-learn>=0.5.6 -zstandard>=0.22.0 -wandb>=0.12.0 -pytest>=6.2.4 \ No newline at end of file diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 4172129..797cbab 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -7,7 +7,11 @@ from dictionary_learning.training import trainSAE from dictionary_learning.trainers.standard import StandardTrainer from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK -from dictionary_learning.utils import hf_dataset_to_generator, get_nested_folders, load_dictionary +from dictionary_learning.utils import ( + hf_dataset_to_generator, + get_nested_folders, + load_dictionary, +) from dictionary_learning.buffer import ActivationBuffer from dictionary_learning.dictionary import ( AutoEncoder, @@ -62,10 +66,8 @@ def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing. I have observed that results can slightly vary with library versions. For full determinism, - use pytorch 2.5.1 and nnsight 0.3.7. + use pytorch 2.5.1 and nnsight 0.3.7.""" - NOTE: `dictionary_learning` is meant to be used as a submodule. Thus, to run this test, you need to use `dictionary_learning` as a submodule - and run the test from the root of the repository using `pytest -s`. Refer to https://github.com/adamkarvonen/dictionary_learning_demo for an example""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) diff --git a/tests/unit/test_dictionary.py b/tests/unit/test_dictionary.py new file mode 100644 index 0000000..232eb0e --- /dev/null +++ b/tests/unit/test_dictionary.py @@ -0,0 +1,136 @@ +import torch as t +import pytest +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) + + +@pytest.mark.parametrize( + "sae_cls", [AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder] +) +def test_forward_equals_decode_encode(sae_cls: type) -> None: + """Test that forward pass equals decode(encode(x)) for all SAE types""" + batch_size = 4 + act_dim = 8 + dict_size = 6 + x = t.randn(batch_size, act_dim) + + sae = sae_cls(activation_dim=act_dim, dict_size=dict_size) + + # Test without output_features + forward_out = sae(x) + encode_decode = sae.decode(sae.encode(x)) + assert t.allclose(forward_out, encode_decode) + + # Test with output_features + forward_out, features = sae(x, output_features=True) + encode_features = sae.encode(x) + assert t.allclose(features, encode_features) + + +def test_simple_autoencoder() -> None: + """Test AutoEncoder with simple weight matrices""" + sae = AutoEncoder(activation_dim=2, dict_size=2) + + # Set simple weights + with t.no_grad(): + sae.encoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.decoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.encoder.bias.data = t.zeros(2) + sae.bias.data = t.zeros(2) + + # Test encoding + x = t.tensor([[2.0, -1.0]]) + encoded = sae.encode(x) + assert t.allclose(encoded, t.tensor([[2.0, 0.0]])) # ReLU clips negative value + + # Test decoding + decoded = sae.decode(encoded) + assert t.allclose(decoded, t.tensor([[2.0, 0.0]])) + + +def test_simple_gated_autoencoder() -> None: + """Test GatedAutoEncoder with simple weight matrices""" + sae = GatedAutoEncoder(activation_dim=2, dict_size=2) + + # Set simple weights and biases + with t.no_grad(): + sae.encoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.decoder.weight.data = t.tensor([[1.0, 0.0], [0.0, 1.0]]) + sae.gate_bias.data = t.zeros(2) + sae.mag_bias.data = t.zeros(2) + sae.r_mag.data = t.zeros(2) + sae.decoder_bias.data = t.zeros(2) + + x = t.tensor([[2.0, -1.0]]) + encoded = sae.encode(x) + assert t.allclose( + encoded, t.tensor([[2.0, 0.0]]) + ) # Only positive values pass through + + +def test_normalize_decoder() -> None: + """Test that normalize_decoder maintains output while normalizing weights""" + sae = AutoEncoder(activation_dim=4, dict_size=3) + x = t.randn(2, 4) + + # Get initial output + initial_output = sae(x) + + # Normalize decoder + sae.normalize_decoder() + + # Check decoder weights are normalized + norms = t.norm(sae.decoder.weight, dim=0) + assert t.allclose(norms, t.ones_like(norms)) + + # Check output is maintained + new_output = sae(x) + assert t.allclose(initial_output, new_output, atol=1e-4) + + +def test_scale_biases() -> None: + """Test that scale_biases correctly scales all bias terms""" + sae = AutoEncoder(activation_dim=4, dict_size=3) + + # Record initial biases + initial_encoder_bias = sae.encoder.bias.data.clone() + initial_bias = sae.bias.data.clone() + + scale = 2.0 + sae.scale_biases(scale) + + assert t.allclose(sae.encoder.bias.data, initial_encoder_bias * scale) + assert t.allclose(sae.bias.data, initial_bias * scale) + + +@pytest.mark.parametrize( + "sae_cls", [AutoEncoder, GatedAutoEncoder, AutoEncoderNew, JumpReluAutoEncoder] +) +def test_output_shapes(sae_cls: type) -> None: + """Test that output shapes are correct for all operations""" + batch_size = 3 + act_dim = 4 + dict_size = 5 + x = t.randn(batch_size, act_dim) + + sae = sae_cls(activation_dim=act_dim, dict_size=dict_size) + + # Test encode shape + encoded = sae.encode(x) + assert encoded.shape == (batch_size, dict_size) + + # Test decode shape + decoded = sae.decode(encoded) + assert decoded.shape == (batch_size, act_dim) + + # Test forward shape with and without features + output = sae(x) + assert output.shape == (batch_size, act_dim) + + output, features = sae(x, output_features=True) + assert output.shape == (batch_size, act_dim) + assert features.shape == (batch_size, dict_size) From 07975f7a7c505042b6619846db18dbd122c4f4e6 Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 12 Feb 2025 06:55:15 +0000 Subject: [PATCH 63/70] 0.1.0 Automatically generated by python-semantic-release --- CHANGELOG.md | 669 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 669 insertions(+) create mode 100644 CHANGELOG.md diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..207ed78 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,669 @@ +# CHANGELOG + + + +## v0.1.0 (2025-02-12) + +### Feature + +* feat: pypi packaging and auto-release with semantic release ([`0ff8888`](https://github.com/saprmarks/dictionary_learning/commit/0ff88883e7caac8ebd7ea0d8e07585451d8b7f9f)) + +### Unknown + +* Merge pull request #37 from chanind/pypi-package + +feat: pypi packaging and auto-release with semantic release ([`a711efe`](https://github.com/saprmarks/dictionary_learning/commit/a711efe3b60aabc99a35e7279cd35fa8bf4c930a)) + +* simplify matryoshka loss ([`43421f5`](https://github.com/saprmarks/dictionary_learning/commit/43421f5934a1476cb3f32f0b9e1b5d14b84540a1)) + +* Use torch.split() instead of direct indexing for 25% speedup ([`505a445`](https://github.com/saprmarks/dictionary_learning/commit/505a4455358f079db9f2b0309cc0922169869965)) + +* Fix matryoshka spelling ([`aa45bf6`](https://github.com/saprmarks/dictionary_learning/commit/aa45bf6ed9aa981f6a266f333e6d4a8b9d459909)) + +* Fix incorrect auxk logging name ([`784a62a`](https://github.com/saprmarks/dictionary_learning/commit/784a62a405be4ee8754a76ad4d3e61fd7de06348)) + +* Add citation ([`77f2690`](https://github.com/saprmarks/dictionary_learning/commit/77f2690abcd56ce19aaf3c1404dcfcfc6cf9381b)) + +* Make sure to detach reconstruction before calculating aux loss ([`db2b564`](https://github.com/saprmarks/dictionary_learning/commit/db2b5642e2966559a907e4885bf3317ea997a494)) + +* Merge pull request #36 from saprmarks/aux_loss_fixes + +Aux loss fixes, standardize decoder normalization ([`34eefda`](https://github.com/saprmarks/dictionary_learning/commit/34eefdafcbcac784f3761abf5037c5178cbfd866)) + +* Standardize and fix topk auxk loss implementation ([`0af1971`](https://github.com/saprmarks/dictionary_learning/commit/0af19713feb5b4c35788039245013736bf974383)) + +* Normalize decoder after optimzer step ([`200ed3b`](https://github.com/saprmarks/dictionary_learning/commit/200ed3bed09c88d336c25a886eee4cb98c1e616e)) + +* Remove experimental matroyshka temperature ([`6c2fcfc`](https://github.com/saprmarks/dictionary_learning/commit/6c2fcfc2a8108ba720591eb414be6ab16157dc36)) + +* Make sure x is on the correct dtype for jumprelu when logging ([`c697d0f`](https://github.com/saprmarks/dictionary_learning/commit/c697d0f83984f0f257be2044231c30f2abb15aa1)) + +* Import trainers from correct relative location for submodule use ([`8363ff7`](https://github.com/saprmarks/dictionary_learning/commit/8363ff779eee04518edaac9d10d97e459f708b66)) + +* By default, don't normalize Gated activations during inference ([`52b0c54`](https://github.com/saprmarks/dictionary_learning/commit/52b0c54ba92630cfb2ae007f020ed447d4a5ba9f)) + +* Also update context manager for matroyshka threshold ([`65e7af8`](https://github.com/saprmarks/dictionary_learning/commit/65e7af80441e5b601114756afc36a4041cec152f)) + +* Disable autocast for threshold tracking ([`17aa5d5`](https://github.com/saprmarks/dictionary_learning/commit/17aa5d52f818545afe5fbbe3edf1f774cde92f44)) + +* Add torch autocast to training loop ([`832f4a3`](https://github.com/saprmarks/dictionary_learning/commit/832f4a32428cda68ec418aff9abe7dca66a9f66e)) + +* Save state dicts to cpu ([`3c5a5cd`](https://github.com/saprmarks/dictionary_learning/commit/3c5a5cdef682cbeb12e23b825f39709f518e2c0a)) + +* Add an option to pass LR to TopK trainers ([`8316a44`](https://github.com/saprmarks/dictionary_learning/commit/8316a4418dc4acb70ccad9854d3b05df1b817b9d)) + +* Add April Update Standard Trainer ([`cfb36ff`](https://github.com/saprmarks/dictionary_learning/commit/cfb36fff661fa60f38a2d1b372b6802517c08257)) + +* Merge pull request #35 from saprmarks/code_cleanup + +Consolidate LR Schedulers, Sparsity Schedulers, and constrained optimizers ([`f19db98`](https://github.com/saprmarks/dictionary_learning/commit/f19db98106302ed1d75dc8380160463ff812b1ad)) + +* Consolidate LR Schedulers, Sparsity Schedulers, and constrained optimizers ([`9751c57`](https://github.com/saprmarks/dictionary_learning/commit/9751c57731a25c04871e8173d16a0e4d902edc19)) + +* Merge pull request #34 from adamkarvonen/matroyshka + +Add Matroyshka, Fix Jump ReLU training, modify initialization ([`92648d4`](https://github.com/saprmarks/dictionary_learning/commit/92648d4e3d28aa397dbc89c43147aa6faf8874b7)) + +* Add a verbose option during training ([`0ff687b`](https://github.com/saprmarks/dictionary_learning/commit/0ff687bdc12cba66a0233825cb301df28da3a9db)) + +* Prevent wandb cuda multiprocessing errors ([`370272a`](https://github.com/saprmarks/dictionary_learning/commit/370272a4aac0ad0e59a2982073aa7b08970712b6)) + +* Log dead features for batch top k SAEs ([`936a69c`](https://github.com/saprmarks/dictionary_learning/commit/936a69c38a74980830f24fc851c40fb93abe8f07)) + +* Log number of dead features to wandb ([`77da794`](https://github.com/saprmarks/dictionary_learning/commit/77da7945f520f448b0524e476f539b3a44a4ca43)) + +* Add trainer number to wandb name ([`3b03b92`](https://github.com/saprmarks/dictionary_learning/commit/3b03b92b97d61a95e98b6f187dad97e939f6f977)) + +* Add notes ([`810dbb8`](https://github.com/saprmarks/dictionary_learning/commit/810dbb8bdce4ac6f1ce371872297b4f7a104e3f6)) + +* Add option to ignore bos tokens ([`c2fe5b8`](https://github.com/saprmarks/dictionary_learning/commit/c2fe5b89e78ae4a9d41a4809f4d00b8a3fcd0b36)) + +* Fix jumprelu training ([`ec961ac`](https://github.com/saprmarks/dictionary_learning/commit/ec961acde2244b98b26bcf796c3ec00b721088bb)) + +* Use kaiming initialization if specified in paper, fix batch_top_k aux_k_alpha ([`8eaa8b2`](https://github.com/saprmarks/dictionary_learning/commit/8eaa8b2407eabd714bbe7d55fd0c15fcb05fba24)) + +* Format with ruff ([`3e31571`](https://github.com/saprmarks/dictionary_learning/commit/3e31571b20d3e86823540882ec03c87b155d8e3d)) + +* Add temperature scaling to matroyshka ([`ceabbc5`](https://github.com/saprmarks/dictionary_learning/commit/ceabbc5233dcf28f0f5afd53e0de850d19f34d78)) + +* norm the correct decoder dimension ([`5383603`](https://github.com/saprmarks/dictionary_learning/commit/53836033b305142fb6d076a52a7679e0642ddb7a)) + +* Fix loading matroyshkas from_pretrained() ([`764d4ac`](https://github.com/saprmarks/dictionary_learning/commit/764d4ac4450ea6b7d79de52fdec70c7c1e0dfb79)) + +* Initial matroyshka implementation ([`8ade55b`](https://github.com/saprmarks/dictionary_learning/commit/8ade55b6eb57ed7c7b06a70187ee68e1056bb95b)) + +* Make sure we step the learning rate scheduler ([`1df47d8`](https://github.com/saprmarks/dictionary_learning/commit/1df47d83d9dea07d2fb905509b635ac6139bcd48)) + +* Merge pull request #33 from saprmarks/lr_scheduling + +Lr scheduling ([`316dbbe`](https://github.com/saprmarks/dictionary_learning/commit/316dbbe9a905bdab91fb2db63bbc61646e7039a6)) + +* Properly set new parameters in end to end test ([`e00fd64`](https://github.com/saprmarks/dictionary_learning/commit/e00fd643050584f4cfe15ad41e6a01e29e3c0766)) + +* Standardize learning rate and sparsity schedules ([`a2d6c43`](https://github.com/saprmarks/dictionary_learning/commit/a2d6c43e94ef068821441d47fef8ae7b3215d09e)) + +* Merge pull request #32 from saprmarks/add_sparsity_warmup + +Add sparsity warmup ([`a11670f`](https://github.com/saprmarks/dictionary_learning/commit/a11670fc6b96b1af3fe8a97175218041f2a9791f)) + +* Add sparsity warmup for trainers with a sparsity penalty ([`911b958`](https://github.com/saprmarks/dictionary_learning/commit/911b95890e20998df92710a01d158f4663d6834b)) + +* Clean up lr decay ([`e0db40b`](https://github.com/saprmarks/dictionary_learning/commit/e0db40b8fadcdd1e24c1945829ecd4eb57451fa8)) + +* Track lr decay implementation ([`f0bb66d`](https://github.com/saprmarks/dictionary_learning/commit/f0bb66d1c25bcb7dc8df62d8dbc3bfd47d26b14c)) + +* Remove leftover variable, update expected results with standard SAE improvements ([`9687bb9`](https://github.com/saprmarks/dictionary_learning/commit/9687bb9858ef05306227309af99cd5c09d91642a)) + +* Merge pull request #31 from saprmarks/add_demo + +Add option to normalize dataset, track thresholds for TopK SAEs, Fix Standard SAE ([`67a7857`](https://github.com/saprmarks/dictionary_learning/commit/67a7857ca63eb9299c340bc8f9804cdd569df1a9)) + +* Also scale topk thresholds when scaling biases ([`efd76b1`](https://github.com/saprmarks/dictionary_learning/commit/efd76b138f429bb8e5e969e2e45926e886fdd71b)) + +* Use the correct standard SAE reconstruction loss, initialize W_dec to W_enc.T ([`8b95ec9`](https://github.com/saprmarks/dictionary_learning/commit/8b95ec9b6e9a6d8d6255092e51b7580dccac70d6)) + +* Add bias scaling to topk saes ([`484ca01`](https://github.com/saprmarks/dictionary_learning/commit/484ca01f405e5791968883123718fd67ee35f299)) + +* Fix topk bfloat16 dtype error ([`488a154`](https://github.com/saprmarks/dictionary_learning/commit/488a1545922249cdb9ce5a5885c1931a5c21a37f)) + +* Add option to normalize dataset activations ([`81968f2`](https://github.com/saprmarks/dictionary_learning/commit/81968f2659082996539f08ea3188a5d2ed327696)) + +* Remove demo script and graphing notebook ([`57f451b`](https://github.com/saprmarks/dictionary_learning/commit/57f451b5635c4677ab47a4172aa588a5bdffdb4e)) + +* Track thresholds for topk and batchtopk during training ([`b5821fd`](https://github.com/saprmarks/dictionary_learning/commit/b5821fd87e3676e7a9ab6b87d423c03c57a344dd)) + +* Track threshold for batchtopk, rename for consistency ([`32d198f`](https://github.com/saprmarks/dictionary_learning/commit/32d198f738c61b0c1109f1803c43e01afb977d3e)) + +* Modularize demo script ([`dcc02f0`](https://github.com/saprmarks/dictionary_learning/commit/dcc02f04e504331011a54ce851a91976daf15170)) + +* Begin creation of demo script ([`712eb98`](https://github.com/saprmarks/dictionary_learning/commit/712eb98f78d9537aa3ff01a1d9e007361e67c267)) + +* Fix JumpReLU training and loading ([`552a8c2`](https://github.com/saprmarks/dictionary_learning/commit/552a8c2c12d41b5d520c99bf3534dff5329f0fde)) + +* Ensure activation buffer has the correct dtype ([`d416eab`](https://github.com/saprmarks/dictionary_learning/commit/d416eab5de1edfe8ea75c972cdf78d9de68642c2)) + +* Merge pull request #30 from adamkarvonen/add_tests + +Add end to end test, upgrade nnsight to support 0.3.0, fix bugs ([`c4eed3c`](https://github.com/saprmarks/dictionary_learning/commit/c4eed3cca27e93f0ad80cd49057cb862d03c86d7)) + +* Merge pull request #26 from mntss/batchtokp_aux_fix + +Fix BatchTopKSAE training ([`2ec1890`](https://github.com/saprmarks/dictionary_learning/commit/2ec18905045109ec0647bc127bacb794312fc2f6)) + +* Check for is_tuple to support mlp / attn submodules ([`d350415`](https://github.com/saprmarks/dictionary_learning/commit/d350415e119cacb6547703eb9733daf8ef57075b)) + +* Change save_steps to a list of ints ([`f1b9b80`](https://github.com/saprmarks/dictionary_learning/commit/f1b9b800bc8e2cc308d4d14690df71f854b30fce)) + +* Add early stopping in forward pass ([`05fe179`](https://github.com/saprmarks/dictionary_learning/commit/05fe179f5b0616310253deaf758c370071f534fa)) + +* Obtain better test results using multiple batches ([`067bf7b`](https://github.com/saprmarks/dictionary_learning/commit/067bf7b05470f61b9ed4f38b95be55c5ac45fb8f)) + +* Fix frac_alive calculation, perform evaluation over multiple batches ([`dc30720`](https://github.com/saprmarks/dictionary_learning/commit/dc3072089c24ce1eb8bc40e9f5248c69a92f5174)) + +* Complete nnsight 0.2 to 0.3 changes ([`807f6ef`](https://github.com/saprmarks/dictionary_learning/commit/807f6ef735872a5cab68773a315f15bc920c3d72)) + +* Rename input to inputs per nnsight 0.3.0 ([`9ed4af2`](https://github.com/saprmarks/dictionary_learning/commit/9ed4af245a22e095e932d6065d368c58947d9a3d)) + +* Add a simple end to end test ([`fe54b00`](https://github.com/saprmarks/dictionary_learning/commit/fe54b001cba976ca96d46add8539580268dc5cb6)) + +* Create LICENSE ([`32fec9c`](https://github.com/saprmarks/dictionary_learning/commit/32fec9c4556b3acaa709d756e8693edde1e74644)) + +* Fix BatchTopKSAE training ([`4aea538`](https://github.com/saprmarks/dictionary_learning/commit/4aea5388811284f4fd3daa8fb97916073bfe8841)) + +* dtype for loading SAEs ([`932e10a`](https://github.com/saprmarks/dictionary_learning/commit/932e10a46523347e8c2da70a10bb8e6dd42d17c6)) + +* Merge pull request #22 from pleask/jumprelu + +Implement jumprelu training ([`713f638`](https://github.com/saprmarks/dictionary_learning/commit/713f6389dde35177c83361f90daaba99b5ac3d08)) + +* Merge branch 'main' into jumprelu ([`099dbbf`](https://github.com/saprmarks/dictionary_learning/commit/099dbbfcdcad07dfc85dd65bfbd15ca9eece70a5)) + +* Merge pull request #21 from pleask/separate-wandb-runs + +Use separate wandb runs for each SAE being trained ([`df60f52`](https://github.com/saprmarks/dictionary_learning/commit/df60f52737f18ce0b1ecd2eb9e08d0706871442d)) + +* Merge branch 'main' into jumprelu ([`3dfc069`](https://github.com/saprmarks/dictionary_learning/commit/3dfc069d39ceeb33ce60581fc7cb17f08ec0e428)) + +* implement jumprelu training ([`16bdfd9`](https://github.com/saprmarks/dictionary_learning/commit/16bdfd95bc04000b89f81b0496df59f17653a2f8)) + +* handle no wandb ([`8164d32`](https://github.com/saprmarks/dictionary_learning/commit/8164d32ec79325d3cc31063098b9108386eb15cf)) + +* Merge pull request #20 from pleask/batchtopk + +Implement BatchTopK ([`b001fb0`](https://github.com/saprmarks/dictionary_learning/commit/b001fb0fd358efc7647acf835123a5e874a9a822)) + +* separate runs for each sae being trained ([`7d3b127`](https://github.com/saprmarks/dictionary_learning/commit/7d3b12778070b88fd39c439751973ac83afbe7a0)) + +* add batchtopk ([`f08e00b`](https://github.com/saprmarks/dictionary_learning/commit/f08e00b2585ab9a965984af4932614a2e408b6e3)) + +* Move f_gate to encoder's dtype ([`43bdb3b`](https://github.com/saprmarks/dictionary_learning/commit/43bdb3b903f7a45ee52b4d865f6d6b7bd60647a3)) + +* Ensure that x_hat is in correct dtype ([`3376f1b`](https://github.com/saprmarks/dictionary_learning/commit/3376f1bd9d05bedd03179475052d3a26a61fad7a)) + +* Preallocate buffer memory to lower peak VRAM usage when replenishing buffer ([`90aff63`](https://github.com/saprmarks/dictionary_learning/commit/90aff63b042c50c3c81a3977b62248254115e907)) + +* Perform logging outside of training loop to lower peak memory usage ([`57f8812`](https://github.com/saprmarks/dictionary_learning/commit/57f8812ff93d4d9ac437d29a74f1d920daa45515)) + +* Remove triton usage ([`475fece`](https://github.com/saprmarks/dictionary_learning/commit/475feceba9e47d6e74b17c87844253f0a209d75d)) + +* Revert to triton TopK implementation ([`d94697d`](https://github.com/saprmarks/dictionary_learning/commit/d94697df1783da8b6739e565c3a1bd297b8b1e98)) + +* Add relative reconstruction bias from GDM Gated SAE paper to evaluate() ([`8984b01`](https://github.com/saprmarks/dictionary_learning/commit/8984b0112e6f9eebcf869aba78ad713b2016d6a6)) + +* git push origin main:Merge branch 'ElanaPearl-small_bug_fixes' into main ([`2d586e4`](https://github.com/saprmarks/dictionary_learning/commit/2d586e417cd30473e1c608146df47eb5767e2527)) + +* simplifying readme ([`9c46e06`](https://github.com/saprmarks/dictionary_learning/commit/9c46e061eb3b29d055e7221ce92524c6546d2a59)) + +* simplify readme ([`5c96003`](https://github.com/saprmarks/dictionary_learning/commit/5c9600344e033b5a7834a48914e958b257bcb720)) + +* add missing imports ([`7f689d9`](https://github.com/saprmarks/dictionary_learning/commit/7f689d9a3a60d577a0d860ac306ae7ba0c71240a)) + +* fix arg name in trainer_config ([`9577d26`](https://github.com/saprmarks/dictionary_learning/commit/9577d26c92affa71a9dcc3a3b3f6cb905f230388)) + +* update sae training example code ([`9374546`](https://github.com/saprmarks/dictionary_learning/commit/937454616f087a6e30afa2ae5f6d52ea685ebfee)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`7d405f7`](https://github.com/saprmarks/dictionary_learning/commit/7d405f7d7555444c66121bc853ab027f49c408b0)) + +* GatedSAE: moved feature re-normalization into encode ([`f628c0e`](https://github.com/saprmarks/dictionary_learning/commit/f628c0ef2ec53d20ffd4d3d06f84100054c358e1)) + +* documenting JumpReLU SAE support ([`322b6c0`](https://github.com/saprmarks/dictionary_learning/commit/322b6c0c75767b7fe110d1454b9dcd4106bb942b)) + +* support for JumpReluAutoEncoders ([`57df4e7`](https://github.com/saprmarks/dictionary_learning/commit/57df4e75cbf181e3662058a6609ab2bb5921c9c4)) + +* Add submodule_name to PAnnealTrainer ([`ecdac03`](https://github.com/saprmarks/dictionary_learning/commit/ecdac0376285912d9468695c024b39100c663b07)) + +* host SAEs on huggingface ([`0ae37fe`](https://github.com/saprmarks/dictionary_learning/commit/0ae37feeb8beac0fce5036c6ff4188c86627775e)) + +* fixed batch loading in examine_dimension ([`82485d7`](https://github.com/saprmarks/dictionary_learning/commit/82485d78bcb6d3bcec67965743fac32e6d29ff37)) + +* Merge pull request #17 from saprmarks/collab + +Merge Collab Branch ([`cdf8222`](https://github.com/saprmarks/dictionary_learning/commit/cdf82227d24295fe8a83fbcfe785e6d6d4f2b997)) + +* moved experimental trainers to collab-dev ([`8d1d581`](https://github.com/saprmarks/dictionary_learning/commit/8d1d581f3df482c77ca99d0839f1677b19ca1ae7)) + +* Merge branch 'main' into collab ([`dda38b9`](https://github.com/saprmarks/dictionary_learning/commit/dda38b94a491261fd92bf9754f1c673221d7f270)) + +* Update README.md ([`4d6c6a6`](https://github.com/saprmarks/dictionary_learning/commit/4d6c6a6cb5816571e045f3c42c9f5b508d395d83)) + +* remove a sentence ([`2d40ed5`](https://github.com/saprmarks/dictionary_learning/commit/2d40ed598074c57904e9566d82bbd8ce27b661b5)) + +* add a list of trainers to the README ([`746927a`](https://github.com/saprmarks/dictionary_learning/commit/746927ae0b597e1fcb69aed58a5e9d4b6103732c)) + +* add architecture details to README ([`60422a8`](https://github.com/saprmarks/dictionary_learning/commit/60422a87231439425b9e27384352b03bc245365a)) + +* make wandb integration optional ([`a26c4e5`](https://github.com/saprmarks/dictionary_learning/commit/a26c4e57985458735bbf887685b679d16008de98)) + +* make wandb integration optional ([`0bdc871`](https://github.com/saprmarks/dictionary_learning/commit/0bdc871a95dae4de17b5116eda38f20d2375ebd1)) + +* Fix tutorial 404 ([`deb3df7`](https://github.com/saprmarks/dictionary_learning/commit/deb3df7906c8a0d00a4286f42cb65ae27667b2a7)) + +* Add missing values to config ([`9e44ea9`](https://github.com/saprmarks/dictionary_learning/commit/9e44ea9dc015c6bf919bd61aa40892be1da66dc3)) + +* changed TrainerTopK class name ([`c52ff00`](https://github.com/saprmarks/dictionary_learning/commit/c52ff008869a021b5f58d1beb80f8afe014757c5)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`c04ee3b`](https://github.com/saprmarks/dictionary_learning/commit/c04ee3b006ae72e69266d0ac2163035aee326b6a)) + +* fixed loss_recovered to incorporate top_k ([`6be5635`](https://github.com/saprmarks/dictionary_learning/commit/6be563540801caf185069051985b453dacc421d8)) + +* fixed TopK loss (spotted by Anish) ([`a3b71f7`](https://github.com/saprmarks/dictionary_learning/commit/a3b71f71212b839c8814ffa4223a5026837738c3)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`40bcdf6`](https://github.com/saprmarks/dictionary_learning/commit/40bcdf65b646f0e387d030b0c2211eaf07636b4c)) + +* naming conventions ([`5ff7fa1`](https://github.com/saprmarks/dictionary_learning/commit/5ff7fa101da07dfdb0663a484214b75c79e02fe0)) + +* small fix to triton kernel ([`5d21265`](https://github.com/saprmarks/dictionary_learning/commit/5d21265bd390d35b937d10d83cdf617151212cb3)) + +* small updates for eval ([`585e820`](https://github.com/saprmarks/dictionary_learning/commit/585e82070620771ee5bef4278d4d500b02983e0c)) + +* added some housekeeping stuff to top_k ([`5559c2c`](https://github.com/saprmarks/dictionary_learning/commit/5559c2c02d84df49531a631d3f4b29ef8acf94c4)) + +* add support for Top-k SAEs ([`2d549d0`](https://github.com/saprmarks/dictionary_learning/commit/2d549d0d98e400fedf4d7c4127d540f97240b89e)) + +* add transcoder eval ([`8446f4f`](https://github.com/saprmarks/dictionary_learning/commit/8446f4fc1aa9e7a08ece6e2fd59e6fa9583a7501)) + +* add transcoder support ([`c590a25`](https://github.com/saprmarks/dictionary_learning/commit/c590a254990691947b244e09849db7b288ed6bee)) + +* added wandb finish to trainer ([`113c042`](https://github.com/saprmarks/dictionary_learning/commit/113c042101b6df6de60b04c7e65116c3a9460904)) + +* fixed anneal end bug ([`fbd9ee4`](https://github.com/saprmarks/dictionary_learning/commit/fbd9ee41ed23d65cbaedb43447b64ae4117dab9a)) + +* added layer and lm_name ([`d173235`](https://github.com/saprmarks/dictionary_learning/commit/d17323572d23067bbba949732a842b1c2c149188)) + +* adding layer and lm_name to trainer config ([`6168ee0`](https://github.com/saprmarks/dictionary_learning/commit/6168ee0210308a42f3536f5bff19db70e91311ae)) + +* make tracer_args optional ([`31b2828`](https://github.com/saprmarks/dictionary_learning/commit/31b2828869bd560ac29eafbd3abf06f752063047)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`87d2b58`](https://github.com/saprmarks/dictionary_learning/commit/87d2b58da4b5714a44e6d301f2b5595e6bdd4296)) + +* bug fix evaluating CE loss with NNsight models ([`f8d81a1`](https://github.com/saprmarks/dictionary_learning/commit/f8d81a1d56b96f34c26fcc9f3feac0cb11ab3065)) + +* Combining P Annealing and Anthropic Update ([`44318e9`](https://github.com/saprmarks/dictionary_learning/commit/44318e999d6d123daad63fa399935ba339421070)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`43e9ca6`](https://github.com/saprmarks/dictionary_learning/commit/43e9ca63664dafd9a9f23f81b0bf57917a9f36ba)) + +* removing normalization ([`7a98d77`](https://github.com/saprmarks/dictionary_learning/commit/7a98d77318b3abcc0aab237de455eb33e20f691e)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`5f2b598`](https://github.com/saprmarks/dictionary_learning/commit/5f2b598cdbeb311f32c4ce1e2e816769240bb75e)) + +* added buffer for NNsight models (not LanguageModel classes) as an extra class. We'll want to combine the three buffers wo currently have at some point ([`f19d284`](https://github.com/saprmarks/dictionary_learning/commit/f19d2843f9fc64192ddac12f345a4ad910b96310)) + +* fixed nnsight issues model tracing for chess-gpt ([`7e8c9f9`](https://github.com/saprmarks/dictionary_learning/commit/7e8c9f95cd25bb6bc56def8210852841a30f22fd)) + +* added W_O projection to HeadBuffer ([`47bd4cd`](https://github.com/saprmarks/dictionary_learning/commit/47bd4cdea4a64563d2f8ba9ab39b246caf9f3c8c)) + +* added support for training SAEs on individual heads ([`a0e3119`](https://github.com/saprmarks/dictionary_learning/commit/a0e31199f2a02c86328bc4551f1d9a0b89d0d87b)) + +* added support for training SAEs on individual heads ([`47351b4`](https://github.com/saprmarks/dictionary_learning/commit/47351b4f6ca0bbd42981f73a784e1a395a941025)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`7de0bd3`](https://github.com/saprmarks/dictionary_learning/commit/7de0bd3d062693d8a35f309a6bc8b494c98408a3)) + +* default hyperparameter adjustments ([`a09346b`](https://github.com/saprmarks/dictionary_learning/commit/a09346b928a9782f57ec137b95d9e7636eda2abf)) + +* normalization in gated_new ([`104aba2`](https://github.com/saprmarks/dictionary_learning/commit/104aba291b0c17a5ec9e86655a281f457ce14cbc)) + +* fixing bug where inputs can get overwritten ([`93fd46e`](https://github.com/saprmarks/dictionary_learning/commit/93fd46e3884daf2fb2e17d952b7a4030b0129957)) + +* fixing tuple bug ([`b05dcaf`](https://github.com/saprmarks/dictionary_learning/commit/b05dcafc816370be8f0584700fd5a882be4a2e8f)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`73b5663`](https://github.com/saprmarks/dictionary_learning/commit/73b5663ed47c91c1c0d2fa8d47c029686fcf8a48)) + +* multiple steps debugging ([`de3eef1`](https://github.com/saprmarks/dictionary_learning/commit/de3eef10d322502f150dc63e9f71d84c9b777b71)) + +* adding gradient pursuit function ([`72941f1`](https://github.com/saprmarks/dictionary_learning/commit/72941f10e794401b2a6b682aa097f4db3f7aa1fe)) + +* bugfix ([`53aabc0`](https://github.com/saprmarks/dictionary_learning/commit/53aabc0ae45464fd3d1d1d384969fe7066d94a7a)) + +* bugfix ([`91691b5`](https://github.com/saprmarks/dictionary_learning/commit/91691b5b8da50f6b1d44eae501529d72b935752e)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`9ce7d80`](https://github.com/saprmarks/dictionary_learning/commit/9ce7d80ec96e7324c095ffef81039d0e6a896feb)) + +* logging more things ([`8498a75`](https://github.com/saprmarks/dictionary_learning/commit/8498a754acacca467494182dbf7444b34e1184c3)) + +* changing initialization for AutoEncoderNew ([`c7ee7ec`](https://github.com/saprmarks/dictionary_learning/commit/c7ee7ec8e7c4bc235cf969f7653a2d99f9bd5723)) + +* fixing gated SAE encoder scheme ([`4084bc3`](https://github.com/saprmarks/dictionary_learning/commit/4084bc3fa50f0764864630b2fe476722a9303b47)) + +* changes to gatedSAE API ([`9e001d1`](https://github.com/saprmarks/dictionary_learning/commit/9e001d170c752c1887c27942bf3d6336322a0ff0)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`05b397b`](https://github.com/saprmarks/dictionary_learning/commit/05b397bcc60f3026da7d55aefebfd3b2223273a6)) + +* changing initialization ([`ebe0d57`](https://github.com/saprmarks/dictionary_learning/commit/ebe0d57c62ebde85386fd7ec59157758e85d3ce3)) + +* finished combining gated and p-annealing ([`4c08614`](https://github.com/saprmarks/dictionary_learning/commit/4c08614403d51328c672983cb28aaaee846092bc)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`8e0a6f9`](https://github.com/saprmarks/dictionary_learning/commit/8e0a6f998ded264270c019ee9b14ffb9c31d650a)) + +* gated_anneal first steps ([`ba8b8fa`](https://github.com/saprmarks/dictionary_learning/commit/ba8b8fa1efda86ea843b0a837f98f106ab089448)) + +* jump SAE ([`873b764`](https://github.com/saprmarks/dictionary_learning/commit/873b764b5a17bdc8704a4a871362e1b03de3ef5f)) + +* adapted loss logging in p_anneal ([`33997c0`](https://github.com/saprmarks/dictionary_learning/commit/33997c05699862896191dc683c9622efc3e97f95)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`1eecbda`](https://github.com/saprmarks/dictionary_learning/commit/1eecbdaf651dffa1ed4962d79d3f2577d1979e91)) + +* merging gated and Anthropic SAEs ([`b6a24d0`](https://github.com/saprmarks/dictionary_learning/commit/b6a24d001234e38c2f6b4c52215d65fdcb50a09e)) + +* revert trainer naming ([`c0af6d9`](https://github.com/saprmarks/dictionary_learning/commit/c0af6d9c20fda36ee700e2884611fd12edc3fb59)) + +* restored trainer naming ([`2ec3c67`](https://github.com/saprmarks/dictionary_learning/commit/2ec3c6768d21b019ba12d0065876011e85bc2aae)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`fe7e93b`](https://github.com/saprmarks/dictionary_learning/commit/fe7e93bf606a6c8e2e2d13335565455359905345)) + +* various changes ([`32027ae`](https://github.com/saprmarks/dictionary_learning/commit/32027ae3781e367affc23c3fde5fc504ef49ebc4)) + +* debug panneal ([`463907d`](https://github.com/saprmarks/dictionary_learning/commit/463907dab4ee91254d0ae674752d3c3803a8044d)) + +* debug panneal ([`8c00100`](https://github.com/saprmarks/dictionary_learning/commit/8c00100423223dc21d57ee4114f9ef6b38ee209e)) + +* debug panneal ([`dc632cd`](https://github.com/saprmarks/dictionary_learning/commit/dc632cd69df0c1719ebed7bdd677d7373f37dc74)) + +* debug panneal ([`166f6a9`](https://github.com/saprmarks/dictionary_learning/commit/166f6a9e582d45728d1e8291c6ab451dbb7a35fd)) + +* debug panneal ([`bcebaa6`](https://github.com/saprmarks/dictionary_learning/commit/bcebaa6b2adedaecd0779d6551929b3a213aef1e)) + +* debug pannealing ([`446c568`](https://github.com/saprmarks/dictionary_learning/commit/446c568d32ff7c93c9688c193dd459abe9086ed5)) + +* p_annealing loss buffer ([`e4d4a35`](https://github.com/saprmarks/dictionary_learning/commit/e4d4a3532536d9b450f39c034c0aabd8e95560fa)) + +* implement Ben's p-annealing strategy ([`06a27f0`](https://github.com/saprmarks/dictionary_learning/commit/06a27f096c0e62df695d60d9e1ec7df77c305498)) + +* panneal changes ([`fe4ff6f`](https://github.com/saprmarks/dictionary_learning/commit/fe4ff6fa5d0c85942b45fffa0bb2908f4d13a2aa)) + +* logging trainer names to wandb ([`f9c5e45`](https://github.com/saprmarks/dictionary_learning/commit/f9c5e45a85ed345fdd95502a4de7a873c25f8456)) + +* bugfixes for StandardTrainerNew ([`70acd85`](https://github.com/saprmarks/dictionary_learning/commit/70acd8572b1b5250ac38cb9be04069cb1a6f981e)) + +* trainer for new anthropic infrastructure ([`531c285`](https://github.com/saprmarks/dictionary_learning/commit/531c28596cbbe296c45a2a6e22ea175e8633f2a1)) + +* adding r_mag parameter to GSAE ([`198ddf4`](https://github.com/saprmarks/dictionary_learning/commit/198ddf4bd4210b95a11b0a29862fe615d1774fe0)) + +* gatedSAE trainer ([`3567d6d`](https://github.com/saprmarks/dictionary_learning/commit/3567d6d2a2cb6d810df32b838029daacc354aaaa)) + +* cosmetic change ([`0200976`](https://github.com/saprmarks/dictionary_learning/commit/0200976ba04409d477e5321b586b844dd545b976)) + +* GatedAutoEncoder class ([`2cfc47b`](https://github.com/saprmarks/dictionary_learning/commit/2cfc47b42e89c294e14c98969a993f5910604211)) + +* p annealing not affected by resampling ([`ad8d837`](https://github.com/saprmarks/dictionary_learning/commit/ad8d8371411067c6a031d87faf08f4ec2fe96032)) + +* integrated trainer update ([`c7613d3`](https://github.com/saprmarks/dictionary_learning/commit/c7613d386a5677451f6da6f9260ceb9d28a3a4d4)) + +* Merge branch 'collab' into p_annealing ([`933b80c`](https://github.com/saprmarks/dictionary_learning/commit/933b80c91a3e49e2e7a761422c629588774370eb)) + +* fixed p calculation ([`9837a6f`](https://github.com/saprmarks/dictionary_learning/commit/9837a6fa4e88303b7694aa3556485661fa512f1c)) + +* getting rid of useless seed arguement ([`377c762`](https://github.com/saprmarks/dictionary_learning/commit/377c762d9a9333aed42ad097d393796fcf8a7e57)) + +* trainer initializes SAE ([`7dffb66`](https://github.com/saprmarks/dictionary_learning/commit/7dffb663a0dcc5f5e3c2855e24e9f8b322704bcc)) + +* trainer initialized SAE ([`6e80590`](https://github.com/saprmarks/dictionary_learning/commit/6e80590fb441c53df70345bfd20da4fbad7c9cf9)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`c58d23d`](https://github.com/saprmarks/dictionary_learning/commit/c58d23d5a6e2d38c0ff47e42b157f1686f7e98a6)) + +* changes to lista p_anneal trainers ([`3cc6642`](https://github.com/saprmarks/dictionary_learning/commit/3cc6642b414608e5d0e86c733b0855f927afa52c)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`9dfd3db`](https://github.com/saprmarks/dictionary_learning/commit/9dfd3dbf42d3ad35b0bb32f9d8374ac00201edda)) + +* decoupled lr warmup and p warmup in p_anneal trainer ([`c3c1645`](https://github.com/saprmarks/dictionary_learning/commit/c3c164540476d69ff4c3bfa7f9a1a4532c4603c0)) + +* Merge pull request #14 from saprmarks/p_annealing + +added annealing and trainer_param_callback ([`61927bc`](https://github.com/saprmarks/dictionary_learning/commit/61927bcf99537a15651a9829a6a261cffad9e65f)) + +* cosmetic changes to interp ([`4a7966f`](https://github.com/saprmarks/dictionary_learning/commit/4a7966f979ea4b660613c980cdefd48494511955)) + +* Merge branch 'collab' of https://github.com/saprmarks/dictionary_learning into collab ([`c76818e`](https://github.com/saprmarks/dictionary_learning/commit/c76818e4dbf7e980251a6f652529e50cd1b1b7de)) + +* Merge pull request #13 from jannik-brinkmann/collab + +add ListaTrainer ([`d4d2fd9`](https://github.com/saprmarks/dictionary_learning/commit/d4d2fd9b57a4ab380a56b1b5fa0faf1d91a29989)) + +* additional evluation metrics ([`fa2ec08`](https://github.com/saprmarks/dictionary_learning/commit/fa2ec081e2ff42377eb98b031320933806b2faf7)) + +* add GroupSAETrainer ([`60e6068`](https://github.com/saprmarks/dictionary_learning/commit/60e6068924a42b8252d11b398b9972205b46ece4)) + +* added annealing and trainer_param_callback ([`18e3fca`](https://github.com/saprmarks/dictionary_learning/commit/18e3fcaaf5428e998d26a0be80f1be56ffea7981)) + +* Merge remote-tracking branch 'upstream/collab' into collab ([`4650c2a`](https://github.com/saprmarks/dictionary_learning/commit/4650c2a7db87c7ca32db043cb15db8a28450a013)) + +* fixing neuron resampling ([`a346be9`](https://github.com/saprmarks/dictionary_learning/commit/a346be9abc6644fd59ae493e44ef8fdbd1e339e2)) + +* improvements to saving and logging ([`4a1d7ae`](https://github.com/saprmarks/dictionary_learning/commit/4a1d7ae76d59713fe0c4722e821ad3882c0aa757)) + +* can export buffer config ([`d19d8d9`](https://github.com/saprmarks/dictionary_learning/commit/d19d8d956da3e04ab899b93fc67c63b0a7bd5020)) + +* fixing evaluation.py ([`c91a581`](https://github.com/saprmarks/dictionary_learning/commit/c91a5815e4e11197a8031d21193381f9b596b95c)) + +* fixing bug in neuron resampling ([`67a03c7`](https://github.com/saprmarks/dictionary_learning/commit/67a03c763feec3bcebd9070389b8481257bdf10b)) + +* add ListaTrainer ([`880f570`](https://github.com/saprmarks/dictionary_learning/commit/880f5706a42c337e021530855166089b6722e1df)) + +* fixing neuron resampling in standard trainer ([`3406262`](https://github.com/saprmarks/dictionary_learning/commit/3406262b31dd97f29130532d694aecd62f092f80)) + +* improvements to training and evaluating ([`b111d40`](https://github.com/saprmarks/dictionary_learning/commit/b111d40898d97123722cda60084f46d0766cd3e2)) + +* Factoring out SAETrainer class ([`fabd001`](https://github.com/saprmarks/dictionary_learning/commit/fabd001d97f869c01e67ea26f2e02822eba9ab82)) + +* updating syntax for buffer ([`035a0f9`](https://github.com/saprmarks/dictionary_learning/commit/035a0f9d4ffa8e7307ae637fb801a78c0ea9eb95)) + +* updating readme for from_pretrained ([`70e8c2a`](https://github.com/saprmarks/dictionary_learning/commit/70e8c2a13682ef12658f92b459c1bf552cb78180)) + +* from_pretrained ([`db96abc`](https://github.com/saprmarks/dictionary_learning/commit/db96abc96be7ba975bb09a41c7a81b13c2ea5f3e)) + +* Change syntax for specifying activation dimensions and batch sizes ([`bdf1f19`](https://github.com/saprmarks/dictionary_learning/commit/bdf1f19b292b152b3c4601fc7a77fc6d66cd04c0)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`86c7475`](https://github.com/saprmarks/dictionary_learning/commit/86c7475a945c0a70c0a82c914d9733c8d2bcc651)) + +* activation_dim for IdentityDict is optional ([`be1b68c`](https://github.com/saprmarks/dictionary_learning/commit/be1b68c0df0de955d722f1739f5c115dfbfbf702)) + +* update umap requirement ([`776b53e`](https://github.com/saprmarks/dictionary_learning/commit/776b53e506a2c720139d056542a3397d883e2c79)) + +* Merge pull request #10 from adamkarvonen/shell_script_change + +Add sae_set_name to local_path for dictionary downloader ([`33b5a6b`](https://github.com/saprmarks/dictionary_learning/commit/33b5a6be4ea3c76aa918178f2dfcd3f7c81e2b97)) + +* Add sae_set_name to local_path for dictionary downloader ([`d6163be`](https://github.com/saprmarks/dictionary_learning/commit/d6163be200d28653394c2b9adac540c7a27e2659)) + +* dispatch no longer needed when loading models ([`69c32ca`](https://github.com/saprmarks/dictionary_learning/commit/69c32ca6fcf1c94c4b7fb7ac8b82fe7257123400)) + +* removed in_and_out option for activation buffer ([`cf6ad1d`](https://github.com/saprmarks/dictionary_learning/commit/cf6ad1d72de9fc11acba34e73a03799e2b893692)) + +* updating readme with 10_32768 dictionaries ([`614883f`](https://github.com/saprmarks/dictionary_learning/commit/614883f9476613e7c1c48b951cd3947451e1f534)) + +* upgrade to nnsight 0.2 ([`cbc5f79`](https://github.com/saprmarks/dictionary_learning/commit/cbc5f7991c9233579c36b4972c6273f3f250f0ef)) + +* downloader script ([`7a305c5`](https://github.com/saprmarks/dictionary_learning/commit/7a305c583dbbf06f3dbb223387dc3536a489b0de)) + +* fixing device issue in buffer ([`b1b44f1`](https://github.com/saprmarks/dictionary_learning/commit/b1b44f12e176e73544d863d1d41009a284bc1db5)) + +* added pretrained_dictionary_downloader.sh ([`0028ebe`](https://github.com/saprmarks/dictionary_learning/commit/0028ebe739ac90e2587a86b92b0aa4b2c0b8497e)) + +* added pretrained_dictionary_downloader.sh ([`8b63d8d`](https://github.com/saprmarks/dictionary_learning/commit/8b63d8d6d74f51c00b191519d383de7f6052df0b)) + +* added pretrained_dictionary_downloader.sh ([`6771aff`](https://github.com/saprmarks/dictionary_learning/commit/6771aff6543b320e14fb3db99e0c6fd2613cc905)) + +* efficiency improvements ([`94844d4`](https://github.com/saprmarks/dictionary_learning/commit/94844d4fa9ce4a593faf9b709cf61a45447f84f3)) + +* adding identity dict ([`76bd32f`](https://github.com/saprmarks/dictionary_learning/commit/76bd32fe87bf3c7f3ce45d13d6fe6a69c81e05b4)) + +* debugging interp ([`2f75db3`](https://github.com/saprmarks/dictionary_learning/commit/2f75db31233b1296af97c2002194888715355759)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`86812f5`](https://github.com/saprmarks/dictionary_learning/commit/86812f5dae6a4ebc1605f3b067c27d7b8b96001e)) + +* warns user when evaluating without enough data ([`246c472`](https://github.com/saprmarks/dictionary_learning/commit/246c472d7efb845875c4aa67a8e0dfd417c28f6d)) + +* cleaning up interp ([`95d7310`](https://github.com/saprmarks/dictionary_learning/commit/95d7310ef39ed2fe7a496d0a63a142fe569bdcf5)) + +* examine_dimension returns mbottom_tokens and logit stats ([`40137ff`](https://github.com/saprmarks/dictionary_learning/commit/40137ffe47d9c3ee03e9b46f994c5bd98f5b953e)) + +* continuing merge ([`db693a6`](https://github.com/saprmarks/dictionary_learning/commit/db693a6c4c290bb670f37c0a7e222e25b6b916c6)) + +* progress on merge ([`949b3a7`](https://github.com/saprmarks/dictionary_learning/commit/949b3a755c1458e7d216cc02dc5bf7d8e8f62a1a)) + +* changes to buffer.py ([`792546b`](https://github.com/saprmarks/dictionary_learning/commit/792546b35c45fda3e93abcb0f8cc28f70d0e439c)) + +* fixing some things in buffer.py ([`f58688e`](https://github.com/saprmarks/dictionary_learning/commit/f58688e574f5353f906a470abfbcc386730fdda6)) + +* updating requirements ([`a54b496`](https://github.com/saprmarks/dictionary_learning/commit/a54b4961a7ac9996566a3c32f4d216968afac7b1)) + +* updating requirements ([`a1db591`](https://github.com/saprmarks/dictionary_learning/commit/a1db5917be710c046736574a48bc7f0c2ea98506)) + +* identity dictionary ([`5e1f35e`](https://github.com/saprmarks/dictionary_learning/commit/5e1f35e09abc20c6ee7bc43cfba6231d97121403)) + +* bug fix for neuron resampling ([`b281b53`](https://github.com/saprmarks/dictionary_learning/commit/b281b538c1de2b5ce220b429dd3ea4be44c5b72f)) + +* UMAP visualizations ([`81f8e1f`](https://github.com/saprmarks/dictionary_learning/commit/81f8e1f164def236423e53b89da37d50c115fc62)) + +* better normalization for ghost_loss ([`fc74af7`](https://github.com/saprmarks/dictionary_learning/commit/fc74af75ca2d9d4fdbca6fefb3feb583ef11583d)) + +* neuron resampling without replacement ([`4565e9a`](https://github.com/saprmarks/dictionary_learning/commit/4565e9a14975a4a2d9c736ba7c5551b6c9685ae2)) + +* simplifications to interp functions ([`2318666`](https://github.com/saprmarks/dictionary_learning/commit/231866665154d80e933b5d9ab5be5de5a522c398)) + +* Second nnsight 0.2 pass through ([`3bcebed`](https://github.com/saprmarks/dictionary_learning/commit/3bcebedb801d5654edb3fc7118144953af2366da)) + +* Conversion to nnsight 0.2 first pass ([`cac410a`](https://github.com/saprmarks/dictionary_learning/commit/cac410a72e52cbd6f359fd69bd6fdb346923a9e1)) + +* detaching another thing in ghost grads ([`2f212d6`](https://github.com/saprmarks/dictionary_learning/commit/2f212d6cab348d565633f1bdc0d3e305a6e98d42)) + +* Neuron resampling no longer errors when resampling zero neurons ([`376dd3b`](https://github.com/saprmarks/dictionary_learning/commit/376dd3b51b1433625386ca357c61497a13b6bf6d)) + +* NNsight v0.2 Updates ([`90bbc76`](https://github.com/saprmarks/dictionary_learning/commit/90bbc762aaf369a138f544f2e1f3a4e7a6b5fc4a)) + +* cosmetic improvements to buffer.py ([`b2bd5f0`](https://github.com/saprmarks/dictionary_learning/commit/b2bd5f09cc7f657b7121f0659514d81336903bba)) + +* fix to ghost grads ([`9531fe5`](https://github.com/saprmarks/dictionary_learning/commit/9531fe5f65a23acb32e0f1c96920d67bb1bed15b)) + +* fixing table formatting ([`0e69c8c`](https://github.com/saprmarks/dictionary_learning/commit/0e69c8cc7c446db0ddf86da984417965714ec7ec)) + +* Fixing some table formatting ([`75f927f`](https://github.com/saprmarks/dictionary_learning/commit/75f927f4c722db4c05d64d732b1d025ecdc186aa)) + +* gpt2-small support ([`f82146c`](https://github.com/saprmarks/dictionary_learning/commit/f82146cf586e53407d639ef81f64e1be481a666b)) + +* fixing bug relevant to UnifiedTransformer support ([`9ec9ce4`](https://github.com/saprmarks/dictionary_learning/commit/9ec9ce494384ab303db26be066b3a8004230a16a)) + +* Getting rid of histograms ([`31d09d7`](https://github.com/saprmarks/dictionary_learning/commit/31d09d7136d97c553b8f06c1074ef08ea65be879)) + +* Fixing tables in readme ([`5934011`](https://github.com/saprmarks/dictionary_learning/commit/59340116bb24cbc01cefa76c52641b5b5b46a340)) + +* Updates to the readme ([`a5ca51e`](https://github.com/saprmarks/dictionary_learning/commit/a5ca51ea13cfcd4bb286d644e3416a9af3b5fc53)) + +* Fixing ghost grad bugs ([`633d583`](https://github.com/saprmarks/dictionary_learning/commit/633d583ddaa3090039fca3f1f3e8820ded942e76)) + +* Handling ghost grad case with no dead neurons ([`4f19425`](https://github.com/saprmarks/dictionary_learning/commit/4f19425a4e09ea93bb7ebaad436c2ef227cb420e)) + +* adding support for buffer on other devices ([`f3cf296`](https://github.com/saprmarks/dictionary_learning/commit/f3cf296fe00bf547412f7d500b8993796e30a8b9)) + +* support for ghost grads ([`25d2a62`](https://github.com/saprmarks/dictionary_learning/commit/25d2a62fcaa8bc9be048b5e37aa57441e78262b5)) + +* add an implementation of ghost gradients ([`2e09210`](https://github.com/saprmarks/dictionary_learning/commit/2e09210099d991d45500488dac9654d141815530)) + +* fixing a bug with warmup, adding utils ([`47bbde1`](https://github.com/saprmarks/dictionary_learning/commit/47bbde13f47010bbebf6ac393ae3cdc59b804e9d)) + +* remove HF arg from buffer. rename search_utils to interp ([`7276f17`](https://github.com/saprmarks/dictionary_learning/commit/7276f17288286429162432af6a30763fa80f8117)) + +* typo fix ([`3f6b922`](https://github.com/saprmarks/dictionary_learning/commit/3f6b922c031f9b31652c3998f0ce1e985629c62a)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`278084b`](https://github.com/saprmarks/dictionary_learning/commit/278084b0a54e5804a064358a1fb28bc007e4fae4)) + +* added utils for converting hf dataset to generator ([`82fff19`](https://github.com/saprmarks/dictionary_learning/commit/82fff1968ae883afec82d14246041df793ffd170)) + +* add ablated token effects to ; restore support for HF datasets ([`799e2ca`](https://github.com/saprmarks/dictionary_learning/commit/799e2caeb3f4f4f922531cfb3b14dd34d999ae9d)) + +* merge in function for examining features ([`986bf96`](https://github.com/saprmarks/dictionary_learning/commit/986bf9646e82f35186c74ce88e6c6e4dc1c8470f)) + +* easier submodule/dictionary feature examination ([`2c8b985`](https://github.com/saprmarks/dictionary_learning/commit/2c8b98567e1908a4279efc342f46bd4bd72ab618)) + +* Adding lr warmup after every time neurons are resampled ([`429c582`](https://github.com/saprmarks/dictionary_learning/commit/429c582f84be12d6c326b131f926b33d48698c7b)) + +* fixing issues with EmptyStream exception ([`39ff6e1`](https://github.com/saprmarks/dictionary_learning/commit/39ff6e1cdccb438d335c39c36656657f974f585f)) + +* Minor changes due to updates in nnsight ([`49bbbac`](https://github.com/saprmarks/dictionary_learning/commit/49bbbac6a653398be8726587c2c634e0fd831f02)) + +* Revert "restore support for streaming HF datasets" + +This reverts commit b43527b9b6b24521f6eba68242dc22c3c68173d8. ([`23ada98`](https://github.com/saprmarks/dictionary_learning/commit/23ada983527a748887b7481e255b8dfdb310a23d)) + +* restore support for streaming HF datasets ([`b43527b`](https://github.com/saprmarks/dictionary_learning/commit/b43527b9b6b24521f6eba68242dc22c3c68173d8)) + +* first version of automatic feature labeling ([`c6753f6`](https://github.com/saprmarks/dictionary_learning/commit/c6753f62967503583aae33978b0684d5af0947e5)) + +* Add feature_effect function to search_utils.py ([`0ada2c6`](https://github.com/saprmarks/dictionary_learning/commit/0ada2c654b2dcc71e14869afc813b3adce445472)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`fab70b1`](https://github.com/saprmarks/dictionary_learning/commit/fab70b1b1a17fbe46fbdc54ea34095457c8cbe64)) + +* adding sqrt to MSE ([`63b2174`](https://github.com/saprmarks/dictionary_learning/commit/63b217449c651c78da68571bb032563ac73ebd71)) + +* Merge pull request #1 from cadentj/main + +Update README.md ([`fd79bb3`](https://github.com/saprmarks/dictionary_learning/commit/fd79bb34a7cb56bd987ce8a24764a72586999431)) + +* Update README.md ([`cf5ec24`](https://github.com/saprmarks/dictionary_learning/commit/cf5ec240bcb31db7007dceb7b4362967b044fd01)) + +* Update README.md ([`55f33f2`](https://github.com/saprmarks/dictionary_learning/commit/55f33f226d94baace938501d741ccfb5e9816a56)) + +* evaluation.py ([`2edf59e`](https://github.com/saprmarks/dictionary_learning/commit/2edf59ebb2a625e0862cecd5e4d84249589d95b9)) + +* evaluating dictionaries ([`71e28fb`](https://github.com/saprmarks/dictionary_learning/commit/71e28fbfa2976b099e849c766176252fa8d9fbc2)) + +* Removing experimental use of sqrt on MSELoss ([`865bbb5`](https://github.com/saprmarks/dictionary_learning/commit/865bbb58fdd1af681a2a435f546f4f6dceaaf930)) + +* Adding readme, evaluation, cleaning up ([`ddac948`](https://github.com/saprmarks/dictionary_learning/commit/ddac948a7971e526a47a9dae7311a25c0c56a81c)) + +* some stuff for saving dicts ([`d1f0e21`](https://github.com/saprmarks/dictionary_learning/commit/d1f0e21afc6395ddec71e274bbd3075750f4a76f)) + +* removing device from buffer ([`398f15c`](https://github.com/saprmarks/dictionary_learning/commit/398f15cb5d44ba81e12dee5299841a983e9f54df)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`7f013c2`](https://github.com/saprmarks/dictionary_learning/commit/7f013c2441620391eabba4f408deaa14140a5239)) + +* lr schedule + enabling stretched mlp ([`4eaf7e3`](https://github.com/saprmarks/dictionary_learning/commit/4eaf7e35e8c1c461da761a71968d8e9d1ef0c6b3)) + +* add random feature search ([`e58cc67`](https://github.com/saprmarks/dictionary_learning/commit/e58cc67cb8303b48cf40cb52e586d464f8cb6b48)) + +* restore HF support and progress bar ([`7e2b6c6`](https://github.com/saprmarks/dictionary_learning/commit/7e2b6c69aa7095680affe58c4251577f96505915)) + +* Merge branch 'main' of https://github.com/saprmarks/dictionary_learning into main ([`d33ef05`](https://github.com/saprmarks/dictionary_learning/commit/d33ef052e7d4175a5042855c04a6f3b60acb07ff)) + +* more support for saving checkpints ([`0ca258a`](https://github.com/saprmarks/dictionary_learning/commit/0ca258af3775910ce20d4cce541ff4de962bef3d)) + +* fix unit column bug + add scheduler ([`5a05c8c`](https://github.com/saprmarks/dictionary_learning/commit/5a05c8cd1b29894e8ba77b115727f1511c3334bd)) + +* fix merge bugs: checkpointing support ([`9c5bbd8`](https://github.com/saprmarks/dictionary_learning/commit/9c5bbd8a3ac82e8611434d7ba95da172a80a44a0)) + +* Merge: add HF datasets and checkpointing ([`ccf6ed1`](https://github.com/saprmarks/dictionary_learning/commit/ccf6ed1d9fdc7c0df68c879c893f919d8c192b83)) + +* checkpointing, progress bar, HF dataset support ([`fd8a3ee`](https://github.com/saprmarks/dictionary_learning/commit/fd8a3ee3ee70354191c4d8ecce9d4f8b878d40c6)) + +* progress bar for training autoencoders ([`0a8064d`](https://github.com/saprmarks/dictionary_learning/commit/0a8064dd7ef93904c4b5b4edb9fc7ddbc1e42af1)) + +* implementing neuron resampling ([`f9b9d02`](https://github.com/saprmarks/dictionary_learning/commit/f9b9d020cd5c2daf857d44de2c956a6df2cf7cc3)) + +* lotsa stuff ([`bc09ba4`](https://github.com/saprmarks/dictionary_learning/commit/bc09ba48a701900311d7049dab52549b8239cb15)) + +* adding __init__.py file for imports ([`3d9fd43`](https://github.com/saprmarks/dictionary_learning/commit/3d9fd43957b8c35e1d6377aa33341f663ae5d289)) + +* modifying buffer ([`ba9441b`](https://github.com/saprmarks/dictionary_learning/commit/ba9441b444cd56b2a01c341357d3ede11b06e2b6)) + +* first commit ([`ea89e90`](https://github.com/saprmarks/dictionary_learning/commit/ea89e90e3f737ec8e2a339cfd0b2f1a1082ef850)) + +* Initial commit ([`741f4d6`](https://github.com/saprmarks/dictionary_learning/commit/741f4d6e1d07e55f6c6df5340cc22b9c7f8d49b7)) From 944edd1b0c588ba288df556a4cfbc90523b78fda Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 02:02:02 +0000 Subject: [PATCH 64/70] Add a pytorch activation buffer, enable model truncation --- dictionary_learning/pytorch_buffer.py | 225 ++++++++++++++++++++++ dictionary_learning/utils.py | 62 +++++- tests/test_end_to_end.py | 8 +- tests/test_pytorch_end_to_end.py | 261 ++++++++++++++++++++++++++ 4 files changed, 548 insertions(+), 8 deletions(-) create mode 100644 dictionary_learning/pytorch_buffer.py create mode 100644 tests/test_pytorch_end_to_end.py diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py new file mode 100644 index 0000000..720101f --- /dev/null +++ b/dictionary_learning/pytorch_buffer.py @@ -0,0 +1,225 @@ +import torch as t +from transformers import AutoModelForCausalLM, AutoTokenizer +import gc +from tqdm import tqdm +import contextlib + + +class EarlyStopException(Exception): + """Custom exception for stopping model forward pass early.""" + + pass + + +def collect_activations( + model: AutoModelForCausalLM, + submodule: t.nn.Module, + inputs_BL: dict[str, t.Tensor], + use_no_grad: bool = True, +) -> t.Tensor: + """ + Registers a forward hook on the submodule to capture the residual (or hidden) + activations. We then raise an EarlyStopException to skip unneeded computations. + + Args: + model: The model to run. + submodule: The submodule to hook into. + inputs_BL: The inputs to the model. + use_no_grad: Whether to run the forward pass within a `t.no_grad()` context. Defaults to True. + """ + activations_BLD = None + + def gather_target_act_hook(module, inputs, outputs): + nonlocal activations_BLD + # For many models, the submodule outputs are a tuple or a single tensor: + # If "outputs" is a tuple, pick the relevant item: + # e.g. if your layer returns (hidden, something_else), you'd do outputs[0] + # Otherwise just do outputs + if isinstance(outputs, tuple): + activations_BLD = outputs[0] + else: + activations_BLD = outputs + + raise EarlyStopException("Early stopping after capturing activations") + + handle = submodule.register_forward_hook(gather_target_act_hook) + + # Determine the context manager based on the flag + context_manager = t.no_grad() if use_no_grad else contextlib.nullcontext() + + try: + # Use the selected context manager + with context_manager: + _ = model(**inputs_BL) + except EarlyStopException: + pass + except Exception as e: + print(f"Unexpected error during forward pass: {str(e)}") + raise + finally: + handle.remove() + + if activations_BLD is None: + # This should ideally not happen if the hook worked and EarlyStopException was raised, + # but handle it just in case. + raise RuntimeError( + "Failed to collect activations. The hook might not have run correctly." + ) + + return activations_BLD + + +class ActivationBuffer: + """ + Implements a buffer of activations. The buffer stores activations from a model, + yields them in batches, and refreshes them when the buffer is less than half full. + """ + + def __init__( + self, + data, # generator which yields text data + model: AutoModelForCausalLM, # Language Model from which to extract activations + submodule, # submodule of the model from which to extract activations + d_submodule=None, # submodule dimension; if None, try to detect automatically + io="out", # can be 'in' or 'out'; whether to extract input or output activations + n_ctxs=3e4, # approximate number of contexts to store in the buffer + ctx_len=128, # length of each context + refresh_batch_size=512, # size of batches in which to process the data when adding to buffer + out_batch_size=8192, # size of batches in which to yield activations + device="cpu", # device on which to store the activations + remove_bos: bool = False, + add_special_tokens: bool = True, + ): + if io not in ["in", "out"]: + raise ValueError("io must be either 'in' or 'out'") + + if d_submodule is None: + try: + if io == "in": + d_submodule = submodule.in_features + else: + d_submodule = submodule.out_features + except: + raise ValueError( + "d_submodule cannot be inferred and must be specified directly" + ) + self.activations = t.empty(0, d_submodule, device=device, dtype=model.dtype) + self.read = t.zeros(0).bool() + + self.data = data + self.model = model + self.submodule = submodule + self.d_submodule = d_submodule + self.io = io + self.n_ctxs = n_ctxs + self.ctx_len = ctx_len + self.activation_buffer_size = n_ctxs * ctx_len + self.refresh_batch_size = refresh_batch_size + self.out_batch_size = out_batch_size + self.device = device + self.remove_bos = remove_bos + self.add_special_tokens = add_special_tokens + self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) + + if not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def __iter__(self): + return self + + def __next__(self): + """ + Return a batch of activations + """ + with t.no_grad(): + # if buffer is less than half full, refresh + if (~self.read).sum() < self.activation_buffer_size // 2: + self.refresh() + + # return a batch + unreads = (~self.read).nonzero().squeeze() + idxs = unreads[ + t.randperm(len(unreads), device=unreads.device)[: self.out_batch_size] + ] + self.read[idxs] = True + return self.activations[idxs] + + def text_batch(self, batch_size=None): + """ + Return a list of text + """ + if batch_size is None: + batch_size = self.refresh_batch_size + try: + return [next(self.data) for _ in range(batch_size)] + except StopIteration: + raise StopIteration("End of data stream reached") + + def tokenized_batch(self, batch_size=None): + """ + Return a batch of tokenized inputs. + """ + texts = self.text_batch(batch_size=batch_size) + return self.tokenizer( + texts, + return_tensors="pt", + max_length=self.ctx_len, + padding=True, + truncation=True, + add_special_tokens=self.add_special_tokens, + ).to(self.device) + + def refresh(self): + gc.collect() + t.cuda.empty_cache() + self.activations = self.activations[~self.read] + + current_idx = len(self.activations) + new_activations = t.empty( + self.activation_buffer_size, + self.d_submodule, + device=self.device, + dtype=self.model.dtype, + ) + + new_activations[: len(self.activations)] = self.activations + self.activations = new_activations + + # Optional progress bar when filling buffer. At larger models / buffer sizes (e.g. gemma-2-2b, 1M tokens on a 4090) this can take a couple minutes. + # pbar = tqdm(total=self.activation_buffer_size, initial=current_idx, desc="Refreshing activations") + + while current_idx < self.activation_buffer_size: + with t.no_grad(): + input = self.tokenized_batch() + hidden_states = collect_activations(self.model, self.submodule, input) + attn_mask = input["attention_mask"] + if self.remove_bos: + hidden_states = hidden_states[:, 1:, :] + attn_mask = attn_mask[:, 1:] + hidden_states = hidden_states[attn_mask != 0] + + remaining_space = self.activation_buffer_size - current_idx + assert remaining_space > 0 + hidden_states = hidden_states[:remaining_space] + + self.activations[current_idx : current_idx + len(hidden_states)] = ( + hidden_states.to(self.device) + ) + current_idx += len(hidden_states) + + # pbar.update(len(hidden_states)) + + # pbar.close() + self.read = t.zeros(len(self.activations), dtype=t.bool, device=self.device) + + @property + def config(self): + return { + "d_submodule": self.d_submodule, + "io": self.io, + "n_ctxs": self.n_ctxs, + "ctx_len": self.ctx_len, + "refresh_batch_size": self.refresh_batch_size, + "out_batch_size": self.out_batch_size, + "device": self.device, + } diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py index 3b1077e..537754c 100644 --- a/dictionary_learning/utils.py +++ b/dictionary_learning/utils.py @@ -3,7 +3,11 @@ import io import json import os -from nnsight import LanguageModel +from transformers import AutoModelForCausalLM +from fractions import Fraction +import random +from transformers import AutoTokenizer +import torch as t from .trainers.top_k import AutoEncoderTopK from .trainers.batch_top_k import BatchTopKSAE @@ -88,13 +92,61 @@ def load_dictionary(base_path: str, device: str) -> tuple: return dictionary, config -def get_submodule(model: LanguageModel, layer: int): +def get_submodule(model: AutoModelForCausalLM, layer: int): """Gets the residual stream submodule""" - model_name = model._model_key + model_name = model.name_or_path - if "pythia" in model_name: + if model.config.architectures[0] == "GPTNeoXForCausalLM": return model.gpt_neox.layers[layer] - elif "gemma" in model_name: + elif ( + model.config.architectures[0] == "Qwen2ForCausalLM" + or model.config.architectures[0] == "Gemma2ForCausalLM" + ): return model.model.layers[layer] else: raise ValueError(f"Please add submodule for model {model_name}") + + +def truncate_model(model: AutoModelForCausalLM, layer: int): + """From tilde-research/activault + https://github.com/tilde-research/activault/blob/db6d1e4e36c2d3eb4fdce79e72be94f387eccee1/pipeline/setup.py#L74 + This provides significant memory savings by deleting all layers that aren't needed for the given layer. + You should probably test this before using it""" + import gc + + total_params_before = sum(p.numel() for p in model.parameters()) + print(f"Model parameters before truncation: {total_params_before:,}") + + if ( + model.config.architectures[0] == "Qwen2ForCausalLM" + or model.config.architectures[0] == "Gemma2ForCausalLM" + ): + removed_layers = model.model.layers[layer + 1 :] + + model.model.layers = model.model.layers[: layer + 1] + + del removed_layers + del model.lm_head + + model.lm_head = t.nn.Identity() + + elif model.config.architectures[0] == "GPTNeoXForCausalLM": + removed_layers = model.gpt_neox.layers[layer + 1 :] + + model.gpt_neox.layers = model.gpt_neox.layers[: layer + 1] + + del removed_layers + del model.embed_out + + model.embed_out = t.nn.Identity() + + else: + raise ValueError(f"Please add truncation for model {model.name_or_path}") + + total_params_after = sum(p.numel() for p in model.parameters()) + print(f"Model parameters after truncation: {total_params_after:,}") + + gc.collect() + t.cuda.empty_cache() + + return model diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 797cbab..055db18 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -59,14 +59,16 @@ LAYER = 3 DATASET_NAME = "monology/pile-uncopyrighted" -EVAL_TOLERANCE = 0.01 +EVAL_TOLERANCE_PERCENT = 0.005 def test_sae_training(): """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. This isn't a nice suite of unit tests, but it's better than nothing. I have observed that results can slightly vary with library versions. For full determinism, - use pytorch 2.5.1 and nnsight 0.3.7.""" + use pytorch 2.5.1 and nnsight 0.3.7. + Unfortunately an RTX 3090 is also required for full determinism. On an H100 the results are off by ~0.3%, meaning this test will + not be within the EVAL_TOLERANCE.""" random.seed(RANDOM_SEED) t.manual_seed(RANDOM_SEED) @@ -257,4 +259,4 @@ def test_evaluation(): max_diff_percent = max(max_diff_percent, diff / value) print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") - assert max_diff < EVAL_TOLERANCE + assert max_diff_percent < EVAL_TOLERANCE_PERCENT diff --git a/tests/test_pytorch_end_to_end.py b/tests/test_pytorch_end_to_end.py new file mode 100644 index 0000000..79ef5b3 --- /dev/null +++ b/tests/test_pytorch_end_to_end.py @@ -0,0 +1,261 @@ +import torch as t +from transformers import AutoModelForCausalLM, AutoTokenizer +import os +import json +import random + +from dictionary_learning.training import trainSAE +from dictionary_learning.trainers.standard import StandardTrainer +from dictionary_learning.trainers.top_k import TopKTrainer, AutoEncoderTopK +from dictionary_learning.utils import ( + hf_dataset_to_generator, + get_nested_folders, + load_dictionary, +) + +# from dictionary_learning.buffer import ActivationBuffer +from dictionary_learning.pytorch_buffer import ActivationBuffer +from dictionary_learning.dictionary import ( + AutoEncoder, + GatedAutoEncoder, + AutoEncoderNew, + JumpReluAutoEncoder, +) +from dictionary_learning.evaluation import evaluate + +EXPECTED_RESULTS = { + "AutoEncoderTopK": { + "l2_loss": 4.358876752853393, + "l1_loss": 50.90618553161621, + "l0": 40.0, + "frac_variance_explained": 0.9577824175357819, + "cossim": 0.9476200461387634, + "l2_ratio": 0.9476299166679383, + "relative_reconstruction_bias": 0.9996505916118622, + "frac_alive": 1.0, + }, + "AutoEncoder": { + "l2_loss": 6.8308186531066895, + "l1_loss": 19.398421669006346, + "l0": 37.4469970703125, + "frac_variance_explained": 0.9003101229667664, + "cossim": 0.8782103300094605, + "l2_ratio": 0.7444103538990021, + "relative_reconstruction_bias": 0.960041344165802, + "frac_alive": 0.9970703125, + }, +} + +DEVICE = "cuda:0" +SAVE_DIR = "./test_data" +MODEL_NAME = "EleutherAI/pythia-70m-deduped" +RANDOM_SEED = 42 +LAYER = 3 +DATASET_NAME = "monology/pile-uncopyrighted" + +EVAL_TOLERANCE_PERCENT = 0.005 + + +def test_sae_training(): + """End to end test for training an SAE. Takes ~2 minutes on an RTX 3090. + This isn't a nice suite of unit tests, but it's better than nothing. + I have observed that results can slightly vary with library versions. For full determinism, + use pytorch 2.5.1 and nnsight 0.3.7. + Unfortunately an RTX 3090 is also required for full determinism. On an H100 the results are off by ~0.3%, meaning this test will + not be within the EVAL_TOLERANCE.""" + + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + # model = LanguageModel(MODEL_NAME, dispatch=True, device_map=DEVICE) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, device_map="auto", torch_dtype=t.float32 + ).to(DEVICE) + + context_length = 128 + llm_batch_size = 512 # Fits on a 24GB GPU + sae_batch_size = 8192 + num_contexts_per_sae_batch = sae_batch_size // context_length + + num_inputs_in_buffer = num_contexts_per_sae_batch * 20 + + num_tokens = 10_000_000 + + # sae training parameters + k = 40 + sparsity_penalty = 2.0 + expansion_factor = 8 + + steps = int(num_tokens / sae_batch_size) # Total number of batches to train + save_steps = None + warmup_steps = 1000 # Warmup period at start of training and after each resample + resample_steps = None + + # standard sae training parameters + learning_rate = 3e-4 + + # topk sae training parameters + decay_start = None + auxk_alpha = 1 / 32 + + submodule = model.gpt_neox.layers[LAYER] + submodule_name = f"resid_post_layer_{LAYER}" + io = "out" + activation_dim = model.config.hidden_size + + generator = hf_dataset_to_generator(DATASET_NAME) + + activation_buffer = ActivationBuffer( + generator, + model, + submodule, + n_ctxs=num_inputs_in_buffer, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + # create the list of configs + trainer_configs = [] + trainer_configs.extend( + [ + { + "trainer": TopKTrainer, + "dict_class": AutoEncoderTopK, + "lr": None, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "k": k, + "auxk_alpha": auxk_alpha, # see Appendix A.2 + "warmup_steps": 0, + "decay_start": decay_start, # when does the lr decay start + "steps": steps, # when when does training end + "seed": RANDOM_SEED, + "wandb_name": f"TopKTrainer-{MODEL_NAME}-{submodule_name}", + "device": DEVICE, + "layer": LAYER, + "lm_name": MODEL_NAME, + "submodule_name": submodule_name, + }, + ] + ) + trainer_configs.extend( + [ + { + "trainer": StandardTrainer, + "dict_class": AutoEncoder, + "activation_dim": activation_dim, + "dict_size": expansion_factor * activation_dim, + "lr": learning_rate, + "l1_penalty": sparsity_penalty, + "warmup_steps": warmup_steps, + "sparsity_warmup_steps": None, + "decay_start": decay_start, + "steps": steps, + "resample_steps": resample_steps, + "seed": RANDOM_SEED, + "wandb_name": f"StandardTrainer-{MODEL_NAME}-{submodule_name}", + "layer": LAYER, + "lm_name": MODEL_NAME, + "device": DEVICE, + "submodule_name": submodule_name, + }, + ] + ) + + print(f"len trainer configs: {len(trainer_configs)}") + output_dir = f"{SAVE_DIR}/{submodule_name}" + + trainSAE( + data=activation_buffer, + trainer_configs=trainer_configs, + steps=steps, + save_steps=save_steps, + save_dir=output_dir, + ) + + folders = get_nested_folders(output_dir) + + assert len(folders) == 2 + + for folder in folders: + dictionary, config = load_dictionary(folder, DEVICE) + + assert dictionary is not None + assert config is not None + + +def test_evaluation(): + random.seed(RANDOM_SEED) + t.manual_seed(RANDOM_SEED) + + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, device_map="auto", torch_dtype=t.float32 + ).to(DEVICE) + ae_paths = get_nested_folders(SAVE_DIR) + + context_length = 128 + llm_batch_size = 100 + sae_batch_size = 4096 + n_batches = 10 + buffer_size = 256 + io = "out" + + generator = hf_dataset_to_generator(DATASET_NAME) + submodule = model.gpt_neox.layers[LAYER] + + input_strings = [] + for i, example in enumerate(generator): + input_strings.append(example) + if i > buffer_size * n_batches: + break + + for ae_path in ae_paths: + dictionary, config = load_dictionary(ae_path, DEVICE) + dictionary = dictionary.to(dtype=model.dtype) + + activation_dim = config["trainer"]["activation_dim"] + context_length = config["buffer"]["ctx_len"] + + activation_buffer_data = iter(input_strings) + + activation_buffer = ActivationBuffer( + activation_buffer_data, + model, + submodule, + n_ctxs=buffer_size, + ctx_len=context_length, + refresh_batch_size=llm_batch_size, + out_batch_size=sae_batch_size, + io=io, + d_submodule=activation_dim, + device=DEVICE, + ) + + eval_results = evaluate( + dictionary, + activation_buffer, + context_length, + llm_batch_size, + io=io, + device=DEVICE, + n_batches=n_batches, + ) + + print(eval_results) + + dict_class = config["trainer"]["dict_class"] + expected_results = EXPECTED_RESULTS[dict_class] + + max_diff = 0 + max_diff_percent = 0 + for key, value in expected_results.items(): + diff = abs(eval_results[key] - value) + max_diff = max(max_diff, diff) + max_diff_percent = max(max_diff_percent, diff / value) + + print(f"Max diff: {max_diff}, max diff %: {max_diff_percent}") + assert max_diff_percent < EVAL_TOLERANCE_PERCENT From c644ccd20d6b17e1bc1199e981cf8539c5c06f0a Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 02:02:20 +0000 Subject: [PATCH 65/70] Add better dataset generators --- dictionary_learning/buffer.py | 10 +- dictionary_learning/utils.py | 189 ++++++++++++++++++++++++++++++++++ 2 files changed, 196 insertions(+), 3 deletions(-) diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index 7d304b7..67e1c0f 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -28,6 +28,7 @@ def __init__(self, out_batch_size=8192, # size of batches in which to yield activations device='cpu', # device on which to store the activations remove_bos: bool = False, + add_special_tokens: bool = True, ): if io not in ['in', 'out']: @@ -56,7 +57,8 @@ def __init__(self, self.out_batch_size = out_batch_size self.device = device self.remove_bos = remove_bos - + self.add_special_tokens = add_special_tokens + def __iter__(self): return self @@ -98,7 +100,8 @@ def tokenized_batch(self, batch_size=None): return_tensors='pt', max_length=self.ctx_len, padding=True, - truncation=True + truncation=True, + add_special_tokens=self.add_special_tokens ) def refresh(self): @@ -117,8 +120,9 @@ def refresh(self): while current_idx < self.activation_buffer_size: with t.no_grad(): + tokens = self.tokenized_batch() with self.model.trace( - self.text_batch(), + tokens, **tracer_kwargs, invoker_args={"truncation": True, "max_length": self.ctx_len}, ): diff --git a/dictionary_learning/utils.py b/dictionary_learning/utils.py index 537754c..6f2d2c0 100644 --- a/dictionary_learning/utils.py +++ b/dictionary_learning/utils.py @@ -47,6 +47,195 @@ def generator(): return generator() +def randomly_remove_system_prompt( + text: str, freq: float, system_prompt: str | None = None +) -> str: + if system_prompt and random.random() < freq: + assert system_prompt in text + text = text.replace(system_prompt, "") + return text + + +def hf_mixed_dataset_to_generator( + tokenizer: AutoTokenizer, + pretrain_dataset: str = "HuggingFaceFW/fineweb", + chat_dataset: str = "lmsys/lmsys-chat-1m", + min_chars: int = 1, + pretrain_frac: float = 0.9, # 0.9 → 90 % pretrain, 10 % chat + split: str = "train", + streaming: bool = True, + pretrain_key: str = "text", + chat_key: str = "conversation", + sequence_pack_pretrain: bool = True, + sequence_pack_chat: bool = False, + system_prompt_to_remove: str | None = None, + system_prompt_removal_freq: float = 0.9, +): + """Get a mix of pretrain and chat data at a specified ratio. By default, 90% of the data will be pretrain and 10% will be chat. + + Default datasets: + pretrain_dataset: "HuggingFaceFW/fineweb" + chat_dataset: "lmsys/lmsys-chat-1m" + + Note that you will have to request permission for lmsys (instant approval on HuggingFace). + + min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens. + Samples will be joined with the eos token. + If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't. + + Why use strings instead of tokens? Because dictionary learning expects an iterator of strings, and this is simple and good enough. + + Implicit assumption: each sample will be truncated to sequence length when tokenized. + + By default, we sequence pack the pretrain data and DO NOT sequence pack the chat data, as it would look kind of weird. The EOS token is used to separate + user / assistant messages, not to separate conversations from different users. + If you want to sequence pack the chat data, set sequence_pack_chat to True. + + Pretrain format will be: texttexttext... + Chat format will be Optionally: ... + + Other parameters: + - system_prompt_to_remove: an optional string that will be removed from the chat data with a given frequency. + You probably want to verify that the system prompt you pass in is correct. + - system_prompt_removal_freq: the frequency with which the system prompt will be removed + + Why? Well, we probably don't want to have 1000's of copies of the system prompt in the training dataset. But we also may not want to remove it entirely. + And we may want to use the LLM with no system prompt when comparing between models. + IDK, this is a complicated and annoying detail. At least this constrains the complexity to the dataset generator. + """ + if not 0 < pretrain_frac < 1: + raise ValueError("main_frac must be between 0 and 1 (exclusive)") + + assert min_chars > 0 + + # Load both datasets as iterable streams + pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming)) + chat_ds = iter(load_dataset(chat_dataset, split=split, streaming=streaming)) + + # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10) + frac = Fraction(pretrain_frac).limit_denominator() + n_pretrain = frac.numerator + n_chat = frac.denominator - n_pretrain + eos_token = tokenizer.eos_token + + bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token + + def gen(): + while True: + for _ in range(n_pretrain): + if sequence_pack_pretrain: + length = 0 + samples = [] + while length < min_chars: + # Add bos token to the beginning of the sample + sample = next(pretrain_ds)[pretrain_key] + samples.append(sample) + length += len(sample) + samples = bos_token + eos_token.join(samples) + yield samples + else: + sample = bos_token + next(pretrain_ds)[pretrain_key] + yield sample + for _ in range(n_chat): + if sequence_pack_chat: + length = 0 + samples = [] + while length < min_chars: + sample = next(chat_ds)[chat_key] + # Apply chat template also includes bos token + sample = tokenizer.apply_chat_template(sample, tokenize=False) + sample = randomly_remove_system_prompt( + sample, system_prompt_removal_freq, system_prompt_to_remove + ) + samples.append(sample) + length += len(sample) + samples = "".join(samples) + yield samples + else: + sample = tokenizer.apply_chat_template( + next(chat_ds)[chat_key], tokenize=False + ) + sample = randomly_remove_system_prompt( + sample, system_prompt_removal_freq, system_prompt_to_remove + ) + yield sample + + return gen() + + +def hf_sequence_packing_dataset_to_generator( + tokenizer: AutoTokenizer, + pretrain_dataset: str = "HuggingFaceFW/fineweb", + min_chars: int = 1, + split: str = "train", + streaming: bool = True, + pretrain_key: str = "text", + sequence_pack_pretrain: bool = True, +): + """min_chars: minimum number of characters per sample. To perform sequence packing, set it to ~4x sequence length in tokens. + Samples will be joined with the eos token. + If it's low (like 1), each sample will just be a single row from the dataset, padded to the max length. Sometimes this will fill the context, sometimes it won't.""" + assert min_chars > 0 + + # Load both datasets as iterable streams + pretrain_ds = iter(load_dataset(pretrain_dataset, split=split, streaming=streaming)) + + eos_token = tokenizer.eos_token + + bos_token = tokenizer.bos_token if tokenizer.bos_token else eos_token + + def gen(): + while True: + if sequence_pack_pretrain: + length = 0 + samples = [] + while length < min_chars: + # Add bos token to the beginning of the sample + sample = next(pretrain_ds)[pretrain_key] + samples.append(sample) + length += len(sample) + samples = bos_token + eos_token.join(samples) + yield samples + else: + sample = bos_token + next(pretrain_ds)[pretrain_key] + yield sample + + return gen() + + +def simple_hf_mixed_dataset_to_generator( + main_name: str, + aux_name: str, + main_frac: float = 0.9, # 0.9 → 90 % main, 10 % aux + split: str = "train", + streaming: bool = True, + main_key: str = "text", + aux_key: str = "text", +): + if not 0 < main_frac < 1: + raise ValueError("main_frac must be between 0 and 1 (exclusive)") + + # Load both datasets as iterable streams + main_ds = iter(load_dataset(main_name, split=split, streaming=streaming)) + aux_ds = iter(load_dataset(aux_name, split=split, streaming=streaming)) + + # Convert the fraction to two small integers (e.g. 0.9 → 9 / 10) + frac = Fraction(main_frac).limit_denominator() + n_main = frac.numerator + n_aux = frac.denominator - n_main + + def gen(): + while True: + # Yield `n_main` items from the main dataset + for _ in range(n_main): + yield next(main_ds)[main_key] + # Yield `n_aux` items from the auxiliary dataset + for _ in range(n_aux): + yield next(aux_ds)[aux_key] + + return gen() + + def get_nested_folders(path: str) -> list[str]: """ Recursively get a list of folders that contain an ae.pt file, starting the search from the given path From 17a41c76335c038a0da3ae6f6589ba2afc6a19b3 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 02:35:03 +0000 Subject: [PATCH 66/70] Add optional backup step --- dictionary_learning/training.py | 54 ++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/dictionary_learning/training.py b/dictionary_learning/training.py index 50f20ee..0671f31 100644 --- a/dictionary_learning/training.py +++ b/dictionary_learning/training.py @@ -126,6 +126,7 @@ def trainSAE( verbose:bool=False, device:str="cuda", autocast_dtype: t.dtype = t.float32, + backup_steps:Optional[int]=None, ): """ Train SAEs using the given trainers @@ -214,23 +215,42 @@ def trainSAE( # saving if save_steps is not None and step in save_steps: for dir, trainer in zip(save_dirs, trainers): - if dir is not None: - - if normalize_activations: - # Temporarily scale up biases for checkpoint saving - trainer.ae.scale_biases(norm_factor) - - if not os.path.exists(os.path.join(dir, "checkpoints")): - os.mkdir(os.path.join(dir, "checkpoints")) - - checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} - t.save( - checkpoint, - os.path.join(dir, "checkpoints", f"ae_{step}.pt"), - ) - - if normalize_activations: - trainer.ae.scale_biases(1 / norm_factor) + if dir is None: + continue + + if normalize_activations: + # Temporarily scale up biases for checkpoint saving + trainer.ae.scale_biases(norm_factor) + + if not os.path.exists(os.path.join(dir, "checkpoints")): + os.mkdir(os.path.join(dir, "checkpoints")) + + checkpoint = {k: v.cpu() for k, v in trainer.ae.state_dict().items()} + t.save( + checkpoint, + os.path.join(dir, "checkpoints", f"ae_{step}.pt"), + ) + + if normalize_activations: + trainer.ae.scale_biases(1 / norm_factor) + + # backup + if backup_steps is not None and step % backup_steps == 0: + for save_dir, trainer in zip(save_dirs, trainers): + if save_dir is None: + continue + # save the current state of the trainer for resume if training is interrupted + # this will be overwritten by the next checkpoint and at the end of training + t.save( + { + "step": step, + "ae": trainer.ae.state_dict(), + "optimizer": trainer.optimizer.state_dict(), + "config": trainer.config, + "norm_factor": norm_factor, + }, + os.path.join(save_dir, "ae.pt"), + ) # training for trainer in trainers: From fe9d8c7c144ec85eff87154f3704afaa6a24e671 Mon Sep 17 00:00:00 2001 From: Adam Karvonen Date: Tue, 13 May 2025 03:29:26 +0000 Subject: [PATCH 67/70] add activault buffer implementation --- dictionary_learning/activault_s3_buffer.py | 744 +++++++++++++++++++++ 1 file changed, 744 insertions(+) create mode 100644 dictionary_learning/activault_s3_buffer.py diff --git a/dictionary_learning/activault_s3_buffer.py b/dictionary_learning/activault_s3_buffer.py new file mode 100644 index 0000000..1b94a7e --- /dev/null +++ b/dictionary_learning/activault_s3_buffer.py @@ -0,0 +1,744 @@ +"""Copyright (2025) Tilde Research Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import io +import json +import os +import random +import signal +import sys +import time +import warnings +from multiprocessing import Process, Queue, Value +from typing import Optional + +import einops +import aiohttp +import boto3 +import torch +import torch.nn as nn +import multiprocessing as mp +import warnings +import logging + +logger = logging.getLogger(__name__) + +# Constants for file sizes +KB = 1024 +MB = KB * KB + +# Cache directory constants +OUTER_CACHE_DIR = "cache" +INNER_CACHE_DIR = "cache" +BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") + + +def _metadata_path(run_name): + """Generate the metadata file path for a given run name.""" + return f"{run_name}/metadata.json" + + +def _statistics_path(run_name): + """Generate the statistics file path for a given run name.""" + return f"{run_name}/statistics.json" + + +async def download_chunks(session, url, total_size, chunk_size): + """Download file chunks asynchronously with retries.""" + tries_left = 5 + while tries_left > 0: + chunks = [ + (i, min(i + chunk_size - 1, total_size - 1)) + for i in range(0, total_size, chunk_size) + ] + tasks = [ + asyncio.create_task(request_chunk(session, url, start, end)) + for start, end in chunks + ] + responses = await asyncio.gather(*tasks, return_exceptions=True) + + results = [] + retry = False + for response in responses: + if isinstance(response, Exception): + logger.error(f"Error occurred: {response}") + logger.error( + f"Session: {session}, URL: {url}, Tries left: {tries_left}" + ) + tries_left -= 1 + retry = True + break + else: + results.append(response) + + if not retry: + return results + + return None + + +async def request_chunk(session, url, start, end): + """Request a specific chunk of a file.""" + headers = {"Range": f"bytes={start}-{end}"} + try: + async with session.get(url, headers=headers) as response: + response.raise_for_status() + return start, await response.read() + except Exception as e: + return e + + +def download_loop(*args): + """Run the asynchronous download loop.""" + asyncio.run(_async_download(*args)) + + +def compile(byte_buffers, shuffle=True, seed=None, return_ids=False): + """Compile downloaded chunks into a tensor.""" + combined_bytes = b"".join( + chunk for _, chunk in sorted(byte_buffers, key=lambda x: x[0]) + ) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # n = np.frombuffer(combined_bytes, dtype=np.float16) + # t = torch.from_numpy(n) + # t = torch.frombuffer(combined_bytes, dtype=dtype) # torch.float32 + buffer = io.BytesIO(combined_bytes) + t = torch.load(buffer) + if ( + isinstance(t, dict) and "states" in t and not return_ids + ): # backward compatibility + t = t["states"] # ignore input_ids + buffer.close() + + if shuffle and not return_ids: + t = shuffle_megabatch_tokens(t, seed) + + return t + + +def shuffle_megabatch_tokens(t, seed=None): + """ + Shuffle within a megabatch (across batches and sequences), using each token as the unit of shuffling. + + Args: + t (torch.Tensor): Input tensor of shape (batch_size * batches_per_file, sequence_length, d_in + 1) + seed (int): Seed for the random number generator + + Returns: + torch.Tensor: Shuffled tensor of the same shape as input + """ + original_shape = ( + t.shape + ) # (batch_size * batches_per_file, sequence_length, d_in + 1) + + total_tokens = ( + original_shape[0] * original_shape[1] + ) # reshape to (total_tokens, d_in + 1) + t_reshaped = t.reshape(total_tokens, -1) + + rng = torch.Generator() + if seed is not None: + rng.manual_seed(seed) + + shuffled_indices = torch.randperm(total_tokens, generator=rng) + t_shuffled = t_reshaped[shuffled_indices] + + t = t_shuffled.reshape(original_shape) # revert + + return t + + +def write_tensor(t, buffer, writeable_tensors, readable_tensors, ongoing_downloads): + """Write a tensor to the shared buffer.""" + idx = writeable_tensors.get(block=True) + if isinstance(buffer[0], SharedBuffer): + buffer[idx].states.copy_(t["states"]) + buffer[idx].input_ids.copy_(t["input_ids"]) + else: + buffer[idx] = t + + readable_tensors.put(idx, block=True) + with ongoing_downloads.get_lock(): + ongoing_downloads.value -= 1 + + +async def _async_download( + buffer, + file_index, + s3_paths, + stop, + readable_tensors, + writeable_tensors, + ongoing_downloads, + concurrency, + bytes_per_file, + chunk_size, + shuffle, + seed, + return_ids, +): + """Asynchronously download and process files from S3.""" + connector = aiohttp.TCPConnector(limit=concurrency) + async with aiohttp.ClientSession(connector=connector) as session: + while file_index.value < len(s3_paths) and not stop.value: + with ongoing_downloads.get_lock(): + ongoing_downloads.value += 1 + with file_index.get_lock(): + url = s3_paths[file_index.value] + file_index.value += 1 + bytes_results = await download_chunks( + session, url, bytes_per_file, chunk_size + ) + if bytes_results is not None: + try: + t = compile(bytes_results, shuffle, seed, return_ids) + write_tensor( + t, + buffer, + writeable_tensors, + readable_tensors, + ongoing_downloads, + ) + except Exception as e: + logger.error(f"Exception while downloading: {e}") + logger.error(f"Failed URL: {url}") + stop.value = True # Set stop flag + break # Exit the loop + else: + logger.error(f"Failed to download URL: {url}") + with ongoing_downloads.get_lock(): + ongoing_downloads.value -= 1 + + +class S3RCache: + """A cache that reads data from Amazon S3.""" + + @classmethod + def from_credentials( + self, aws_access_key_id, aws_secret_access_key, *args, **kwargs + ): + s3_client = boto3.client( + "s3", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + endpoint_url=os.environ.get("S3_ENDPOINT_URL"), + ) + return S3RCache(s3_client, *args, **kwargs) + + def __init__( + self, + s3_client, + s3_prefix, + bucket_name=BUCKET_NAME, + device="cpu", + concurrency=100, + chunk_size=MB * 16, + buffer_size=2, + shuffle=True, + preserve_file_order=False, + seed=42, + paths=None, + n_workers=1, + return_ids=False, + ) -> None: + """Initialize S3 cache.""" + ensure_spawn_context() + + # Configure S3 client with correct signature version + self.s3_client = ( + boto3.client( + "s3", + region_name="eu-north1", # Make sure this matches your bucket region + config=boto3.session.Config(signature_version="s3v4"), + ) + if s3_client is None + else s3_client + ) + + self.s3_prefix = s3_prefix + self.bucket_name = bucket_name + self.device = device + self.concurrency = concurrency + self.chunk_size = chunk_size + self.buffer_size = buffer_size + self.shuffle = shuffle + self.preserve_file_order = preserve_file_order + self.seed = seed + self.return_ids = return_ids + + random.seed(self.seed) + torch.manual_seed(self.seed) # unclear if this has effect + # but we drill down the seed to download loop anyway + + self.paths = paths + self._s3_paths = self._list_s3_files() + if isinstance(self.s3_prefix, list): + target_prefix = self.s3_prefix[0] + else: + target_prefix = self.s3_prefix + response = self.s3_client.get_object( + Bucket=bucket_name, Key=_metadata_path(target_prefix) + ) + content = response["Body"].read() + self.metadata = json.loads(content) + # self.metadata["bytes_per_file"] = 1612711320 + self._activation_dtype = eval(self.metadata["dtype"]) + + self._running_processes = [] + self.n_workers = n_workers + + self.readable_tensors = Queue(maxsize=self.buffer_size) + self.writeable_tensors = Queue(maxsize=self.buffer_size) + + for i in range(self.buffer_size): + self.writeable_tensors.put(i) + + if self.return_ids: + self.buffer = [ + SharedBuffer( + self.metadata["shape"], + self.metadata["input_ids_shape"], + self._activation_dtype, + ) + for _ in range(self.buffer_size) + ] + for shared_buffer in self.buffer: + shared_buffer.share_memory() + else: + self.buffer = torch.empty( + (self.buffer_size, *self.metadata["shape"]), + dtype=self._activation_dtype, + ).share_memory_() + + self._stop = Value("b", False) + self._file_index = Value("i", 0) + self._ongoing_downloads = Value("i", 0) + + signal.signal(signal.SIGTERM, self._catch_stop) + signal.signal(signal.SIGINT, self._catch_stop) + + self._initial_file_index = 0 + + @property + def current_file_index(self): + return self._file_index.value + + def set_file_index(self, index): + self._initial_file_index = index + + def _catch_stop(self, *args, **kwargs): + logger.info("cleaning up before process is killed") + self._stop_downloading() + sys.exit(0) + + def sync(self): + self._s3_paths = self._list_s3_files() + + def _reset(self): + self._file_index.value = self._initial_file_index + self._ongoing_downloads.value = 0 + self._stop.value = False + + while not self.readable_tensors.empty(): + self.readable_tensors.get() + + while not self.writeable_tensors.empty(): + self.writeable_tensors.get() + for i in range(self.buffer_size): + self.writeable_tensors.put(i) + + def _list_s3_files(self): + """List and prepare all data files from one or more S3 prefixes.""" + paths = [] + combined_metadata = None + combined_config = None + + # Handle single prefix case for backward compatibility + prefixes = ( + [self.s3_prefix] if isinstance(self.s3_prefix, str) else self.s3_prefix + ) + + # Process each prefix + for prefix in prefixes: + # Get metadata for this prefix + response = self.s3_client.get_object( + Bucket=self.bucket_name, Key=_metadata_path(prefix) + ) + metadata = json.loads(response["Body"].read()) + + # Get config for this prefix + try: + config_response = self.s3_client.get_object( + Bucket=self.bucket_name, + Key=f"{'/'.join(prefix.split('/')[:-1])}/cfg.json", + ) + config = json.loads(config_response["Body"].read()) + except Exception as e: + logger.warning( + f"Warning: Could not load config for prefix {prefix}: {e}" + ) + config = {} + + # Initialize combined metadata and config from first prefix + if combined_metadata is None: + combined_metadata = metadata.copy() + combined_config = config.copy() + # Initialize accumulation fields + combined_config["total_tokens"] = 0 + combined_config["n_total_files"] = 0 + combined_config["batches_processed"] = 0 + else: + # Verify metadata compatibility + if metadata["shape"][1:] != combined_metadata["shape"][1:]: + raise ValueError( + f"Incompatible shapes between datasets: {metadata['shape']} vs {combined_metadata['shape']}" + ) + if metadata["dtype"] != combined_metadata["dtype"]: + raise ValueError(f"Incompatible dtypes between datasets") + + # Accumulate config fields + combined_config["total_tokens"] += config.get("total_tokens", 0) + combined_config["n_total_files"] += config.get("n_total_files", 0) + combined_config["batches_processed"] += config.get("batches_processed", 0) + + # List files for this prefix + paginator = self.s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix) + + prefix_paths = [] + for page in page_iterator: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + if ( + obj["Key"] != _metadata_path(prefix) + and obj["Key"] != _statistics_path(prefix) + and not obj["Key"].endswith("cfg.json") + ): + url = self.s3_client.generate_presigned_url( + "get_object", + Params={"Bucket": self.bucket_name, "Key": obj["Key"]}, + ExpiresIn=604700, + ) + prefix_paths.append(url) + + paths.extend(prefix_paths) + + # Store the combined metadata and config + self.metadata = combined_metadata + self.config = combined_config # Store combined config for potential later use + + if self.preserve_file_order: + # chronological upload order + return sorted(paths) + else: + # shuffle the file order + random.shuffle(paths) + return paths + + def __iter__(self): + self._reset() + + if self._running_processes: + raise ValueError( + "Cannot iterate over cache a second time while it is downloading" + ) + + if len(self._s3_paths) > self._initial_file_index: + while len(self._running_processes) < self.n_workers: + p = Process( + target=download_loop, + args=( + self.buffer, + self._file_index, + self._s3_paths[ + self._initial_file_index : + ], # Start from the initial index + self._stop, + self.readable_tensors, + self.writeable_tensors, + self._ongoing_downloads, + self.concurrency, + self.metadata["bytes_per_file"], + self.chunk_size, + self.shuffle, + self.seed, + self.return_ids, + ), + ) + p.start() + self._running_processes.append(p) + time.sleep(0.75) + + return self + + def _next_tensor(self): + try: + idx = self.readable_tensors.get(block=True) + if self.return_ids: + t = { + "states": self.buffer[idx].states.clone().detach(), + "input_ids": self.buffer[idx].input_ids.clone().detach(), + } + else: + t = self.buffer[idx].clone().detach() + + self.writeable_tensors.put(idx, block=True) + return t + except Exception as e: + logger.error(f"exception while iterating: {e}") + self._stop_downloading() + raise StopIteration + + def __next__(self): + while ( + self._file_index.value < len(self._s3_paths) + or not self.readable_tensors.empty() + or self._ongoing_downloads.value > 0 + ): + return self._next_tensor() + + if self._running_processes: + self._stop_downloading() + raise StopIteration + + def finalize(self): + self._stop_downloading() + + def _stop_downloading(self): + logger.info("stopping workers...") + self._file_index.value = len(self._s3_paths) + self._stop.value = True + + while not all([not p.is_alive() for p in self._running_processes]): + if not self.readable_tensors.empty(): + self.readable_tensors.get() + + if not self.writeable_tensors.full(): + self.writeable_tensors.put(0) + + time.sleep(0.25) + + for p in self._running_processes: + p.join() # still join to make sure all resources are cleaned up + + self._ongoing_downloads.value = 0 + self._running_processes = [] + + +""" +tl;dr of why we need this: +shared memory is handled differently for nested structures -- see buffer intiialization +we can initialize a dict with two tensors with shared memory, and these tensors themselves are shared but NOT the dict +hence writing to buffer[idx] in write_tensor will not actually write to self.buffer[idx], which _next_tensor uses +(possibly a better fix, but for now this works) +""" + + +class SharedBuffer(nn.Module): + def __init__(self, shape, input_ids_shape, dtype): + super().__init__() + self.states = nn.Parameter(torch.ones(shape, dtype=dtype), requires_grad=False) + self.input_ids = nn.Parameter( + torch.ones(input_ids_shape, dtype=torch.int64), requires_grad=False + ) + + def forward(self): + return {"states": self.states, "input_ids": self.input_ids} + + +### mini-helper for multiprocessing +def ensure_spawn_context(): + """ + Ensures multiprocessing uses 'spawn' context if not already set. + Returns silently if already set to 'spawn'. + Issues warning if unable to set to 'spawn'. + """ + if mp.get_start_method(allow_none=True) != "spawn": + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + warnings.warn( + "Multiprocessing start method is not 'spawn'. This may cause issues." + ) + + +def create_s3_client( + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, + endpoint_url: Optional[str] = None, +) -> boto3.client: + """Create an S3 client configured for S3-compatible storage services. + + This function creates a boto3 S3 client with optimized settings for reliable + data transfer. It supports both direct credential passing and environment + variable configuration. + + Args: + access_key_id: S3 access key ID. If None, reads from AWS_ACCESS_KEY_ID env var + secret_access_key: S3 secret key. If None, reads from AWS_SECRET_ACCESS_KEY env var + endpoint_url: S3-compatible storage service endpoint URL + + Returns: + boto3.client: Configured S3 client with optimized settings + + Environment Variables: + - AWS_ACCESS_KEY_ID: S3 access key ID (if not provided as argument) + - AWS_SECRET_ACCESS_KEY: S3 secret key (if not provided as argument) + + Example: + ```python + # Using environment variables + s3_client = create_s3_client() + + # Using explicit credentials + s3_client = create_s3_client( + access_key_id="your_key", + secret_access_key="your_secret", + endpoint_url="your_endpoint_url" + ) + ``` + + Note: + The client is configured with path-style addressing and S3v4 signatures + for maximum compatibility with S3-compatible storage services. + """ + access_key_id = access_key_id or os.environ.get("AWS_ACCESS_KEY_ID") + secret_access_key = secret_access_key or os.environ.get("AWS_SECRET_ACCESS_KEY") + endpoint_url = endpoint_url or os.environ.get("S3_ENDPOINT_URL") + + if not access_key_id or not secret_access_key: + raise ValueError( + "S3 credentials must be provided either through arguments or " + "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY environment variables" + ) + + if not endpoint_url: + raise ValueError( + "S3 endpoint URL must be provided either through arguments or " + "S3_ENDPOINT_URL environment variable" + ) + + session = boto3.session.Session() + return session.client( + service_name="s3", + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + endpoint_url=endpoint_url, + use_ssl=True, + verify=True, + config=boto3.session.Config( + s3={"addressing_style": "path"}, + signature_version="s3v4", + # Advanced configuration options (currently commented out): + # retries=dict( + # max_attempts=3, # Number of retry attempts + # mode='adaptive' # Adds exponential backoff + # ), + # max_pool_connections=20, # Limits concurrent connections + # connect_timeout=60, # Connection timeout in seconds + # read_timeout=300, # Read timeout in seconds + # tcp_keepalive=True, # Enable TCP keepalive + ), + ) + + +class ActivaultS3ActivationBuffer: + def __init__( + self, + cache: S3RCache, + batch_size: int = 8192, + device: str = "cpu", + io: str = "out", + ): + self.cache = iter(cache) # Make sure it's an iterator + self.batch_size = batch_size + self.device = device + self.io = io + + self.states = None # Shape: [N, D] + self.read_mask = None # Shape: [N] + self.refresh() # Load the first batch + + def __iter__(self): + return self + + def __next__(self): + with torch.no_grad(): + if (~self.read_mask).sum() < self.batch_size: + self.refresh() + + if self.states is None or self.states.shape[0] == 0: + raise StopIteration + + unreads = (~self.read_mask).nonzero().squeeze() + if unreads.ndim == 0: + unreads = unreads.unsqueeze(0) + selected = unreads[ + torch.randperm(len(unreads), device=self.device)[: self.batch_size] + ] + self.read_mask[selected] = True + return self.states[selected] + + def refresh(self): + try: + next_batch = next(self.cache) # dict with "states" key + except StopIteration: + self.states = None + self.read_mask = None + return + + states = next_batch["states"].to(self.device) # [B, L, D] + flat_states = einops.rearrange(states, "b l d -> (b l) d").contiguous() + self.states = flat_states + self.read_mask = torch.zeros( + flat_states.shape[0], dtype=torch.bool, device=self.device + ) + + def close(self): + if hasattr(self.cache, "finalize"): + self.cache.finalize() + elif hasattr(self.cache, "close"): + self.cache.close() + + +if __name__ == "__main__": + device = "cuda" + sae_batch_size = 2048 + io = "out" + + # example activault usage + + BUCKET_NAME = os.environ.get("S3_BUCKET_NAME", "main") + s3_prefix = ["mistral.8b.fineweb/blocks.9.hook_resid_post"] + cache = S3RCache.from_credentials( + aws_access_key_id=os.environ.get("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=os.environ.get("AWS_SECRET_ACCESS_KEY"), + s3_prefix=s3_prefix, + bucket_name=BUCKET_NAME, + device=device, + buffer_size=2, + return_ids=True, + shuffle=True, + n_workers=2, + ) + + s3_buffer = ActivaultS3ActivationBuffer( + cache, batch_size=sae_batch_size, device=device, io=io + ) From c7b2527ef8f8fc7e1b46ee7d570f9f3723f1fd1d Mon Sep 17 00:00:00 2001 From: Andy Arditi Date: Sun, 18 May 2025 02:56:45 +0000 Subject: [PATCH 68/70] assert right padding for remove_bos logic --- dictionary_learning/buffer.py | 6 ++++++ dictionary_learning/pytorch_buffer.py | 4 ++++ 2 files changed, 10 insertions(+) diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index 67e1c0f..ad3727e 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -59,6 +59,11 @@ def __init__(self, self.remove_bos = remove_bos self.add_special_tokens = add_special_tokens + print(self.model.tokenizer.padding_side) + + if self.remove_bos: + assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" + def __iter__(self): return self @@ -138,6 +143,7 @@ def refresh(self): if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] if self.remove_bos: + assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" hidden_states = hidden_states[:, 1:, :] attn_mask = attn_mask[:, 1:] hidden_states = hidden_states[attn_mask != 0] diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py index 720101f..c4cd6ec 100644 --- a/dictionary_learning/pytorch_buffer.py +++ b/dictionary_learning/pytorch_buffer.py @@ -123,6 +123,9 @@ def __init__( if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token + + if self.remove_bos: + assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" def __iter__(self): return self @@ -194,6 +197,7 @@ def refresh(self): hidden_states = collect_activations(self.model, self.submodule, input) attn_mask = input["attention_mask"] if self.remove_bos: + assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" hidden_states = hidden_states[:, 1:, :] attn_mask = attn_mask[:, 1:] hidden_states = hidden_states[attn_mask != 0] From 59abc88dc10e48c5f893587877027092a9d3b75f Mon Sep 17 00:00:00 2001 From: Andy Arditi Date: Sun, 18 May 2025 04:10:11 +0000 Subject: [PATCH 69/70] mask out bos activations --- dictionary_learning/buffer.py | 17 ++++++----------- dictionary_learning/pytorch_buffer.py | 14 +++++--------- 2 files changed, 11 insertions(+), 20 deletions(-) diff --git a/dictionary_learning/buffer.py b/dictionary_learning/buffer.py index ad3727e..f0e02e1 100644 --- a/dictionary_learning/buffer.py +++ b/dictionary_learning/buffer.py @@ -56,14 +56,9 @@ def __init__(self, self.refresh_batch_size = refresh_batch_size self.out_batch_size = out_batch_size self.device = device - self.remove_bos = remove_bos + self.remove_bos = remove_bos and (self.model.tokenizer.bos_token_id is not None) self.add_special_tokens = add_special_tokens - print(self.model.tokenizer.padding_side) - - if self.remove_bos: - assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" - def __iter__(self): return self @@ -138,15 +133,15 @@ def refresh(self): input = self.model.inputs.save() self.submodule.output.stop() - attn_mask = input.value[1]["attention_mask"] + + mask = (input.value[1]["attention_mask"] != 0) hidden_states = hidden_states.value if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] if self.remove_bos: - assert self.model.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" - hidden_states = hidden_states[:, 1:, :] - attn_mask = attn_mask[:, 1:] - hidden_states = hidden_states[attn_mask != 0] + bos_mask = (input.value[1]["input_ids"] == self.model.tokenizer.bos_token_id) + mask = mask & ~bos_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx assert remaining_space > 0 diff --git a/dictionary_learning/pytorch_buffer.py b/dictionary_learning/pytorch_buffer.py index c4cd6ec..9d943e4 100644 --- a/dictionary_learning/pytorch_buffer.py +++ b/dictionary_learning/pytorch_buffer.py @@ -117,15 +117,12 @@ def __init__( self.refresh_batch_size = refresh_batch_size self.out_batch_size = out_batch_size self.device = device - self.remove_bos = remove_bos self.add_special_tokens = add_special_tokens self.tokenizer = AutoTokenizer.from_pretrained(model.name_or_path) + self.remove_bos = remove_bos and (self.tokenizer.bos_token_id is not None) if not self.tokenizer.pad_token: self.tokenizer.pad_token = self.tokenizer.eos_token - - if self.remove_bos: - assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" def __iter__(self): return self @@ -195,12 +192,11 @@ def refresh(self): with t.no_grad(): input = self.tokenized_batch() hidden_states = collect_activations(self.model, self.submodule, input) - attn_mask = input["attention_mask"] + mask = (input["attention_mask"] != 0) if self.remove_bos: - assert self.tokenizer.padding_side == "right", "Padding side must be right (bos-trimming logic assumes right padding)" - hidden_states = hidden_states[:, 1:, :] - attn_mask = attn_mask[:, 1:] - hidden_states = hidden_states[attn_mask != 0] + bos_mask = (input["input_ids"] == self.tokenizer.bos_token_id) + mask = mask & ~bos_mask + hidden_states = hidden_states[mask] remaining_space = self.activation_buffer_size - current_idx assert remaining_space > 0 From 61ac634845bd76c839482f3b725ab3d898c8b277 Mon Sep 17 00:00:00 2001 From: canrager <61095597+canrager@users.noreply.github.com> Date: Mon, 14 Jul 2025 20:45:10 +0200 Subject: [PATCH 70/70] add handling of device and dtype to IdentityDict --- dictionary_learning/dictionary.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/dictionary_learning/dictionary.py b/dictionary_learning/dictionary.py index bceed6a..238a866 100644 --- a/dictionary_learning/dictionary.py +++ b/dictionary_learning/dictionary.py @@ -150,29 +150,43 @@ class IdentityDict(Dictionary, nn.Module): An identity dictionary, i.e. the identity function. """ - def __init__(self, activation_dim=None): + def __init__(self, activation_dim=None, dtype=None, device=None): super().__init__() self.activation_dim = activation_dim self.dict_size = activation_dim + self.device = device + self.dtype = dtype def encode(self, x): + if self.device is not None: + x = x.to(self.device) + if self.dtype is not None: + x = x.to(self.dtype) return x def decode(self, f): + if self.device is not None: + f = f.to(self.device) + if self.dtype is not None: + f = f.to(self.dtype) return f def forward(self, x, output_features=False, ghost_mask=None): + if self.device is not None: + x = x.to(self.device) + if self.dtype is not None: + x = x.to(self.dtype) if output_features: return x, x else: return x @classmethod - def from_pretrained(cls, path, dtype=t.float, device=None): + def from_pretrained(cls, activation_dim, path, dtype=None, device=None): """ Load a pretrained dictionary from a file. """ - return cls(None) + return cls(activation_dim, device=device, dtype=dtype) class GatedAutoEncoder(Dictionary, nn.Module):