File tree Expand file tree Collapse file tree 3 files changed +5
-6
lines changed Expand file tree Collapse file tree 3 files changed +5
-6
lines changed Original file line number Diff line number Diff line change @@ -503,15 +503,13 @@ def collect(
503
503
), "Shuffling shards and storing tokens is not supported yet"
504
504
505
505
# Check if we need to store sequence ranges
506
- has_bos_token = model .tokenizer .bos_token_id is not None
506
+ has_bos_token = model .tokenizer .bos_token is not None
507
507
store_sequence_ranges = (
508
508
store_tokens and
509
509
not shuffle_shards and
510
510
not has_bos_token
511
511
)
512
- if store_sequence_ranges :
513
- print ("No BOS token found. Will store sequence ranges." )
514
-
512
+
515
513
dataloader = DataLoader (data , batch_size = batch_size , num_workers = num_workers )
516
514
517
515
activation_cache = [[] for _ in submodules ]
Original file line number Diff line number Diff line change @@ -173,7 +173,7 @@ def loss(
173
173
if step > self .threshold_start_step :
174
174
self .update_threshold (f )
175
175
176
- x_hat = self .ae .decode (f , denormalize_activations = normalize_activations )
176
+ x_hat = self .ae .decode (f , denormalize_activations = False )
177
177
178
178
e = x - x_hat
179
179
Original file line number Diff line number Diff line change 11
11
import wandb
12
12
from typing import List , Optional
13
13
14
+ from .trainers .batch_top_k import BatchTopKTrainer
14
15
from .trainers .crosscoder import CrossCoderTrainer , BatchTopKCrossCoderTrainer
15
16
16
17
@@ -300,7 +301,7 @@ def trainSAE(
300
301
use_threshold = False ,
301
302
epoch_idx_per_step = epoch_idx_per_step ,
302
303
)
303
- if isinstance (trainer , BatchTopKCrossCoderTrainer ):
304
+ if isinstance (trainer , BatchTopKCrossCoderTrainer ) or isinstance ( trainer , BatchTopKTrainer ) :
304
305
log_stats (
305
306
trainer ,
306
307
step ,
You can’t perform that action at this time.
0 commit comments