Skip to content

Commit 4722e6e

Browse files
committed
fix: sum bias gradients over sequence dim only, not batch + tests
1 parent 78fa643 commit 4722e6e

File tree

2 files changed

+178
-15
lines changed

2 files changed

+178
-15
lines changed

bergson/gradients.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,8 @@ class GradientCollector(ContextDecorator):
334334
of the parameters, which are expected to be precomputed and passed in.
335335
336336
We assume that the input to `model` is of shape `[N, S, I]`, where `N` is the
337-
batch size, `S` is the sequence length, and `I` is the input dimension. We take the
338-
mean over the sequence length to obtain a single gradient per sequence.
337+
batch size, `S` is the sequence length, and `I` is the input dimension. We
338+
sum over the sequence dimension to obtain a single gradient per sequence.
339339
"""
340340

341341
model: nn.Module
@@ -565,24 +565,19 @@ def _process_grad(self, module: nn.Module, _, grad_out):
565565
if isinstance(norm, AdamNormalizer) or include_bias:
566566

567567
P = G.mT @ I # [N, O, S] @ [N, S, I] → [N, O, I]
568+
if isinstance(norm, AdamNormalizer):
569+
# Normalize the gradients using the second moment matrix
570+
P /= norm.avg_sq.sqrt().add_(1e-8)
571+
568572
if include_bias:
569-
# Append the bias gradient to the input
573+
# TODO: should we normalize the bias gradients?
574+
# Append the raw bias gradient to the input
570575
P = torch.cat(
571-
[
572-
P,
573-
G.sum(dim=(0, 1))
574-
.unsqueeze(0)
575-
.unsqueeze(2)
576-
.expand(P.shape[0], -1, 1),
577-
],
576+
[P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1]
578577
dim=2,
579578
)
580579
i += 1
581580

582-
if isinstance(norm, AdamNormalizer):
583-
# Normalize the gradients using the second moment matrix
584-
P /= norm.avg_sq.sqrt().add_(1e-8)
585-
586581
if self.processor.reshape_to_square:
587582
P = reshape_to_nearest_square(P)
588583
o, i = P.shape[-2:]

tests/test_gradients.py

Lines changed: 169 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import tempfile
2+
from collections import defaultdict
23
from pathlib import Path
34

5+
import pytest
46
import torch
7+
import torch.nn as nn
58
from transformers import AutoConfig, AutoModelForCausalLM
69

710
from bergson.gradients import (
@@ -13,7 +16,7 @@
1316
)
1417

1518

16-
def test_phi3():
19+
def test_gradient_collector_proj_norm():
1720
temp_dir = Path(tempfile.mkdtemp())
1821

1922
config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Phi3ForCausalLM")
@@ -105,3 +108,168 @@ def closure(name: str, g: torch.Tensor):
105108
)
106109

107110
previous_collected_grads = collected_grads.copy()
111+
112+
113+
@pytest.mark.parametrize("include_bias", [True, False])
114+
def test_gradient_collector_batched(include_bias: bool):
115+
torch.manual_seed(42)
116+
N = 4
117+
S = 6
118+
I = 5
119+
O = 3
120+
121+
class SimpleModel(nn.Module):
122+
def __init__(self):
123+
super().__init__()
124+
self.fc1 = nn.Linear(I, O * 2, bias=include_bias)
125+
self.relu = nn.ReLU()
126+
self.fc2 = nn.Linear(O * 2, O, bias=include_bias)
127+
128+
def forward(self, x):
129+
return self.fc2(self.relu(self.fc1(x)))
130+
131+
torch.manual_seed(42)
132+
model = SimpleModel()
133+
134+
optimizer = torch.optim.Adam(model.parameters())
135+
136+
# Run a few training steps to build up second moments
137+
for _ in range(5):
138+
optimizer.zero_grad()
139+
out = model(torch.randn(N, S, I))
140+
loss = (out**2).sum()
141+
loss.backward()
142+
optimizer.step()
143+
144+
normalizers = {}
145+
for name, param in model.named_parameters():
146+
if "weight" in name:
147+
layer_name = name.replace(".weight", "")
148+
# Adam stores second moments as 'exp_avg_sq'
149+
exp_avg_sq = optimizer.state[param]["exp_avg_sq"]
150+
normalizers[layer_name] = AdamNormalizer(exp_avg_sq)
151+
152+
# collect gradients
153+
collected_grads = {}
154+
155+
def closure(name: str, g: torch.Tensor):
156+
"""Store the gradients in a dictionary for later comparison."""
157+
collected_grads[name] = g
158+
159+
processor = GradientProcessor(
160+
normalizers=normalizers, projection_dim=None, include_bias=include_bias
161+
)
162+
collector = GradientCollector(model, closure, processor)
163+
164+
x = torch.randn(N, S, I)
165+
with collector:
166+
model.zero_grad()
167+
out = model(x)
168+
loss = (out**2).sum()
169+
loss.backward()
170+
171+
def compute_ground_truth():
172+
"""Compute gradients using individual backward passes, with normalization."""
173+
model.zero_grad()
174+
output = model(x) # [N, S, O]
175+
176+
# Per-sample losses
177+
per_sample_losses = (output**2).sum(dim=(1, 2)) # [N]
178+
179+
ground_truth_grads = defaultdict(list)
180+
for n in range(N):
181+
model.zero_grad()
182+
per_sample_losses[n].backward(retain_graph=True)
183+
184+
# manually normalize
185+
for layer_name in ["fc1", "fc2"]:
186+
layer = model.get_submodule(layer_name)
187+
grad = layer.weight.grad.clone()
188+
189+
grad = normalizers[layer_name].normalize_(grad)
190+
191+
if include_bias:
192+
bias_grad = layer.bias.grad.clone()
193+
bias_grad = bias_grad.unsqueeze(1)
194+
grad = torch.cat([grad, bias_grad], dim=1)
195+
196+
ground_truth_grads[layer_name].append(grad)
197+
198+
for layer_name in ["fc1", "fc2"]:
199+
ground_truth_grads[layer_name] = torch.stack(ground_truth_grads[layer_name])
200+
201+
return ground_truth_grads
202+
203+
ground_truth = compute_ground_truth()
204+
for layer_name in ["fc1", "fc2"]:
205+
torch.testing.assert_close(
206+
collected_grads[layer_name], ground_truth[layer_name]
207+
)
208+
209+
210+
def test_bias_gradients():
211+
"""Test that per-sample bias gradients are correctly computed."""
212+
torch.manual_seed(42)
213+
N = 4
214+
S = 6
215+
I = 5
216+
O = 3
217+
218+
class SimpleModel(nn.Module):
219+
def __init__(self):
220+
super().__init__()
221+
self.fc = torch.nn.Linear(I, O, bias=True)
222+
223+
def forward(self, x):
224+
return self.fc(x)
225+
226+
model = SimpleModel()
227+
x = torch.randn(N, S, I)
228+
229+
# bias gradient is a sum over sequence dimension for each n
230+
def compute_ground_truth(model) -> torch.Tensor:
231+
"""Compute gradients using individual backward passes."""
232+
model.zero_grad()
233+
output = model(x) # [N, S, O]
234+
235+
per_sample_losses = (output**2).sum(dim=(1, 2)) # [N]
236+
237+
bias_grads = []
238+
for n in range(N):
239+
model.zero_grad()
240+
per_sample_losses[n].backward(retain_graph=True)
241+
bias_grads.append(model.fc.bias.grad.clone())
242+
243+
return torch.stack(bias_grads, dim=0) # [N, O]
244+
245+
ground_truth = compute_ground_truth(model)
246+
247+
# GradientCollector with include_bias=True
248+
collected_grads = {}
249+
250+
def closure(name: str, g: torch.Tensor):
251+
collected_grads[name] = g
252+
253+
processor = GradientProcessor(include_bias=True, projection_dim=None)
254+
collector = GradientCollector(model, closure, processor, target_modules={"fc"})
255+
256+
with collector:
257+
model.zero_grad()
258+
output = model(x)
259+
loss = (output**2).sum()
260+
loss.backward()
261+
262+
# the last column is bias
263+
bias_grads = collected_grads["fc"][..., -1]
264+
265+
assert bias_grads.shape == (
266+
N,
267+
3,
268+
), f"Expected shape ({N}, {O}), got {bias_grads.shape}"
269+
assert ground_truth.shape == (
270+
N,
271+
3,
272+
), f"Expected shape ({N}, {O}), got {ground_truth.shape}"
273+
274+
# Compare to ground truth
275+
torch.testing.assert_close(bias_grads, ground_truth)

0 commit comments

Comments
 (0)