@@ -412,3 +412,152 @@ def test_sequence_ranges_with_bos_token(temp_dir):
412
412
# Verify sequence ranges were NOT stored
413
413
sequence_ranges = cache .sequence_ranges
414
414
assert sequence_ranges is None , "sequence ranges should not be stored for model with BOS token"
415
+
416
+
417
+ def test_activation_cache_slice_indexing_cross_shard (temp_dir ):
418
+ """Test ActivationCache slice indexing that crosses shard boundaries."""
419
+ # Set flag to handle meta tensors properly
420
+ th .fx .experimental ._config .meta_nonzero_assume_all_nonzero = True
421
+
422
+ # Skip test if CUDA not available to avoid device mapping issues
423
+ if not th .cuda .is_available ():
424
+ pytest .skip ("CUDA not available, skipping test to avoid device mapping issues" )
425
+
426
+ # Create test strings with sufficient data to span multiple shards
427
+ test_strings = [
428
+ f"This is test sentence number { i } with some content to fill up the cache."
429
+ for i in range (20 ) # Create more samples to ensure multiple shards
430
+ ]
431
+
432
+ # Use the list directly
433
+ dataset = test_strings
434
+
435
+ # Load GPT-2 model
436
+ tokenizer = AutoTokenizer .from_pretrained ("gpt2" )
437
+ model = AutoModelForCausalLM .from_pretrained (
438
+ "gpt2" , device_map = "auto" , torch_dtype = th .float32
439
+ )
440
+ model = LanguageModel (model , torch_dtype = th .float32 , tokenizer = tokenizer )
441
+ model .tokenizer .pad_token = model .tokenizer .eos_token
442
+
443
+ # Get a transformer block to extract activations from
444
+ target_layer = model .transformer .h [6 ] # Middle layer of GPT-2
445
+ submodule_name = "transformer_h_6"
446
+
447
+ # Parameters for activation collection - use small shard size to ensure multiple shards
448
+ batch_size = 3
449
+ context_len = 32
450
+ d_model = 768 # GPT-2 hidden size
451
+ shard_size = 50 # Small shard size to force multiple shards
452
+
453
+ # Collect activations using ActivationCache
454
+ ActivationCache .collect (
455
+ data = dataset ,
456
+ submodules = (target_layer ,),
457
+ submodule_names = (submodule_name ,),
458
+ model = model ,
459
+ store_dir = temp_dir ,
460
+ batch_size = batch_size ,
461
+ context_len = context_len ,
462
+ shard_size = shard_size , # Small shard size for testing cross-shard slicing
463
+ d_model = d_model ,
464
+ io = "out" ,
465
+ max_total_tokens = 5000 ,
466
+ store_tokens = True ,
467
+ shuffle_shards = False , # Important: don't shuffle so we can predict shard boundaries
468
+ )
469
+
470
+ # Load the cached activations
471
+ cache = ActivationCache (temp_dir , submodule_name + "_out" )
472
+
473
+ # Verify we have multiple shards
474
+ assert len (cache .shards ) >= 2 , f"Expected at least 2 shards, got { len (cache .shards )} "
475
+
476
+ total_size = len (cache )
477
+ print (f"Cache has { len (cache .shards )} shards with total size { total_size } " )
478
+
479
+ # Print shard boundaries for debugging
480
+ shard_boundaries = cache ._range_to_shard_idx
481
+ print (f"Shard boundaries: { shard_boundaries } " )
482
+
483
+ # Test 1: Slice that crosses exactly one shard boundary
484
+ if len (cache .shards ) >= 2 :
485
+ # Find a slice that starts in first shard and ends in second shard
486
+ first_shard_end = shard_boundaries [1 ]
487
+ start_idx = max (0 , first_shard_end - 10 )
488
+ end_idx = min (total_size , first_shard_end + 10 )
489
+
490
+ # Get slice result
491
+ slice_result = cache [start_idx :end_idx ]
492
+
493
+ # Get individual results for comparison
494
+ individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
495
+
496
+ # Verify they match
497
+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
498
+ f"Slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
499
+
500
+ # Verify correct shape
501
+ expected_length = end_idx - start_idx
502
+ assert slice_result .shape [0 ] == expected_length , \
503
+ f"Expected slice length { expected_length } , got { slice_result .shape [0 ]} "
504
+
505
+ print (f"✓ Cross-shard slice test 1 passed: indices { start_idx } :{ end_idx } " )
506
+
507
+ # Test 2: Slice that spans multiple shards
508
+ if len (cache .shards ) >= 3 :
509
+ # Find a slice that starts in first shard and ends in third shard
510
+ second_shard_end = shard_boundaries [2 ]
511
+ start_idx = max (0 , shard_boundaries [1 ] - 5 ) # Start near end of first shard
512
+ end_idx = min (total_size , second_shard_end + 5 ) # End in third shard
513
+
514
+ slice_result = cache [start_idx :end_idx ]
515
+ individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx )], dim = 0 )
516
+
517
+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
518
+ f"Multi-shard slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } "
519
+
520
+ expected_length = end_idx - start_idx
521
+ assert slice_result .shape [0 ] == expected_length , \
522
+ f"Expected multi-shard slice length { expected_length } , got { slice_result .shape [0 ]} "
523
+
524
+ print (f"✓ Multi-shard slice test passed: indices { start_idx } :{ end_idx } " )
525
+
526
+ # Test 3: Slice with step parameter across shards
527
+ if total_size >= 50 :
528
+ start_idx = 5
529
+ end_idx = min (total_size , 45 )
530
+ step = 3
531
+
532
+ slice_result = cache [start_idx :end_idx :step ]
533
+ individual_results = th .stack ([cache [i ] for i in range (start_idx , end_idx , step )], dim = 0 )
534
+
535
+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
536
+ f"Stepped slice result doesn't match individual indexing for indices { start_idx } :{ end_idx } :{ step } "
537
+
538
+ expected_length = len (range (start_idx , end_idx , step ))
539
+ assert slice_result .shape [0 ] == expected_length , \
540
+ f"Expected stepped slice length { expected_length } , got { slice_result .shape [0 ]} "
541
+
542
+ print (f"✓ Stepped slice test passed: indices { start_idx } :{ end_idx } :{ step } " )
543
+
544
+ # Test 4: Edge cases - slice at boundaries
545
+ if len (cache .shards ) >= 2 :
546
+ # Test slice starting exactly at shard boundary
547
+ boundary_idx = shard_boundaries [1 ]
548
+ if boundary_idx < total_size - 5 :
549
+ slice_result = cache [boundary_idx :boundary_idx + 5 ]
550
+ individual_results = th .stack ([cache [i ] for i in range (boundary_idx , boundary_idx + 5 )], dim = 0 )
551
+
552
+ assert th .allclose (slice_result , individual_results , atol = 1e-5 , rtol = 1e-5 ), \
553
+ f"Boundary slice result doesn't match individual indexing"
554
+
555
+ print (f"✓ Boundary slice test passed: starting at shard boundary { boundary_idx } " )
556
+
557
+ # Test 5: Empty slice
558
+ empty_slice = cache [10 :10 ]
559
+ assert empty_slice .shape [0 ] == 0 , f"Expected empty slice, got shape { empty_slice .shape } "
560
+ print ("✓ Empty slice test passed" )
561
+
562
+
563
+ print (f"✓ All slice indexing tests passed for cache with { len (cache .shards )} shards" )
0 commit comments