@@ -324,7 +324,7 @@ def on_step_end(
324324
325325 # Build normalizers from collected second moments
326326 for layer_name , moments in layer_second_moments .items ():
327- lr_sqrt = moments ["lr" ] ** 0.5
327+ lr = moments ["lr" ]
328328
329329 # Adam-like: has weight exp_avg_sq
330330 if "weight" in moments :
@@ -333,23 +333,16 @@ def on_step_end(
333333
334334 # Create Adam normalizer with optional bias, then convert to Adafactor
335335 # TODO: always convert to adafactor?
336- norm = AdamNormalizer (weight_eas , bias_eas ).to_adafactor ()
337-
338- # Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
339- norm .row = norm .row * lr_sqrt
340- norm .col = norm .col * lr_sqrt
341- if norm .bias_avg_sq is not None :
342- norm .bias_avg_sq = norm .bias_avg_sq * (lr_sqrt ** 2 )
336+ norm = (
337+ AdamNormalizer (weight_eas , bias_eas ).to_adafactor ().scale_by_lr (lr )
338+ )
343339
344340 # Adafactor-like: has row/col
345341 elif "row" in moments and "col" in moments :
346342 bias_eas = moments .get ("bias" ) # May be present
347- norm = AdafactorNormalizer (moments ["row" ], moments ["col" ], bias_eas )
348- # Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
349- norm .row = norm .row * lr_sqrt
350- norm .col = norm .col * lr_sqrt
351- if norm .bias_avg_sq is not None :
352- norm .bias_avg_sq = norm .bias_avg_sq * (lr_sqrt ** 2 )
343+ norm = AdafactorNormalizer (
344+ moments ["row" ], moments ["col" ], bias_eas
345+ ).scale_by_lr (lr )
353346 else :
354347 continue
355348
0 commit comments