Skip to content

Commit 8d9beb9

Browse files
committed
Implement RunningStatWelford for streaming statistics and enhance ActivationCache with normalization features
This commit introduces the `RunningStatWelford` class, which provides a streaming mean and variance calculation using Welford's algorithm. This class supports arbitrary feature shapes and includes methods for updating statistics, merging accumulators, and saving/loading state. Additionally, the `ActivationCache` class is enhanced to include mean and standard deviation properties, along with an `ActivationNormalizer` for normalizing activations based on these statistics. The changes improve the flexibility and robustness of activation management in the caching mechanism, facilitating better handling of model activations during training and evaluation. Furthermore, the `BatchTopKSAE`, `CrossCoder`, and their respective trainers are updated to accept an optional `activation_normalizer`, allowing for normalization of activations during encoding and decoding processes. Tests for the new functionality are also added to ensure correctness and reliability.
1 parent 317b74e commit 8d9beb9

File tree

7 files changed

+934
-45
lines changed

7 files changed

+934
-45
lines changed

dictionary_learning/cache.py

Lines changed: 260 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,182 @@
1111
import time
1212
import json
1313
from .config import DEBUG
14-
from .utils import dtype_to_str, str_to_dtype, torch_to_numpy_dtype
15-
14+
from .utils import (
15+
dtype_to_str,
16+
str_to_dtype,
17+
torch_to_numpy_dtype,
18+
ActivationNormalizer,
19+
)
1620

1721
if DEBUG:
1822
tracer_kwargs = {"scan": True, "validate": True}
1923
else:
2024
tracer_kwargs = {"scan": False, "validate": False}
2125

