11
11
import time
12
12
import json
13
13
from .config import DEBUG
14
- from .utils import dtype_to_str , str_to_dtype , torch_to_numpy_dtype
15
-
14
+ from .utils import (
15
+ dtype_to_str ,
16
+ str_to_dtype ,
17
+ torch_to_numpy_dtype ,
18
+ ActivationNormalizer ,
19
+ )
16
20
17
21
if DEBUG :
18
22
tracer_kwargs = {"scan" : True , "validate" : True }
19
23
else :
20
24
tracer_kwargs = {"scan" : False , "validate" : False }
21
25
26
+ import torch
27
+ from typing import Tuple
28
+
29
+
30
+ class RunningStatWelford :
31
+ """
32
+ Streaming (online) mean / variance with Welford's algorithm.
33
+
34
+ Works for arbitrary feature shapes – e.g. a vector of size D, a 2-D image
35
+ channel grid, … anything except that the first axis of the update batch
36
+ is interpreted as the *sample* axis.
37
+
38
+ Args:
39
+ shape: Feature shape tuple (e.g., (512,) for vector features)
40
+ dtype: Data type for internal computations
41
+ device: Device to store tensors on
42
+
43
+ Example:
44
+ stats = RunningStatWelford(shape=(512,), device="cuda")
45
+ for batch in dataloader:
46
+ stats.update(batch) # batch.shape == (B, 512)
47
+
48
+ print(stats.mean) # current running mean
49
+ print(stats.std(unbiased=True)) # sample std-dev (Bessel-corrected)
50
+ """
51
+
52
+ def __init__ (
53
+ self ,
54
+ shape : Tuple [int , ...],
55
+ dtype = torch .float64 ,
56
+ device : torch .device | str = "cpu" ,
57
+ ):
58
+ self .device = torch .device (device )
59
+ self .dtype = dtype
60
+
61
+ self .count = torch .tensor (0 , dtype = torch .long , device = self .device )
62
+ self .mean = torch .zeros (shape , dtype = dtype , device = self .device )
63
+ self .M2 = torch .zeros (shape , dtype = dtype , device = self .device )
64
+
65
+ def save_state (self , store_dir : str ):
66
+ """
67
+ Save the current state of the running statistics to a file.
68
+ """
69
+ torch .save (self .count .cpu (), os .path .join (store_dir , "count.pt" ))
70
+ torch .save (self .mean .cpu (), os .path .join (store_dir , "mean.pt" ))
71
+ torch .save (self .M2 .cpu (), os .path .join (store_dir , "M2.pt" ))
72
+
73
+ @staticmethod
74
+ def load_or_create_state (
75
+ store_dir : str ,
76
+ dtype : torch .dtype = torch .float64 ,
77
+ device : torch .device | str = "cpu" ,
78
+ shape : Tuple [int , ...] = None ,
79
+ ):
80
+ """
81
+ Load the current state of the running statistics from a file.
82
+ """
83
+ if os .path .exists (os .path .join (store_dir , "count.pt" )):
84
+ count = torch .load (
85
+ os .path .join (store_dir , "count.pt" ),
86
+ weights_only = True ,
87
+ map_location = device ,
88
+ )
89
+ mean = torch .load (
90
+ os .path .join (store_dir , "mean.pt" ),
91
+ weights_only = True ,
92
+ map_location = device ,
93
+ )
94
+ M2 = torch .load (
95
+ os .path .join (store_dir , "M2.pt" ), weights_only = True , map_location = device
96
+ )
97
+ return RunningStatWelford (
98
+ shape = mean .shape ,
99
+ dtype = dtype ,
100
+ device = device ,
101
+ count = count ,
102
+ mean = mean ,
103
+ M2 = M2 ,
104
+ )
105
+ else :
106
+ return RunningStatWelford (shape = shape , dtype = dtype , device = device )
107
+
108
+ def update (self , x : torch .Tensor ) -> None :
109
+ """
110
+ Incorporate a new mini-batch `x` whose *first* dimension is batch-size.
111
+
112
+ Args:
113
+ x: Input tensor with batch dimension first
114
+ """
115
+ if x .numel () == 0 :
116
+ return # nothing to do
117
+
118
+ # ensure dtype/device match internal buffers
119
+ x = x .clone ().to (device = self .device , dtype = self .dtype )
120
+
121
+ batch_n = x .shape [0 ]
122
+ batch_mean = x .mean (dim = 0 )
123
+ batch_M2 = ((x - batch_mean ) ** 2 ).sum (dim = 0 )
124
+
125
+ delta = batch_mean - self .mean
126
+ total_n = self .count + batch_n
127
+
128
+ # merge step (Chan-Golub-LeVeque)
129
+ self .mean += delta * batch_n / total_n
130
+ self .M2 += batch_M2 + (delta ** 2 ) * self .count * batch_n / total_n
131
+ self .count = total_n
132
+
133
+ def merge (self , other : "RunningStatWelford" ) -> None :
134
+ """
135
+ Merge another (independent) accumulator into this one in O(1).
136
+ Useful for distributed training / multi-loader aggregation.
137
+
138
+ Args:
139
+ other: Another RunningStatWelford instance to merge
140
+ """
141
+ if other .count == 0 :
142
+ return
143
+ if self .count == 0 :
144
+ # shallow copy of buffers
145
+ self .count = other .count .clone ()
146
+ self .mean = other .mean .clone ()
147
+ self .M2 = other .M2 .clone ()
148
+ return
149
+
150
+ delta = other .mean - self .mean
151
+ total_n = self .count + other .count
152
+
153
+ self .mean += delta * other .count / total_n
154
+ self .M2 += other .M2 + (delta ** 2 ) * self .count * other .count / total_n
155
+ self .count = total_n
156
+
157
+ def var (self , unbiased : bool = True ) -> torch .Tensor :
158
+ """
159
+ Return per-feature variance.
160
+
161
+ Args:
162
+ unbiased: If True, divide by (n-1) for sample variance (Bessel-corrected).
163
+ If False, divide by n for population variance.
164
+
165
+ Returns:
166
+ Per-feature variance tensor
167
+ """
168
+ if self .count < (2 if unbiased else 1 ):
169
+ return torch .full_like (self .mean , float ("nan" ))
170
+ denom = self .count - 1 if unbiased else self .count
171
+ return self .M2 / denom
172
+
173
+ def std (self , unbiased : bool = True ) -> torch .Tensor :
174
+ """
175
+ Standard deviation (sqrt of `var`).
176
+
177
+ Args:
178
+ unbiased: If True, use sample std-dev. If False, use population std-dev.
179
+
180
+ Returns:
181
+ Per-feature standard deviation tensor
182
+ """
183
+ return torch .sqrt (self .var (unbiased = unbiased ))
184
+
185
+ @property
186
+ def n (self ) -> int :
187
+ """Number of samples processed."""
188
+ return int (self .count .item ())
189
+
22
190
23
191
class ActivationShard :
24
192
def __init__ (
@@ -101,6 +269,43 @@ def __init__(self, store_dir: str, submodule_name: str = None):
101
269
os .path .join (store_dir , "tokens.pt" ), weights_only = True
102
270
).cpu ()
103
271
272
+ self ._mean = None
273
+ self ._std = None
274
+
275
+ @property
276
+ def mean (self ):
277
+ if self ._mean is None :
278
+ if os .path .exists (os .path .join (self ._cache_store_dir , "mean.pt" )):
279
+ self ._mean = th .load (
280
+ os .path .join (self ._cache_store_dir , "mean.pt" ),
281
+ weights_only = True ,
282
+ map_location = th .device ("cpu" ),
283
+ )
284
+ else :
285
+ raise ValueError (
286
+ f"Mean not found for { self ._cache_store_dir } . Re-run the collection script."
287
+ )
288
+ return self ._mean
289
+
290
+ @property
291
+ def std (self ):
292
+ if self ._std is None :
293
+ if os .path .exists (os .path .join (self ._cache_store_dir , "std.pt" )):
294
+ self ._std = th .load (
295
+ os .path .join (self ._cache_store_dir , "std.pt" ),
296
+ weights_only = True ,
297
+ map_location = th .device ("cpu" ),
298
+ )
299
+ else :
300
+ raise ValueError (
301
+ f"Std not found for { self ._cache_store_dir } . Re-run the collection script."
302
+ )
303
+ return self ._std
304
+
305
+ @property
306
+ def normalizer (self ):
307
+ return ActivationNormalizer (self .mean , self .std )
308
+
104
309
def __len__ (self ):
105
310
return self .config ["total_size" ]
106
311
@@ -277,6 +482,15 @@ def collect(
277
482
]
278
483
for store_sub_dir in store_sub_dirs :
279
484
os .makedirs (store_sub_dir , exist_ok = True )
485
+
486
+ # load running stats
487
+ running_stats = [
488
+ RunningStatWelford .load_or_create_state (
489
+ store_sub_dir , dtype , shape = (d_model ,)
490
+ )
491
+ for store_sub_dir in store_sub_dirs
492
+ ]
493
+
280
494
total_size = 0
281
495
current_size = 0
282
496
shard_count = 0
@@ -351,6 +565,7 @@ def collect(
351
565
.value [store_mask .reshape (- 1 ).bool ()]
352
566
.cpu ()
353
567
) # remove padding tokens
568
+ running_stats [i ].update (activation_cache [i ][- 1 ].view (- 1 , d_model ))
354
569
if dtype is not None :
355
570
activation_cache [i ][- 1 ] = activation_cache [i ][- 1 ].to (dtype )
356
571
@@ -375,6 +590,8 @@ def collect(
375
590
io ,
376
591
multiprocessing = multiprocessing ,
377
592
)
593
+ for i in range (len (submodules )):
594
+ running_stats [i ].save_state (store_sub_dirs [i ])
378
595
shard_count += 1
379
596
380
597
total_size += current_size
@@ -400,6 +617,8 @@ def collect(
400
617
io ,
401
618
multiprocessing = multiprocessing ,
402
619
)
620
+ for i in range (len (submodules )):
621
+ running_stats [i ].save_state (store_sub_dirs [i ])
403
622
shard_count += 1
404
623
total_size += current_size
405
624
@@ -430,6 +649,15 @@ def collect(
430
649
), f"{ tokens_cache .shape [0 ]} != { total_size } "
431
650
th .save (tokens_cache , os .path .join (store_dir , "tokens.pt" ))
432
651
652
+ # store running stats
653
+ for i in range (len (submodules )):
654
+ th .save (
655
+ running_stats [i ].mean .cpu (), os .path .join (store_sub_dirs [i ], "mean.pt" )
656
+ )
657
+ th .save (
658
+ running_stats [i ].std ().cpu (), os .path .join (store_sub_dirs [i ], "std.pt" )
659
+ )
660
+
433
661
ActivationCache .cleanup_multiprocessing ()
434
662
print (f"Finished collecting activations. Total size: { total_size } " )
435
663
@@ -454,11 +682,27 @@ def tokens(self):
454
682
(self .activation_cache_1 .tokens , self .activation_cache_2 .tokens ), dim = 0
455
683
)
456
684
685
+ @property
686
+ def mean (self ):
687
+ return th .stack (
688
+ (self .activation_cache_1 .mean , self .activation_cache_2 .mean ), dim = 0
689
+ )
690
+
691
+ @property
692
+ def std (self ):
693
+ return th .stack (
694
+ (self .activation_cache_1 .std , self .activation_cache_2 .std ), dim = 0
695
+ )
696
+
697
+ @property
698
+ def normalizer (self ):
699
+ return ActivationNormalizer (self .mean , self .std )
700
+
457
701
458
702
class ActivationCacheTuple :
459
- def __init__ (self , * store_dirs : str ):
703
+ def __init__ (self , * store_dirs : str , submodule_name : str = None ):
460
704
self .activation_caches = [
461
- ActivationCache (store_dir ) for store_dir in store_dirs
705
+ ActivationCache (store_dir , submodule_name ) for store_dir in store_dirs
462
706
]
463
707
assert len (self .activation_caches ) > 0
464
708
for i in range (1 , len (self .activation_caches )):
@@ -473,3 +717,15 @@ def __getitem__(self, index: int):
473
717
@property
474
718
def tokens (self ):
475
719
return th .stack ([cache .tokens for cache in self .activation_caches ], dim = 0 )
720
+
721
+ @property
722
+ def mean (self ):
723
+ return th .stack ([cache .mean for cache in self .activation_caches ], dim = 0 )
724
+
725
+ @property
726
+ def std (self ):
727
+ return th .stack ([cache .std for cache in self .activation_caches ], dim = 0 )
728
+
729
+ @property
730
+ def normalizer (self ):
731
+ return ActivationNormalizer (self .mean , self .std )
0 commit comments