Skip to content

Commit e5f6869

Browse files
committed
fix: add normalizer bias support. fix trainer callback. add tests
1 parent 40d05a4 commit e5f6869

File tree

3 files changed

+257
-22
lines changed

3 files changed

+257
-22
lines changed

bergson/gradients.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,24 @@ def state_dict(self) -> dict[str, str | Tensor]:
6868
class AdafactorNormalizer(Normalizer):
6969
"""
7070
Row and column sums of second moments of gradients for a matrix-valued parameter.
71+
72+
Args:
73+
row: Row statistics [O]
74+
col: Column statistics [I]
75+
bias_avg_sq: Optional second moments for bias [O]
7176
"""
7277

7378
row: Tensor # shape [O]
7479
col: Tensor # shape [I]
80+
bias_avg_sq: Tensor | None = None # shape [O]
7581

7682
def __post_init__(self):
7783
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
7884
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
85+
if self.bias_avg_sq is not None:
86+
assert self.bias_avg_sq.ndim == 1, (
87+
f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D"
88+
)
7989

8090
@torch.compile
8191
def normalize_(
@@ -120,22 +130,29 @@ def to_adam(self) -> "AdamNormalizer":
120130
"""
121131
Convert this Adafactor normalizer to an Adam normalizer by materializing the
122132
rank-one second moment matrix.
133+
134+
Preserves bias_avg_sq if present.
123135
"""
124136
# Compute the second moment matrix as a square matrix of shape [O, I]
125137
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
126138
# add it outside the square root. This could cause infs though if there are
127139
# any exactly zero rows or columns, so we should be careful.
128140
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
129-
return AdamNormalizer(avg_sq=avg_sq)
141+
return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq)
130142

131143

132144
@dataclass
133145
class AdamNormalizer(Normalizer):
134146
"""
135147
Contains the second moments of the gradients.
148+
149+
Args:
150+
avg_sq: Second moments for weights [O, I]
151+
bias_avg_sq: Optional second moments for bias [O]
136152
"""
137153

138154
avg_sq: Tensor
155+
bias_avg_sq: Tensor | None = None
139156