26+
import torch
27+
from typing import Tuple
28+
29+
30+
class RunningStatWelford:
31+
"""
32+
Streaming (online) mean / variance with Welford's algorithm.
33+
34+
Works for arbitrary feature shapes – e.g. a vector of size D, a 2-D image
35+
channel grid, … anything except that the first axis of the update batch
36+
is interpreted as the *sample* axis.
37+
38+
Args:
39+
shape: Feature shape tuple (e.g., (512,) for vector features)
40+
dtype: Data type for internal computations
41+
device: Device to store tensors on
42+
43+
Example:
44+
stats = RunningStatWelford(shape=(512,), device="cuda")
45+
for batch in dataloader:
46+
stats.update(batch) # batch.shape == (B, 512)
47+
48+
print(stats.mean) # current running mean
49+
print(stats.std(unbiased=True)) # sample std-dev (Bessel-corrected)
50+
"""
51+
52+
def __init__(
53+
self,
54+
shape: Tuple[int, ...],
55+
dtype=torch.float64,
56+
device: torch.device | str = "cpu",
57+
):
58+
self.device = torch.device(device)
59+
self.dtype = dtype
60+
61+
self.count = torch.tensor(0, dtype=torch.long, device=self.device)
62+
self.mean = torch.zeros(shape, dtype=dtype, device=self.device)
63+
self.M2 = torch.zeros(shape, dtype=dtype, device=self.device)
64+
65+
def save_state(self, store_dir: str):
66+
"""
67+
Save the current state of the running statistics to a file.
68+
"""
69+
torch.save(self.count.cpu(), os.path.join(store_dir, "count.pt"))
70+
torch.save(self.mean.cpu(), os.path.join(store_dir, "mean.pt"))
71+
torch.save(self.M2.cpu(), os.path.join(store_dir, "M2.pt"))
72+
73+
@staticmethod
74+
def load_or_create_state(
75+
store_dir: str,
76+
dtype: torch.dtype = torch.float64,
77+
device: torch.device | str = "cpu",
78+
shape: Tuple[int, ...] = None,
79+
):
80+
"""
81+
Load the current state of the running statistics from a file.
82+
"""
83+
if os.path.exists(os.path.join(store_dir, "count.pt")):
84+
count = torch.load(
85+
os.path.join(store_dir, "count.pt"),
86+
weights_only=True,
87+
map_location=device,
88+
)
89+
mean = torch.load(
90+
os.path.join(store_dir, "mean.pt"),
91+
weights_only=True,
92+
map_location=device,
93+
)
94+
M2 = torch.load(
95+
os.path.join(store_dir, "M2.pt"), weights_only=True, map_location=device
96+
)
97+
return RunningStatWelford(
98+
shape=mean.shape,
99+
dtype=dtype,
100+
device=device,
101+
count=count,
102+
mean=mean,
103+
M2=M2,
104+
)
105+
else:
106+
return RunningStatWelford(shape=shape, dtype=dtype, device=device)
107+
108+
def update(self, x: torch.Tensor) -> None:
109+
"""
110+
Incorporate a new mini-batch `x` whose *first* dimension is batch-size.
111+
112+
Args:
113+
x: Input tensor with batch dimension first
114+
"""
115+
if x.numel() == 0:
116+
return # nothing to do
117+
118+
# ensure dtype/device match internal buffers
119+
x = x.clone().to(device=self.device, dtype=self.dtype)
120+
121+
batch_n = x.shape[0]
122+
batch_mean = x.mean(dim=0)
123+
batch_M2 = ((x - batch_mean) ** 2).sum(dim=0)
124+
125+
delta = batch_mean - self.mean
126+
total_n = self.count + batch_n
127+
128+
# merge step (Chan-Golub-LeVeque)
129+
self.mean += delta * batch_n / total_n
130+
self.M2 += batch_M2 + (delta**2) * self.count * batch_n / total_n
131+
self.count = total_n
132+
133+
def merge(self, other: "RunningStatWelford") -> None:
134+
"""
135+
Merge another (independent) accumulator into this one in O(1).
136+
Useful for distributed training / multi-loader aggregation.
137+
138+
Args:
139+
other: Another RunningStatWelford instance to merge
140+
"""
141+
if other.count == 0:
142+
return
143+
if self.count == 0:
144+
# shallow copy of buffers
145+
self.count = other.count.clone()
146+
self.mean = other.mean.clone()
147+
self.M2 = other.M2.clone()
148+
return
149+
150+
delta = other.mean - self.mean
151+
total_n = self.count + other.count
152+
153+
self.mean += delta * other.count / total_n
154+
self.M2 += other.M2 + (delta**2) * self.count * other.count / total_n
155+
self.count = total_n
156+
157+
def var(self, unbiased: bool = True) -> torch.Tensor:
158+
"""
159+
Return per-feature variance.
160+
161+
Args:
162+
unbiased: If True, divide by (n-1) for sample variance (Bessel-corrected).
163+
If False, divide by n for population variance.
164+
165+
Returns:
166+
Per-feature variance tensor
167+
"""
168+
if self.count < (2 if unbiased else 1):
169+
return torch.full_like(self.mean, float("nan"))
170+
denom = self.count - 1 if unbiased else self.count
171+
return self.M2 / denom
172+
173+
def std(self, unbiased: bool = True) -> torch.Tensor:
174+
"""
175+
Standard deviation (sqrt of `var`).
176+
177+
Args:
178+
unbiased: If True, use sample std-dev. If False, use population std-dev.
179+
180+
Returns:
181+
Per-feature standard deviation tensor
182+
"""
183+
return torch.sqrt(self.var(unbiased=unbiased))
184+
185+
@property
186+
def n(self) -> int:
187+
"""Number of samples processed."""
188+
return int(self.count.item())
189+
22190

23191
class ActivationShard:
24192
def __init__(
@@ -101,6 +269,43 @@ def __init__(self, store_dir: str, submodule_name: str = None):
101269
os.path.join(store_dir, "tokens.pt"), weights_only=True
102270
).cpu()
103271

272+
self._mean = None
273+
self._std = None
274+
275+
@property
276+
def mean(self):
277+
if self._mean is None:
278+
if os.path.exists(os.path.join(self._cache_store_dir, "mean.pt")):
279+
self._mean = th.load(
280+
os.path.join(self._cache_store_dir, "mean.pt"),
281+
weights_only=True,
282+
map_location=th.device("cpu"),
283+
)
284+
else:
285+
raise ValueError(
286+
f"Mean not found for {self._cache_store_dir}. Re-run the collection script."
287+
)
288+
return self._mean
289+
290+
@property
291+
def std(self):
292+
if self._std is None:
293+
if os.path.exists(os.path.join(self._cache_store_dir, "std.pt")):
294+
self._std = th.load(
295+
os.path.join(self._cache_store_dir, "std.pt"),
296+
weights_only=True,
297+
map_location=th.device("cpu"),
298+
)
299+
else:
300+
raise ValueError(
301+
f"Std not found for {self._cache_store_dir}. Re-run the collection script."
302+
)
303+
return self._std
304+
305+
@property
306+
def normalizer(self):
307+
return ActivationNormalizer(self.mean, self.std)
308+
104309
def __len__(self):
105310
return self.config["total_size"]
106311

