Skip to content

Commit 17c0a92

Browse files
committed
Fixed issues with RunningStats loading
1 parent 7b4b216 commit 17c0a92

File tree

1 file changed

+12
-5
lines changed

1 file changed

+12
-5
lines changed

dictionary_learning/cache.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,15 @@ def load_or_create_state(
9494
M2 = torch.load(
9595
os.path.join(store_dir, "M2.pt"), weights_only=True, map_location=device
9696
)
97-
return RunningStatWelford(
97+
stat = RunningStatWelford(
9898
shape=mean.shape,
99-
dtype=dtype,
99+
dtype=mean.dtype,
100100
device=device,
101-
count=count,
102-
mean=mean,
103-
M2=M2,
104101
)
102+
stat.count = count
103+
stat.mean = mean
104+
stat.M2 = M2
105+
return stat
105106
else:
106107
return RunningStatWelford(shape=shape, dtype=dtype, device=device)
107108

@@ -306,6 +307,12 @@ def std(self):
306307
def normalizer(self):
307308
return ActivationNormalizer(self.mean, self.std)
308309

310+
@property
311+
def running_stats(self):
312+
return RunningStatWelford.load_or_create_state(
313+
self._cache_store_dir, shape=(self.config["d_model"],)
314+
)
315+
309316
def __len__(self):
310317
return self.config["total_size"]
311318

0 commit comments

Comments
 (0)