140157
@torch.compile
141158
def normalize_(
@@ -153,6 +170,8 @@ def to_adafactor(self) -> AdafactorNormalizer:
153170
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
154171
I-divergence (generalized Kullback-Leibler divergence) between the original
155172
and the factored second moments.
173+
174+
Preserves bias_avg_sq if present.
156175
"""
157176
# We assume avg_sq is a square matrix of shape [O, I]
158177
assert self.avg_sq.ndim == 2, (
@@ -163,6 +182,7 @@ def to_adafactor(self) -> AdafactorNormalizer:
163182
return AdafactorNormalizer(
164183
row=self.avg_sq.mean(dim=1), # shape [O]
165184
col=self.avg_sq.mean(dim=0), # shape [I]
185+
bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments
166186
)
167187

168188

@@ -551,8 +571,22 @@ def _process_grad(self, module: nn.Module, _, grad_out):
551571
i = getattr(module, LayerAdapter.in_attr(module))
552572
o = getattr(module, LayerAdapter.out_attr(module))
553573

554-
# Pre-scale G by the Adafactor row statistics
574+
# Handle bias gradients if needed (must be computed from raw G)
555575
norm = self.processor.normalizers.get(name)
576+
bias_grad = None
577+
if include_bias:
578+
# Compute bias from raw G (before any normalization)
579+
bias_grad = G.sum(dim=1) # [N, S, O] -> [N, O]
580+
581+
# Normalize bias with appropriate second moments
582+
if (
583+
isinstance(norm, (AdamNormalizer, AdafactorNormalizer))
584+
and hasattr(norm, "bias_avg_sq")
585+
and norm.bias_avg_sq is not None
586+
):
587+
bias_grad = bias_grad / norm.bias_avg_sq.sqrt().add_(1e-8)
588+
589+
# Pre-scale G by the Adafactor row statistics (for weight gradients)
556590
if isinstance(norm, AdafactorNormalizer):
557591
# Compare to the normalize_ method in AdafactorNormalizer
558592
r = norm.row.add(1e-30)
@@ -568,11 +602,10 @@ def _process_grad(self, module: nn.Module, _, grad_out):
568602
# Normalize the gradients using the second moment matrix
569603
P /= norm.avg_sq.sqrt().add_(1e-8)
570604

571-
if include_bias:
572-
# TODO: should we normalize the bias gradients?
573-
# Append the raw bias gradient to the input
605+
if include_bias and bias_grad is not None:
606+
# Append pre-computed and normalized bias gradient
574607
P = torch.cat(
575-
[P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1]
608+
[P, bias_grad.unsqueeze(2)], # [N, O, 1]
576609
dim=2,
577610
)
578611
i += 1

bergson/huggingface.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ def on_step_end(
239239
**kwargs,
240240
):
241241
self.on_substep_end(args, state, control)
242-
print("Step end")
243242

244243
# Record training order if enabled
245244
if self.order is not None:
@@ -279,32 +278,82 @@ def on_step_end(
279278

280279
# Read normalizers off of the optimizer state. We need to figure out
281280
# what type of optimizer this is first.
281+
# Collect references to both weight and bias second moments per layer
282+
layer_second_moments: dict[str, dict[str, Tensor]] = {}
283+
282284
for group in optimizer.param_groups:
283-
lr_sqrt = group["lr"] ** 0.5
285+
group_lr = group["lr"]
284286

285287
for param in group["params"]:
286-
name = param_to_name[param].removesuffix(".weight")
287-
if name not in self.collector.target_info:
288+
param_name = param_to_name[param]
289+
290+
# Extract layer name (remove .weight or .bias suffix)
291+
if param_name.endswith(".weight"):
292+
param_type = "weight"
293+
layer_name = param_name.removesuffix(".weight")
294+
elif param_name.endswith(".bias"):
295+
param_type = "bias"
296+
layer_name = param_name.removesuffix(".bias")
297+
else:
298+
continue
299+
300+
if layer_name not in self.collector.target_info:
288301
continue
289302

290303
p_state = optimizer.state[param]
291304

305+
# Initialize layer dict if needed, storing this group's learning rate
306+
if layer_name not in layer_second_moments:
307+
layer_second_moments[layer_name] = {"lr": group_lr}
308+
292309
# Adam-like optimizer
293310
if (eas := p_state.get("exp_avg_sq")) is not None:
294-
norm = AdamNormalizer(eas).to_adafactor()
295-
311+
layer_second_moments[layer_name][param_type] = eas
296312
# Adafactor-like optimizer
297313
elif (vr := p_state.get("exp_avg_sq_row")) is not None:
298314
vc = p_state.get("exp_avg_sq_col")
299-
norm = AdafactorNormalizer(vr, vc)
300-
else:
301-
continue
302-
303-
# Scale the gradient by the current learning rate. It's factorized
304-
# so we multiply each factor by the square root of the LR.
305-
norm.row *= lr_sqrt
306-
norm.col *= lr_sqrt
307-
normalizers[name] = norm
315+
if param_type == "weight":
316+
# Factorized second moments for weights
317+
layer_second_moments[layer_name]["row"] = vr
318+
layer_second_moments[layer_name]["col"] = vc
319+
elif param_type == "bias":
320+
# Adafactor stores bias as regular exp_avg_sq
321+
bias_eas = p_state.get("exp_avg_sq")
322+
if bias_eas is not None:
323+
layer_second_moments[layer_name]["bias"] = bias_eas
324+
325+
# Build normalizers from collected second moments
326+
for layer_name, moments in layer_second_moments.items():
327+
lr_sqrt = moments["lr"] ** 0.5
328+
329+
# Adam-like: has weight exp_avg_sq
330+
if "weight" in moments:
331+
weight_eas = moments["weight"]
332+
bias_eas = moments.get("bias") # May be None
333+
334+
# Create Adam normalizer with optional bias, then convert to Adafactor
335+
# TODO: always convert to adafactor?
336+
norm = AdamNormalizer(weight_eas, bias_eas).to_adafactor()
337+
338+
# Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
339+
norm.row = norm.row * lr_sqrt
340+
norm.col = norm.col * lr_sqrt
341+
if norm.bias_avg_sq is not None:
342+
norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2)
343+
344+
# Adafactor-like: has row/col
345+
elif "row" in moments and "col" in moments:
346+
bias_eas = moments.get("bias") # May be present
347+
norm = AdafactorNormalizer(moments["row"], moments["col"], bias_eas)
348+
# Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
349+
norm.row = norm.row * lr_sqrt
350+
norm.col = norm.col * lr_sqrt
351+
if norm.bias_avg_sq is not None:
352+
norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2)
353+
else:
354+
continue
355+
356+
normalizers[layer_name] = norm
308357

309358
proc.normalizers = normalizers
310359

tests/test_trainer_callback.py

Lines changed: 154 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import os
2+
from pathlib import Path
3+
4+
from torch import nn
5+
6+
from bergson import GradientProcessor
7+
from bergson.gradients import AdafactorNormalizer
28

39
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
410
os.environ["WANDB_MODE"] = "disabled"
511

612
import pytest
713
import torch
814
from datasets import Dataset
9-
from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments
15+
from transformers import (
16+
Adafactor,
17+
AutoConfig,
18+
AutoModelForCausalLM,
19+
Trainer,
20+
TrainingArguments,
21+
)
1022
from trl import SFTConfig, SFTTrainer
1123

1224
from bergson.data import load_gradients
@@ -245,3 +257,144 @@ def test_sft_trainer(self, tmp_path, model, dataset):
245257
saved_order = Dataset.load_from_disk(str(order_file))
246258
assert len(saved_order) > 0
247259
assert all(key in saved_order[0] for key in ["_idx", "global_step", "epoch"])
260+
261+
@pytest.mark.parametrize("optimizer_name", ["adam", "adafactor"])
262+
@pytest.mark.parametrize("include_bias", [True, False])
263+
def test_optimizer_state_extraction(self, optimizer_name: str, include_bias: bool):
264+
"""Test that normalizers are correctly extracted from optimizer state.
265+
266+
This tests the huggingface.py callback by:
267+
1. Training a model with an optimizer
268+
2. Calling the callback's on_step_end method
269+
3. Verifying against raw optimizer state
270+
"""
271+
torch.manual_seed(42)
272+
N = 4
273+
S = 6
274+
I = 5
275+
O = 3
276+
277+
class SimpleModel(nn.Module):
278+
def __init__(self):
279+
super().__init__()
280+
self.fc1 = nn.Linear(I, O * 2, bias=include_bias)
281+
self.relu = nn.ReLU()
282+
self.fc2 = nn.Linear(O * 2, O, bias=include_bias)
283+
284+
def forward(self, x):
285+
return self.fc2(self.relu(self.fc1(x)))
286+
287+
torch.manual_seed(42)
288+
model = SimpleModel()
289+
290+
# Create optimizer
291+
if optimizer_name == "adam":
292+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
293+
else:
294+
optimizer = Adafactor(
295+
model.parameters(), scale_parameter=False, relative_step=False, lr=0.001
296+
)
297+
298+
# Train a few steps to build up second moments
299+
for _ in range(5):
300+
optimizer.zero_grad()
301+
out = model(torch.randn(N, S, I))
302+
loss = (out**2).sum()
303+
loss.backward()
304+
optimizer.step()
305+
306+
# Extract normalizers using the ACTUAL callback
307+
from unittest.mock import Mock, patch
308+
309+
from bergson.huggingface import GradientCollectorCallback
310+
311+
# Create callback with minimal setup
312+
callback = GradientCollectorCallback(
313+
path=Path("/tmp/test"),
314+
use_optimizer_state=True,
315+
include_bias=include_bias,
316+
)
317+
318+
# Mock the collector and processor
319+
mock_collector = Mock()
320+
mock_collector.processor = GradientProcessor(
321+
normalizers={}, include_bias=include_bias
322+
)
323+
mock_collector.target_info = {"fc1": None, "fc2": None} # Track these layers
324+
callback.collector = mock_collector
325+
326+
# Mock on_substep_end to avoid needing train_grad_buffer
327+
with patch.object(callback, "on_substep_end"):
328+
# Call the ACTUAL callback method
329+
callback.on_step_end(
330+
args=Mock(),
331+
state=Mock(epoch=0, global_step=1),
332+
control=Mock(),
333+
model=model,
334+
optimizer=optimizer,
335+
)
336+
337+
# Get the normalizers the callback extracted
338+
normalizers = callback.collector.processor.normalizers
339+
340+
# Verify against raw optimizer state (independent ground truth)
341+
for layer_name in ["fc1", "fc2"]:
342+
layer = model.get_submodule(layer_name)
343+
norm = normalizers[layer_name]
344+
345+
# Check normalizer type
346+
assert isinstance(norm, AdafactorNormalizer)
347+
348+
# Get raw state from optimizer
349+
weight_state = optimizer.state[layer.weight]
350+
lr = optimizer.param_groups[0]["lr"]
351+
lr_sqrt = lr**0.5
352+
353+
if optimizer_name == "adam":
354+
# Ground truth: Adam stores full exp_avg_sq
355+
raw_exp_avg_sq = weight_state["exp_avg_sq"]
356+
357+
# NOTE: We convert Adam's full second moments to Adafactor's factorized
358+
# form (row + col vectors) for memory efficiency. This is a lossy
359+
# rank-1 approximation that can have large reconstruction errors.
360+
# We can't verify correctness here, only sanity check the factorization.
361+
362+
# Sanity checks on the factorized representation
363+
assert norm.row.shape == (raw_exp_avg_sq.shape[0],)
364+
assert norm.col.shape == (raw_exp_avg_sq.shape[1],)
365+
assert (
366+
not torch.isnan(norm.row).any() and not torch.isinf(norm.row).any()
367+
)
368+
assert (
369+
not torch.isnan(norm.col).any() and not torch.isinf(norm.col).any()
370+
)
371+
assert (norm.row > 0).all() and (
372+
norm.col > 0
373+
).all() # Second moments are positive
374+
375+
elif optimizer_name == "adafactor":
376+
# Ground truth: Adafactor stores row/col directly
377+
raw_row = weight_state["exp_avg_sq_row"]
378+
raw_col = weight_state["exp_avg_sq_col"]
379+
380+
# Our normalizer should match (scaled by LR)
381+
expected_row = raw_row * lr_sqrt
382+
expected_col = raw_col * lr_sqrt
383+
384+
torch.testing.assert_close(norm.row, expected_row)
385+
torch.testing.assert_close(norm.col, expected_col)
386+
387+
# Verify bias handling
388+
if include_bias and layer.bias is not None:
389+
bias_state = optimizer.state[layer.bias]
390+
raw_bias_exp_avg_sq = bias_state["exp_avg_sq"]
391+
expected_bias = raw_bias_exp_avg_sq * lr
392+
393+
assert norm.bias_avg_sq is not None, (
394+
f"Expected bias_avg_sq for {layer_name}"
395+
)
396+
torch.testing.assert_close(norm.bias_avg_sq, expected_bias)
397+
else:
398+
assert norm.bias_avg_sq is None, (
399+
f"Unexpected bias_avg_sq for {layer_name}"
400+
)

0 commit comments

Comments
 (0)