@@ -277,6 +482,15 @@ def collect(
277482
]
278483
for store_sub_dir in store_sub_dirs:
279484
os.makedirs(store_sub_dir, exist_ok=True)
485+
486+
# load running stats
487+
running_stats = [
488+
RunningStatWelford.load_or_create_state(
489+
store_sub_dir, dtype, shape=(d_model,)
490+
)
491+
for store_sub_dir in store_sub_dirs
492+
]
493+
280494
total_size = 0
281495
current_size = 0
282496
shard_count = 0
@@ -351,6 +565,7 @@ def collect(
351565
.value[store_mask.reshape(-1).bool()]
352566
.cpu()
353567
) # remove padding tokens
568+
running_stats[i].update(activation_cache[i][-1].view(-1, d_model))
354569
if dtype is not None:
355570
activation_cache[i][-1] = activation_cache[i][-1].to(dtype)
356571

@@ -375,6 +590,8 @@ def collect(
375590
io,
376591
multiprocessing=multiprocessing,
377592
)
593+
for i in range(len(submodules)):
594+
running_stats[i].save_state(store_sub_dirs[i])
378595
shard_count += 1
379596

380597
total_size += current_size
@@ -400,6 +617,8 @@ def collect(
400617
io,
401618
multiprocessing=multiprocessing,
402619
)
620+
for i in range(len(submodules)):
621+
running_stats[i].save_state(store_sub_dirs[i])
403622
shard_count += 1
404623
total_size += current_size
405624

@@ -430,6 +649,15 @@ def collect(
430649
), f"{tokens_cache.shape[0]} != {total_size}"
431650
th.save(tokens_cache, os.path.join(store_dir, "tokens.pt"))
432651

652+
# store running stats
653+
for i in range(len(submodules)):
654+
th.save(
655+
running_stats[i].mean.cpu(), os.path.join(store_sub_dirs[i], "mean.pt")
656+
)
657+
th.save(
658+
running_stats[i].std().cpu(), os.path.join(store_sub_dirs[i], "std.pt")
659+
)
660+
433661
ActivationCache.cleanup_multiprocessing()
434662
print(f"Finished collecting activations. Total size: {total_size}")
435663

@@ -454,11 +682,27 @@ def tokens(self):
454682
(self.activation_cache_1.tokens, self.activation_cache_2.tokens), dim=0
455683
)
456684

685+
@property
686+
def mean(self):
687+
return th.stack(
688+
(self.activation_cache_1.mean, self.activation_cache_2.mean), dim=0
689+
)
690+
691+
@property
692+
def std(self):
693+
return th.stack(
694+
(self.activation_cache_1.std, self.activation_cache_2.std), dim=0
695+
)
696+
697+
@property
698+
def normalizer(self):
699+
return ActivationNormalizer(self.mean, self.std)
700+
457701

458702
class ActivationCacheTuple:
459-
def __init__(self, *store_dirs: str):
703+
def __init__(self, *store_dirs: str, submodule_name: str = None):
460704
self.activation_caches = [
461-
ActivationCache(store_dir) for store_dir in store_dirs
705+
ActivationCache(store_dir, submodule_name) for store_dir in store_dirs
462706
]
463707
assert len(self.activation_caches) > 0
464708
for i in range(1, len(self.activation_caches)):
@@ -473,3 +717,15 @@ def __getitem__(self, index: int):
473717
@property
474718
def tokens(self):
475719
return th.stack([cache.tokens for cache in self.activation_caches], dim=0)
720+
721+
@property
722+
def mean(self):
723+
return th.stack([cache.mean for cache in self.activation_caches], dim=0)
724+
725+
@property
726+
def std(self):
727+
return th.stack([cache.std for cache in self.activation_caches], dim=0)
728+
729+
@property
730+
def normalizer(self):
731+
return ActivationNormalizer(self.mean, self.std)

0 commit comments

Comments
 (0)