Skip to content

Commit f8343df

Browse files
committed
Enhance ActivationCache to support sequence range tracking based on presence of BOS token
This commit introduces the following changes: - Added a `_sequence_ranges` attribute to the `ActivationCache` class to store sequence start indices. - Implemented a `sequence_ranges` property that loads sequence ranges from a file if they exist and are configured to be stored. - Updated the `collect` method to track and store sequence ranges when the model does not have a beginning-of-sequence (BOS) token. - Added assertions to ensure the integrity of sequence ranges during activation collection. - Introduced tests to verify the correct behavior of sequence range storage when the model has and does not have a BOS token. These modifications improve the functionality of the `ActivationCache` by allowing it to handle sequence ranges more effectively, enhancing the overall clarity and maintainability of the code.
1 parent 4541d10 commit f8343df

File tree

3 files changed

+227
-44
lines changed

3 files changed

+227
-44
lines changed

dictionary_learning/cache.py

Lines changed: 83 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def __init__(self, store_dir: str, submodule_name: str = None):
269269
os.path.join(store_dir, "tokens.pt"), weights_only=True
270270
).cpu()
271271

272+
self._sequence_ranges = None
272273
self._mean = None
273274
self._std = None
274275

@@ -322,6 +323,23 @@ def __getitem__(self, index: int):
322323
def tokens(self):
323324
return self._tokens
324325

