Skip to content

Commit 31e5008

Browse files
committed
add scale_by_lr method to AdafactorNormalizer
1 parent e5f6869 commit 31e5008

File tree

2 files changed

+22
-14
lines changed

2 files changed

+22
-14
lines changed

bergson/gradients.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@ def to_adam(self) -> "AdamNormalizer":
140140
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
141141
return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq)
142142

143+
def scale_by_lr(self, lr: float | Tensor) -> "AdafactorNormalizer":
144+
"""Scale normalizer by learning rate.
145+
146+
Factorized dimensions (row, col) are scaled by sqrt(lr).
147+
Bias is scaled by lr.
148+
"""
149+
lr_sqrt = lr**0.5
150+
return AdafactorNormalizer(
151+
row=self.row * lr_sqrt, # shape [O]
152+
col=self.col * lr_sqrt, # shape [I]
153+
bias_avg_sq=self.bias_avg_sq * lr
154+
if self.bias_avg_sq is not None
155+
else None, # shape [O]
156+
)
157+
143158

144159
@dataclass
145160
class AdamNormalizer(Normalizer):

bergson/huggingface.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def on_step_end(
324324

325325
# Build normalizers from collected second moments
326326
for layer_name, moments in layer_second_moments.items():
327-
lr_sqrt = moments["lr"] ** 0.5
327+
lr = moments["lr"]
328328

329329
# Adam-like: has weight exp_avg_sq
330330
if "weight" in moments:
@@ -333,23 +333,16 @@ def on_step_end(
333333

334334
# Create Adam normalizer with optional bias, then convert to Adafactor
335335
# TODO: always convert to adafactor?
336-
norm = AdamNormalizer(weight_eas, bias_eas).to_adafactor()
337-
338-
# Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
339-
norm.row = norm.row * lr_sqrt
340-
norm.col = norm.col * lr_sqrt
341-
if norm.bias_avg_sq is not None:
342-
norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2)
336+
norm = (
337+
AdamNormalizer(weight_eas, bias_eas).to_adafactor().scale_by_lr(lr)
338+
)
343339

344340
# Adafactor-like: has row/col
345341
elif "row" in moments and "col" in moments:
346342
bias_eas = moments.get("bias") # May be present
347-
norm = AdafactorNormalizer(moments["row"], moments["col"], bias_eas)
348-
# Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
349-
norm.row = norm.row * lr_sqrt
350-
norm.col = norm.col * lr_sqrt
351-
if norm.bias_avg_sq is not None:
352-
norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2)
343+
norm = AdafactorNormalizer(
344+
moments["row"], moments["col"], bias_eas
345+
).scale_by_lr(lr)
353346
else:
354347
continue
355348

0 commit comments

Comments
 (0)