Skip to content

Commit 9ec8cfa

Browse files
committed
Added support for slice indices in cache. Added return_last_val_log for trainSAE.
1 parent 47ba383 commit 9ec8cfa

File tree

3 files changed

+196
-12
lines changed

3 files changed

+196
-12
lines changed

dictionary_learning/cache.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,14 @@ def __init__(
210210
def __len__(self):
211211
return self.activations.shape[0]
212212

213-
def __getitem__(self, *indices):
214-
return th.tensor(self.activations[(*indices,)]).view(self.dtype)
213+
def __getitem__(self, index):
214+
if isinstance(index, slice):
215+
data = self.activations[index]
216+
elif isinstance(index, tuple):
217+
data = self.activations[index]
218+
else:
219+
data = self.activations[index]
220+
return th.tensor(data).view(self.dtype)
215221

216222

217223
def save_shard(activations, store_dir, shard_count, name, io):
@@ -313,11 +319,29 @@ def running_stats(self):
313319
def __len__(self):
314320
return self.config["total_size"]
315321

316-
def __getitem__(self, index: int):
317-
shard_idx = np.searchsorted(self._range_to_shard_idx, index, side="right") - 1
318-
offset = index - self._range_to_shard_idx[shard_idx]
319-
shard = self.shards[shard_idx]
320-
return shard[offset]
322+
def __getitem__(self, index):
323+
if isinstance(index, slice):
324+
# Handle slice objects
325+
start, stop, step = index.indices(len(self))
326+
start_shard_idx = np.searchsorted(self._range_to_shard_idx, start, side="right") - 1
327+
stop_shard_idx = np.searchsorted(self._range_to_shard_idx, stop, side="right") - 1
328+
if start_shard_idx == stop_shard_idx:
329+
offset = start - self._range_to_shard_idx[start_shard_idx]
330+
end_offset = stop - self._range_to_shard_idx[stop_shard_idx]
331+
shard = self.shards[start_shard_idx]
332+
return shard[offset:end_offset:step]
333+
else:
334+
# Lazily load if we are not in the same shard
335+
# TODO: make this more efficient
336+
return th.stack([self[i] for i in range(start, stop, step)], dim=0)
337+
elif isinstance(index, int):
338+
# Handle single integer index
339+
shard_idx = np.searchsorted(self._range_to_shard_idx, index, side="right") - 1
340+
offset = index - self._range_to_shard_idx[shard_idx]
341+
shard = self.shards[shard_idx]
342+
return shard[offset]
343+
else:
344+
raise TypeError(f"Index must be int or slice, got {type(index)}")
321345

322346
@property
323347
def tokens(self):
@@ -732,10 +756,17 @@ def __init__(self, store_dir_1: str, store_dir_2: str, submodule_name: str = Non
732756
def __len__(self):
733757
return len(self.activation_cache_1)
734758

735-
def __getitem__(self, index: int):
736-
return th.stack(
737-
(self.activation_cache_1[index], self.activation_cache_2[index]), dim=0
738-
)
759+
def __getitem__(self, index):
760+
if isinstance(index, slice):
761+
return th.stack(
762+
(self.activation_cache_1[index], self.activation_cache_2[index]), dim=1
763+
)
764+
elif isinstance(index, int):
765+
return th.stack(
766+
(self.activation_cache_1[index], self.activation_cache_2[index]), dim=0
767+
)
768+
else:
769+
raise TypeError(f"Index must be int or slice, got {type(index)}")
739770

740771
@property
741772
def tokens(self):

dictionary_learning/training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def trainSAE(
218218
dtype=th.float32,
219219
run_wandb_finish=True,
220220
epoch_idx_per_step: Optional[List[int]] = None,
221+
return_last_eval_logs=False,
221222
):
222223
"""
223224
Train SAE using the given trainer
@@ -364,4 +365,7 @@ def trainSAE(
364365
if use_wandb and run_wandb_finish:
365366
wandb.finish()
366367

367-
return get_model(trainer)
368+
if return_last_eval_logs:
369+
return get_model(trainer), last_eval_logs
370+
else:
371+
return get_model(trainer)

tests/test_cache.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,3 +412,152 @@ def test_sequence_ranges_with_bos_token(temp_dir):
412412
# Verify sequence ranges were NOT stored
413413
sequence_ranges = cache.sequence_ranges
414414
assert sequence_ranges is None, "sequence ranges should not be stored for model with BOS token"
415+
416+
417+
def test_activation_cache_slice_indexing_cross_shard(temp_dir):
418+
"""Test ActivationCache slice indexing that crosses shard boundaries."""
419+
# Set flag to handle meta tensors properly
420+
th.fx.experimental._config.meta_nonzero_assume_all_nonzero = True
421+
422+
# Skip test if CUDA not available to avoid device mapping issues
423+
if not th.cuda.is_available():
424+
pytest.skip("CUDA not available, skipping test to avoid device mapping issues")
425+
426+
# Create test strings with sufficient data to span multiple shards
427+
test_strings = [
428+
f"This is test sentence number {i} with some content to fill up the cache."
429+
for i in range(20) # Create more samples to ensure multiple shards
430+
]
431+
432+
# Use the list directly
433+
dataset = test_strings
434+
435+
# Load GPT-2 model
436+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
437+
model = AutoModelForCausalLM.from_pretrained(
438+
"gpt2", device_map="auto", torch_dtype=th.float32
439+
)
440+
model = LanguageModel(model, torch_dtype=th.float32, tokenizer=tokenizer)
441+
model.tokenizer.pad_token = model.tokenizer.eos_token
442+
443+
# Get a transformer block to extract activations from
444+
target_layer = model.transformer.h[6] # Middle layer of GPT-2
445+
submodule_name = "transformer_h_6"
446+
447+
# Parameters for activation collection - use small shard size to ensure multiple shards
448+
batch_size = 3
449+
context_len = 32
450+
d_model = 768 # GPT-2 hidden size
451+
shard_size = 50 # Small shard size to force multiple shards
452+
453+
# Collect activations using ActivationCache
454+
ActivationCache.collect(
455+
data=dataset,
456+
submodules=(target_layer,),
457+
submodule_names=(submodule_name,),
458+
model=model,
459+
store_dir=temp_dir,
460+
batch_size=batch_size,
461+
context_len=context_len,
462+
shard_size=shard_size, # Small shard size for testing cross-shard slicing
463+
d_model=d_model,
464+
io="out",
465+
max_total_tokens=5000,
466+
store_tokens=True,
467+
shuffle_shards=False, # Important: don't shuffle so we can predict shard boundaries
468+
)
469+
470+
# Load the cached activations
471+
cache = ActivationCache(temp_dir, submodule_name + "_out")
472+
473+
# Verify we have multiple shards
474+
assert len(cache.shards) >= 2, f"Expected at least 2 shards, got {len(cache.shards)}"
475+
476+
total_size = len(cache)
477+
print(f"Cache has {len(cache.shards)} shards with total size {total_size}")
478+
479+
# Print shard boundaries for debugging
480+
shard_boundaries = cache._range_to_shard_idx
481+
print(f"Shard boundaries: {shard_boundaries}")
482+
483+
# Test 1: Slice that crosses exactly one shard boundary
484+
if len(cache.shards) >= 2:
485+
# Find a slice that starts in first shard and ends in second shard
486+
first_shard_end = shard_boundaries[1]
487+
start_idx = max(0, first_shard_end - 10)
488+
end_idx = min(total_size, first_shard_end + 10)
489+
490+
# Get slice result
491+
slice_result = cache[start_idx:end_idx]
492+
493+
# Get individual results for comparison
494+
individual_results = th.stack([cache[i] for i in range(start_idx, end_idx)], dim=0)
495+
496+
# Verify they match
497+
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
498+
f"Slice result doesn't match individual indexing for indices {start_idx}:{end_idx}"
499+
500+
# Verify correct shape
501+
expected_length = end_idx - start_idx
502+
assert slice_result.shape[0] == expected_length, \
503+
f"Expected slice length {expected_length}, got {slice_result.shape[0]}"
504+
505+
print(f"✓ Cross-shard slice test 1 passed: indices {start_idx}:{end_idx}")
506+
507+
# Test 2: Slice that spans multiple shards
508+
if len(cache.shards) >= 3:
509+
# Find a slice that starts in first shard and ends in third shard
510+
second_shard_end = shard_boundaries[2]
511+
start_idx = max(0, shard_boundaries[1] - 5) # Start near end of first shard
512+
end_idx = min(total_size, second_shard_end + 5) # End in third shard
513+
514+
slice_result = cache[start_idx:end_idx]
515+
individual_results = th.stack([cache[i] for i in range(start_idx, end_idx)], dim=0)
516+
517+
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
518+
f"Multi-shard slice result doesn't match individual indexing for indices {start_idx}:{end_idx}"
519+
520+
expected_length = end_idx - start_idx
521+
assert slice_result.shape[0] == expected_length, \
522+
f"Expected multi-shard slice length {expected_length}, got {slice_result.shape[0]}"
523+
524+
print(f"✓ Multi-shard slice test passed: indices {start_idx}:{end_idx}")
525+
526+
# Test 3: Slice with step parameter across shards
527+
if total_size >= 50:
528+
start_idx = 5
529+
end_idx = min(total_size, 45)
530+
step = 3
531+
532+
slice_result = cache[start_idx:end_idx:step]
533+
individual_results = th.stack([cache[i] for i in range(start_idx, end_idx, step)], dim=0)
534+
535+
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
536+
f"Stepped slice result doesn't match individual indexing for indices {start_idx}:{end_idx}:{step}"
537+
538+
expected_length = len(range(start_idx, end_idx, step))
539+
assert slice_result.shape[0] == expected_length, \
540+
f"Expected stepped slice length {expected_length}, got {slice_result.shape[0]}"
541+
542+
print(f"✓ Stepped slice test passed: indices {start_idx}:{end_idx}:{step}")
543+
544+
# Test 4: Edge cases - slice at boundaries
545+
if len(cache.shards) >= 2:
546+
# Test slice starting exactly at shard boundary
547+
boundary_idx = shard_boundaries[1]
548+
if boundary_idx < total_size - 5:
549+
slice_result = cache[boundary_idx:boundary_idx + 5]
550+
individual_results = th.stack([cache[i] for i in range(boundary_idx, boundary_idx + 5)], dim=0)
551+
552+
assert th.allclose(slice_result, individual_results, atol=1e-5, rtol=1e-5), \
553+
f"Boundary slice result doesn't match individual indexing"
554+
555+
print(f"✓ Boundary slice test passed: starting at shard boundary {boundary_idx}")
556+
557+
# Test 5: Empty slice
558+
empty_slice = cache[10:10]
559+
assert empty_slice.shape[0] == 0, f"Expected empty slice, got shape {empty_slice.shape}"
560+
print("✓ Empty slice test passed")
561+
562+
563+
print(f"✓ All slice indexing tests passed for cache with {len(cache.shards)} shards")

0 commit comments

Comments
 (0)