326+
@property
327+
def sequence_ranges(self):
328+
if hasattr(self, '_sequence_ranges') and self._sequence_ranges is not None:
329+
return self._sequence_ranges
330+
331+
if ("store_sequence_ranges" in self.config and
332+
self.config["store_sequence_ranges"] and
333+
os.path.exists(os.path.join(self._cache_store_dir, "..", "sequence_ranges.pt"))):
334+
self._sequence_ranges = th.load(
335+
os.path.join(self._cache_store_dir, "..", "sequence_ranges.pt"),
336+
weights_only=True
337+
).cpu()
338+
return self._sequence_ranges
339+
else:
340+
# Return None if sequence ranges not available
341+
return None
342+
325343
@staticmethod
326344
def get_activations(submodule: nn.Module, io: str):
327345
if io == "in":
@@ -434,17 +452,25 @@ def exists(
434452
cached data is present and num_tokens is the total number of tokens in the cache
435453
"""
436454
num_tokens = 0
455+
config = None
437456
for submodule_name in submodule_names:
438-
if not os.path.exists(
439-
os.path.join(store_dir, f"{submodule_name}_{io}", "config.json")
440-
):
457+
config_path = os.path.join(store_dir, f"{submodule_name}_{io}", "config.json")
458+
if not os.path.exists(config_path):
441459
return False, 0
442-
with open(
443-
os.path.join(store_dir, f"{submodule_name}_{io}", "config.json"), "r"
444-
) as f:
445-
num_tokens = json.load(f)["total_size"]
460+
with open(config_path, "r") as f:
461+
config = json.load(f)
462+
num_tokens = config["total_size"]
463+
446464
if store_tokens and not os.path.exists(os.path.join(store_dir, "tokens.pt")):
447465
return False, 0
466+
467+
# Check for sequence ranges if they should exist
468+
if (config and
469+
"store_sequence_ranges" in config and
470+
config["store_sequence_ranges"] and
471+
not os.path.exists(os.path.join(store_dir, "sequence_ranges.pt"))):
472+
return False, 0
473+
448474
return True, num_tokens
449475

450476
@th.no_grad()
@@ -475,10 +501,24 @@ def collect(
475501
assert (
476502
not shuffle_shards or not store_tokens
477503
), "Shuffling shards and storing tokens is not supported yet"
504+
505+
# Check if we need to store sequence ranges
506+
has_bos_token = model.tokenizer.bos_token_id is not None
507+
store_sequence_ranges = (
508+
store_tokens and
509+
not shuffle_shards and
510+
not has_bos_token
511+
)
512+
if store_sequence_ranges:
513+
print("No BOS token found. Will store sequence ranges.")
514+
478515
dataloader = DataLoader(data, batch_size=batch_size, num_workers=num_workers)
479516

480517
activation_cache = [[] for _ in submodules]
481518
tokens_cache = []
519+
sequence_ranges_cache = []
520+
current_token_position = 0 # Track position in flattened token stream
521+
482522
store_sub_dirs = [
483523
os.path.join(store_dir, f"{submodule_names[i]}_{io}")
484524
for i in range(len(submodules))
@@ -530,6 +570,14 @@ def collect(
530570
store_mask = attention_mask.clone()
531571
if ignore_first_n_tokens_per_sample > 0:
532572
store_mask[:, :ignore_first_n_tokens_per_sample] = 0
573+
574+
# Track sequence ranges if needed
575+
if store_sequence_ranges:
576+
batch_lengths = store_mask.sum(dim=1).tolist()
577+
batch_sequence_ranges = np.cumsum([0] + batch_lengths[:-1]) + current_token_position
578+
sequence_ranges_cache.extend(batch_sequence_ranges.tolist())
579+
current_token_position += sum(batch_lengths)
580+
533581
if store_tokens:
534582
tokens_cache.append(
535583
tokens["input_ids"].reshape(-1)[store_mask.reshape(-1).bool()]
@@ -572,7 +620,8 @@ def collect(
572620
if dtype is not None:
573621
activation_cache[i][-1] = activation_cache[i][-1].to(dtype)
574622

575-
assert len(tokens_cache[-1]) == activation_cache[0][-1].shape[0]
623+
if store_tokens:
624+
assert len(tokens_cache[-1]) == activation_cache[0][-1].shape[0]
576625
assert activation_cache[0][-1].shape[0] == store_mask.sum().item()
577626
current_size += activation_cache[0][-1].shape[0]
578627
else:
@@ -639,6 +688,7 @@ def collect(
639688
"total_size": total_size,
640689
"shard_count": shard_count,
641690
"store_tokens": store_tokens,
691+
"store_sequence_ranges": store_sequence_ranges,
642692
},
643693
f,
644694
)
@@ -652,6 +702,16 @@ def collect(
652702
), f"{tokens_cache.shape[0]} != {total_size}"
653703
th.save(tokens_cache, os.path.join(store_dir, "tokens.pt"))
654704

705+
# store sequence ranges
706+
if store_sequence_ranges:
707+
print("Storing sequence ranges...")
708+
# add the last sequence range to the end of the cache
709+
sequence_ranges_cache.append(current_token_position)
710+
assert sequence_ranges_cache[-1] == total_size
711+
sequence_ranges_tensor = th.tensor(sequence_ranges_cache, dtype=th.long)
712+
th.save(sequence_ranges_tensor, os.path.join(store_dir, "sequence_ranges.pt"))
713+
print(f"Stored {len(sequence_ranges_cache)} sequence ranges")
714+
655715
# store running stats
656716
for i in range(len(submodules)):
657717
th.save(
@@ -685,6 +745,14 @@ def tokens(self):
685745
(self.activation_cache_1.tokens, self.activation_cache_2.tokens), dim=0
686746
)
687747

748+
@property
749+
def sequence_ranges(self):
750+
seq_starts_1 = self.activation_cache_1.sequence_ranges
751+
seq_starts_2 = self.activation_cache_2.sequence_ranges
752+
if seq_starts_1 is not None and seq_starts_2 is not None:
753+
return th.stack((seq_starts_1, seq_starts_2), dim=0)
754+
return None
755+
688756
@property
689757
def mean(self):
690758
return th.stack(
@@ -718,6 +786,13 @@ def __getitem__(self, index: int):
718786
def tokens(self):
719787
return th.stack([cache.tokens for cache in self.activation_caches], dim=0)
720788

789+
@property
790+
def sequence_ranges(self):
791+
seq_starts_list = [cache.sequence_ranges for cache in self.activation_caches]
792+
if all(seq_starts is not None for seq_starts in seq_starts_list):
793+
return th.stack(seq_starts_list, dim=0)
794+
return None
795+
721796
@property
722797
def mean(self):
723798
return th.stack([cache.mean for cache in self.activation_caches], dim=0)

dictionary_learning/dictionary.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -465,11 +465,9 @@ def __init__(
465465
activation_mean: Optional mean tensor for input activation normalization. If None, no normalization is applied.
466466
activation_std: Optional std tensor for input activation normalization. If None, no normalization is applied.
467467
"""
468-
# First initialize the base classes that don't take normalization parameters
469-
super().__init__()
468+
469+
super().__init__(activation_mean=activation_mean, activation_std=activation_std, activation_shape=(activation_dim,))
470470

471-
# Then explicitly initialize the NormalizableMixin
472-
NormalizableMixin.__init__(self, activation_mean=activation_mean, activation_std=activation_std, activation_shape=(activation_dim,))
473471

474472
self.activation_dim = activation_dim
475473
self.dict_size = dict_size
@@ -1036,9 +1034,7 @@ def __init__(
10361034
"""
10371035
# First initialize the base classes that don't take normalization parameters
10381036
super().__init__(activation_mean=activation_mean, activation_std=activation_std, activation_shape=(num_layers, activation_dim))
1039-
1040-
# Then explicitly initialize the NormalizableMixin
1041-
# NormalizableMixin.__init__(
1037+
10421038

10431039
if num_decoder_layers is None:
10441040
num_decoder_layers = num_layers
@@ -1266,7 +1262,7 @@ def from_pretrained(
12661262
"""
12671263
if isinstance(code_normalization, str):
12681264
code_normalization = CodeNormalization.from_string(code_normalization)
1269-
if from_hub:
1265+
if from_hub or path.endswith(".safetensors"):
12701266
return super().from_pretrained(path, device=device, dtype=dtype, **kwargs)
12711267

12721268
state_dict = th.load(path, map_location="cpu", weights_only=True)

tests/test_cache.py

Lines changed: 140 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from nnsight import LanguageModel
88
from dictionary_learning.cache import ActivationCache
99
from transformers import AutoModelForCausalLM, AutoTokenizer
10-
10+
import numpy as np
1111

1212
@pytest.fixture
1313
def temp_dir():
@@ -267,36 +267,148 @@ def test_activation_cache_with_normalizer(temp_dir):
267267
cache.std, computed_std, atol=1e-5, rtol=1e-5
268268
), "Cached std doesn't match computed std"
269269

270-
# Test normalizer functionality
271-
normalizer = cache.normalizer
270+
print(f"✓ Successfully tested ActivationCache with {len(cache)} activations")
271+
print(f"✓ Mean shape: {cache.mean.shape}, Std shape: {cache.std.shape}")
272+
272273

273-
# Test normalization of a sample activation
274-
sample_activation = cached_activations[0]
275-
normalized = normalizer(sample_activation)
274+
def test_sequence_ranges_no_bos_token(temp_dir):
275+
"""Test that sequence ranges are stored when model has no BOS token."""
276+
# Set flag to handle meta tensors properly
277+
if hasattr(th.fx, 'experimental'):
278+
th.fx.experimental._config.meta_nonzero_assume_all_nonzero = True
276279

277-
# Verify normalization: (x - mean) / std (with small epsilon for numerical stability)
278-
expected_normalized = (sample_activation - cache.mean) / (cache.std + 1e-8)
279-
assert th.allclose(
280-
normalized, expected_normalized, atol=1e-6
281-
), "Normalizer doesn't work correctly"
280+
# Skip test if CUDA not available
281+
if not th.cuda.is_available():
282+
pytest.skip("CUDA not available, skipping test")
283+
284+
# Test strings of different lengths
285+
test_strings = [
286+
"Hello world",
287+
"This is a longer sentence with more tokens",
288+
"Short",
289+
"Medium length text here",
290+
]
282291

283-
# Test batch normalization
284-
batch_normalized = normalizer(cached_activations[:5])
285-
expected_batch_normalized = (cached_activations[:5] - cache.mean) / (
286-
cache.std + 1e-8
292+
# Load GPT-2 model and modify tokenizer to simulate no BOS token
293+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
294+
model = AutoModelForCausalLM.from_pretrained(
295+
"gpt2", device_map="auto", torch_dtype=th.float32
287296
)
288-
assert th.allclose(
289-
batch_normalized, expected_batch_normalized, atol=1e-6
290-
), "Batch normalization doesn't work correctly"
297+
model = LanguageModel(model, torch_dtype=th.float32, tokenizer=tokenizer)
298+
model.tokenizer.pad_token = model.tokenizer.eos_token
299+
300+
# Simulate model without BOS token
301+
original_bos_token_id = model.tokenizer.bos_token_id
302+
model.tokenizer.bos_token_id = None
303+
304+
tokens = model.tokenizer(test_strings, add_special_tokens=True, return_tensors="pt", padding=True, truncation=True)
305+
lengths = tokens["attention_mask"].sum(dim=1).tolist()
306+
ranges = np.cumsum([0] + lengths)
307+
try:
308+
# Get a transformer block
309+
target_layer = model.transformer.h[6]
310+
submodule_name = "transformer_h_6"
311+
312+
# Parameters for activation collection
313+
batch_size = 2
314+
context_len = 32
315+
d_model = 768
316+
317+
# Collect activations with sequence start tracking
318+
ActivationCache.collect(
319+
data=test_strings,
320+
submodules=(target_layer,),
321+
submodule_names=(submodule_name,),
322+
model=model,
323+
store_dir=temp_dir,
324+
batch_size=batch_size,
325+
context_len=context_len,
326+
shard_size=1000,
327+
d_model=d_model,
328+
io="out",
329+
store_tokens=True,
330+
shuffle_shards=False, # Required for sequence ranges
331+
)
291332

292-
# Test that normalization preserves shape
293-
assert (
294-
normalized.shape == sample_activation.shape
295-
), "Normalization changed tensor shape"
296-
assert (
297-
batch_normalized.shape == cached_activations[:5].shape
298-
), "Batch normalization changed tensor shape"
333+
# Load the cached activations
334+
cache = ActivationCache(temp_dir, submodule_name + "_out")
335+
336+
# Verify sequence ranges were stored
337+
sequence_ranges = cache.sequence_ranges
338+
assert sequence_ranges is not None, "sequence ranges should be stored for model without BOS token"
339+
340+
# Should have one sequence start per input string plus one for the last sequence
341+
assert len(sequence_ranges) == len(test_strings) + 1, f"Expected {len(test_strings)} sequence ranges, got {len(sequence_ranges)}"
342+
343+
# First sequence should start at position 0
344+
assert sequence_ranges[0].item() == 0, "First sequence should start at position 0"
345+
346+
# sequence ranges should be the same as the ranges computed from the tokens
347+
assert np.allclose(sequence_ranges, ranges), "sequence ranges should be the same as the ranges computed from the tokens"
348+
349+
# sequence ranges should be in ascending order
350+
for i in range(1, len(sequence_ranges)):
351+
assert sequence_ranges[i] > sequence_ranges[i-1], f"sequence ranges should be ascending: {sequence_ranges}"
352+
353+
# Verify sequence ranges align with token boundaries
354+
tokens = cache.tokens
355+
total_tokens = len(tokens)
356+
357+
# All sequence ranges should be valid indices
358+
for start_idx in sequence_ranges:
359+
assert 0 <= start_idx <= total_tokens, f"Invalid sequence start index: {start_idx}"
360+
361+
finally:
362+
# Restore original BOS token
363+
model.tokenizer.bos_token_id = original_bos_token_id
364+
365+
366+
def test_sequence_ranges_with_bos_token(temp_dir):
367+
"""Test that sequence ranges are NOT stored when model has BOS token."""
368+
# Set flag to handle meta tensors properly
369+
if hasattr(th.fx, 'experimental'):
370+
th.fx.experimental._config.meta_nonzero_assume_all_nonzero = True
299371

300-
print(f"✓ Successfully tested ActivationCache with {len(cache)} activations")
301-
print(f"✓ Mean shape: {cache.mean.shape}, Std shape: {cache.std.shape}")
302-
print(f"✓ Normalizer tests passed")
372+
# Skip test if CUDA not available
373+
if not th.cuda.is_available():
374+
pytest.skip("CUDA not available, skipping test")
375+
376+
test_strings = ["Hello world", "Another test sentence"]
377+
378+
# Load GPT-2 model with BOS token
379+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
380+
model = AutoModelForCausalLM.from_pretrained(
381+
"gpt2", device_map="auto", torch_dtype=th.float32
382+
)
383+
model = LanguageModel(model, torch_dtype=th.float32, tokenizer=tokenizer)
384+
model.tokenizer.pad_token = model.tokenizer.eos_token
385+
386+
# Ensure model has BOS token (set it explicitly)
387+
model.tokenizer.bos_token_id = model.tokenizer.eos_token_id
388+
389+
# Get a transformer block
390+
target_layer = model.transformer.h[6]
391+
submodule_name = "transformer_h_6"
392+
393+
# Collect activations
394+
ActivationCache.collect(
395+
data=test_strings,
396+
submodules=(target_layer,),
397+
submodule_names=(submodule_name,),
398+
model=model,
399+
store_dir=temp_dir,
400+
batch_size=2,
401+
context_len=32,
402+
shard_size=1000,
403+
d_model=768,
404+
io="out",
405+
store_tokens=True,
406+
shuffle_shards=False,
407+
)
408+
409+
# Load the cached activations
410+
cache = ActivationCache(temp_dir, submodule_name + "_out")
411+
412+
# Verify sequence ranges were NOT stored
413+
sequence_ranges = cache.sequence_ranges
414+
assert sequence_ranges is None, "sequence ranges should not be stored for model with BOS token"

0 commit comments

Comments
 (0)