Skip to content

Commit 1719007

Browse files
committed
Fixed bug with BatchTopK SAEs and normalization in loss computation.
1 parent f8343df commit 1719007

File tree

3 files changed

+5
-6
lines changed

3 files changed

+5
-6
lines changed

dictionary_learning/cache.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -503,15 +503,13 @@ def collect(
503503
), "Shuffling shards and storing tokens is not supported yet"
504504

505505
# Check if we need to store sequence ranges
506-
has_bos_token = model.tokenizer.bos_token_id is not None
506+
has_bos_token = model.tokenizer.bos_token is not None
507507
store_sequence_ranges = (
508508
store_tokens and
509509
not shuffle_shards and
510510
not has_bos_token
511511
)
512-
if store_sequence_ranges:
513-
print("No BOS token found. Will store sequence ranges.")
514-
512+
515513
dataloader = DataLoader(data, batch_size=batch_size, num_workers=num_workers)
516514

517515
activation_cache = [[] for _ in submodules]

dictionary_learning/trainers/batch_top_k.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def loss(
173173
if step > self.threshold_start_step:
174174
self.update_threshold(f)
175175

176-
x_hat = self.ae.decode(f, denormalize_activations=normalize_activations)
176+
x_hat = self.ae.decode(f, denormalize_activations=False)
177177

178178
e = x - x_hat
179179

dictionary_learning/training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import wandb
1212
from typing import List, Optional
1313

14+
from .trainers.batch_top_k import BatchTopKTrainer
1415
from .trainers.crosscoder import CrossCoderTrainer, BatchTopKCrossCoderTrainer
1516

1617

@@ -300,7 +301,7 @@ def trainSAE(
300301
use_threshold=False,
301302
epoch_idx_per_step=epoch_idx_per_step,
302303
)
303-
if isinstance(trainer, BatchTopKCrossCoderTrainer):
304+
if isinstance(trainer, BatchTopKCrossCoderTrainer) or isinstance(trainer, BatchTopKTrainer):
304305
log_stats(
305306
trainer,
306307
step,

0 commit comments

Comments
 (0)