File tree Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Expand file tree Collapse file tree 1 file changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -94,14 +94,15 @@ def load_or_create_state(
94
94
M2 = torch .load (
95
95
os .path .join (store_dir , "M2.pt" ), weights_only = True , map_location = device
96
96
)
97
- return RunningStatWelford (
97
+ stat = RunningStatWelford (
98
98
shape = mean .shape ,
99
- dtype = dtype ,
99
+ dtype = mean . dtype ,
100
100
device = device ,
101
- count = count ,
102
- mean = mean ,
103
- M2 = M2 ,
104
101
)
102
+ stat .count = count
103
+ stat .mean = mean
104
+ stat .M2 = M2
105
+ return stat
105
106
else :
106
107
return RunningStatWelford (shape = shape , dtype = dtype , device = device )
107
108
@@ -306,6 +307,12 @@ def std(self):
306
307
def normalizer (self ):
307
308
return ActivationNormalizer (self .mean , self .std )
308
309
310
+ @property
311
+ def running_stats (self ):
312
+ return RunningStatWelford .load_or_create_state (
313
+ self ._cache_store_dir , shape = (self .config ["d_model" ],)
314
+ )
315
+
309
316
def __len__ (self ):
310
317
return self .config ["total_size" ]
311
318
You can’t perform that action at this time.
0 commit comments