|
1 | 1 | import os |
| 2 | +from pathlib import Path |
| 3 | + |
| 4 | +from torch import nn |
| 5 | + |
| 6 | +from bergson import GradientProcessor |
| 7 | +from bergson.gradients import AdafactorNormalizer |
2 | 8 |
|
3 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
4 | 10 | os.environ["WANDB_MODE"] = "disabled" |
5 | 11 |
|
6 | 12 | import pytest |
7 | 13 | import torch |
8 | 14 | from datasets import Dataset |
9 | | -from transformers import AutoConfig, AutoModelForCausalLM, Trainer, TrainingArguments |
| 15 | +from transformers import ( |
| 16 | + Adafactor, |
| 17 | + AutoConfig, |
| 18 | + AutoModelForCausalLM, |
| 19 | + Trainer, |
| 20 | + TrainingArguments, |
| 21 | +) |
10 | 22 | from trl import SFTConfig, SFTTrainer |
11 | 23 |
|
12 | 24 | from bergson.data import load_gradients |
@@ -245,3 +257,144 @@ def test_sft_trainer(self, tmp_path, model, dataset): |
245 | 257 | saved_order = Dataset.load_from_disk(str(order_file)) |
246 | 258 | assert len(saved_order) > 0 |
247 | 259 | assert all(key in saved_order[0] for key in ["_idx", "global_step", "epoch"]) |
| 260 | + |
| 261 | + @pytest.mark.parametrize("optimizer_name", ["adam", "adafactor"]) |
| 262 | + @pytest.mark.parametrize("include_bias", [True, False]) |
| 263 | + def test_optimizer_state_extraction(self, optimizer_name: str, include_bias: bool): |
| 264 | + """Test that normalizers are correctly extracted from optimizer state. |
| 265 | +
|
| 266 | + This tests the huggingface.py callback by: |
| 267 | + 1. Training a model with an optimizer |
| 268 | + 2. Calling the callback's on_step_end method |
| 269 | + 3. Verifying against raw optimizer state |
| 270 | + """ |
| 271 | + torch.manual_seed(42) |
| 272 | + N = 4 |
| 273 | + S = 6 |
| 274 | + I = 5 |
| 275 | + O = 3 |
| 276 | + |
| 277 | + class SimpleModel(nn.Module): |
| 278 | + def __init__(self): |
| 279 | + super().__init__() |
| 280 | + self.fc1 = nn.Linear(I, O * 2, bias=include_bias) |
| 281 | + self.relu = nn.ReLU() |
| 282 | + self.fc2 = nn.Linear(O * 2, O, bias=include_bias) |
| 283 | + |
| 284 | + def forward(self, x): |
| 285 | + return self.fc2(self.relu(self.fc1(x))) |
| 286 | + |
| 287 | + torch.manual_seed(42) |
| 288 | + model = SimpleModel() |
| 289 | + |
| 290 | + # Create optimizer |
| 291 | + if optimizer_name == "adam": |
| 292 | + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
| 293 | + else: |
| 294 | + optimizer = Adafactor( |
| 295 | + model.parameters(), scale_parameter=False, relative_step=False, lr=0.001 |
| 296 | + ) |
| 297 | + |
| 298 | + # Train a few steps to build up second moments |
| 299 | + for _ in range(5): |
| 300 | + optimizer.zero_grad() |
| 301 | + out = model(torch.randn(N, S, I)) |
| 302 | + loss = (out**2).sum() |
| 303 | + loss.backward() |
| 304 | + optimizer.step() |
| 305 | + |
| 306 | + # Extract normalizers using the ACTUAL callback |
| 307 | + from unittest.mock import Mock, patch |
| 308 | + |
| 309 | + from bergson.huggingface import GradientCollectorCallback |
| 310 | + |
| 311 | + # Create callback with minimal setup |
| 312 | + callback = GradientCollectorCallback( |
| 313 | + path=Path("/tmp/test"), |
| 314 | + use_optimizer_state=True, |
| 315 | + include_bias=include_bias, |
| 316 | + ) |
| 317 | + |
| 318 | + # Mock the collector and processor |
| 319 | + mock_collector = Mock() |
| 320 | + mock_collector.processor = GradientProcessor( |
| 321 | + normalizers={}, include_bias=include_bias |
| 322 | + ) |
| 323 | + mock_collector.target_info = {"fc1": None, "fc2": None} # Track these layers |
| 324 | + callback.collector = mock_collector |
| 325 | + |
| 326 | + # Mock on_substep_end to avoid needing train_grad_buffer |
| 327 | + with patch.object(callback, "on_substep_end"): |
| 328 | + # Call the ACTUAL callback method |
| 329 | + callback.on_step_end( |
| 330 | + args=Mock(), |
| 331 | + state=Mock(epoch=0, global_step=1), |
| 332 | + control=Mock(), |
| 333 | + model=model, |
| 334 | + optimizer=optimizer, |
| 335 | + ) |
| 336 | + |
| 337 | + # Get the normalizers the callback extracted |
| 338 | + normalizers = callback.collector.processor.normalizers |
| 339 | + |
| 340 | + # Verify against raw optimizer state (independent ground truth) |
| 341 | + for layer_name in ["fc1", "fc2"]: |
| 342 | + layer = model.get_submodule(layer_name) |
| 343 | + norm = normalizers[layer_name] |
| 344 | + |
| 345 | + # Check normalizer type |
| 346 | + assert isinstance(norm, AdafactorNormalizer) |
| 347 | + |
| 348 | + # Get raw state from optimizer |
| 349 | + weight_state = optimizer.state[layer.weight] |
| 350 | + lr = optimizer.param_groups[0]["lr"] |
| 351 | + lr_sqrt = lr**0.5 |
| 352 | + |
| 353 | + if optimizer_name == "adam": |
| 354 | + # Ground truth: Adam stores full exp_avg_sq |
| 355 | + raw_exp_avg_sq = weight_state["exp_avg_sq"] |
| 356 | + |
| 357 | + # NOTE: We convert Adam's full second moments to Adafactor's factorized |
| 358 | + # form (row + col vectors) for memory efficiency. This is a lossy |
| 359 | + # rank-1 approximation that can have large reconstruction errors. |
| 360 | + # We can't verify correctness here, only sanity check the factorization. |
| 361 | + |
| 362 | + # Sanity checks on the factorized representation |
| 363 | + assert norm.row.shape == (raw_exp_avg_sq.shape[0],) |
| 364 | + assert norm.col.shape == (raw_exp_avg_sq.shape[1],) |
| 365 | + assert ( |
| 366 | + not torch.isnan(norm.row).any() and not torch.isinf(norm.row).any() |
| 367 | + ) |
| 368 | + assert ( |
| 369 | + not torch.isnan(norm.col).any() and not torch.isinf(norm.col).any() |
| 370 | + ) |
| 371 | + assert (norm.row > 0).all() and ( |
| 372 | + norm.col > 0 |
| 373 | + ).all() # Second moments are positive |
| 374 | + |
| 375 | + elif optimizer_name == "adafactor": |
| 376 | + # Ground truth: Adafactor stores row/col directly |
| 377 | + raw_row = weight_state["exp_avg_sq_row"] |
| 378 | + raw_col = weight_state["exp_avg_sq_col"] |
| 379 | + |
| 380 | + # Our normalizer should match (scaled by LR) |
| 381 | + expected_row = raw_row * lr_sqrt |
| 382 | + expected_col = raw_col * lr_sqrt |
| 383 | + |
| 384 | + torch.testing.assert_close(norm.row, expected_row) |
| 385 | + torch.testing.assert_close(norm.col, expected_col) |
| 386 | + |
| 387 | + # Verify bias handling |
| 388 | + if include_bias and layer.bias is not None: |
| 389 | + bias_state = optimizer.state[layer.bias] |
| 390 | + raw_bias_exp_avg_sq = bias_state["exp_avg_sq"] |
| 391 | + expected_bias = raw_bias_exp_avg_sq * lr |
| 392 | + |
| 393 | + assert norm.bias_avg_sq is not None, ( |
| 394 | + f"Expected bias_avg_sq for {layer_name}" |
| 395 | + ) |
| 396 | + torch.testing.assert_close(norm.bias_avg_sq, expected_bias) |
| 397 | + else: |
| 398 | + assert norm.bias_avg_sq is None, ( |
| 399 | + f"Unexpected bias_avg_sq for {layer_name}" |
| 400 | + ) |
0 commit comments