-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathsave_grad.py
More file actions
279 lines (243 loc) · 12.8 KB
/
save_grad.py
File metadata and controls
279 lines (243 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# -*- coding: utf-8 -*-
# ================================================================================
# SAVE_GRAD.PY - PROXY STUDENT GRADIENT COMPUTATION
# ================================================================================
# This script computes and saves gradients from a proxy student model on holdout
# traces for use in antidistillation sampling (ADS).
#
# Purpose:
# - Load holdout reasoning traces generated by the teacher model
# - Compute gradients from a proxy student model trying to learn from these traces
# - Save averaged gradients across the dataset for later use in ADS
#
# These gradients are crucial for ADS as they represent how the student model
# would be affected by changes to the teacher's outputs. The antidistillation
# mechanism uses these gradients in a finite difference approximation to modify
# the teacher's sampling distribution in ways that hurt student learning.
#
# Key technical details:
# - Uses completion-only training (only computes loss on assistant responses)
# - Accumulates gradients across the entire holdout dataset
# - Saves averaged gradients for use in finite difference perturbations (+/-ε)
# ================================================================================
import argparse
import os
import sys
import datasets
import torch
import yaml
import socket
from accelerate import Accelerator
from datasets import load_from_disk
from rich import print as rprint
from rich.console import Console
from rich.panel import Panel
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import logging as hf_logging
from trl import DataCollatorForCompletionOnlyLM
from utils import init
# ================================================================================
# DISTRIBUTED SETUP AND LOGGING CONFIGURATION
# ================================================================================
accelerator = Accelerator()
# Disable verbose logging for non-main processes to reduce noise
if not accelerator.is_main_process:
hf_logging.set_verbosity_error()
hf_logging.disable_progress_bar()
datasets.disable_progress_bar()
tqdm = lambda x, *args, **kwargs: x
def log_color(content, title=""):
"""Enhanced logging with colored console output for main process."""
console = Console(highlight=True, file=sys.stdout)
console.print(Panel(content, title=title, border_style="cyan", title_align="left"))
# ================================================================================
# MAIN GRADIENT COMPUTATION SCRIPT
# ================================================================================
if __name__ == "__main__":
# ============================================================================
# COMMAND LINE ARGUMENT PARSING
# ============================================================================
# Parse arguments and load configuration from holdout trace generation
parser = argparse.ArgumentParser(description="Compute proxy student gradients for antidistillation sampling.")
parser.add_argument("holdout_config", type=str, help="Path to the holdout config.yaml file")
parser.add_argument("--proxy_student", type=str, help="Proxy student model to use for gradient computation")
parser.add_argument("--tokenizer", type=str, help="Tokenizer model to use (should match proxy student)")
parser.add_argument("--seed", type=int, help="Random seed for reproducibility")
parser.add_argument("--trace_colname", type=str, help="Column name for reasoning traces in dataset")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for gradient computation")
args = parser.parse_args()
# ============================================================================
# CONFIGURATION LOADING AND SETUP
# ============================================================================
# Load configuration from holdout trace generation and merge with command line args
with open(args.holdout_config, 'r') as f:
config = yaml.safe_load(f)
# Override default values with config file values if not specified in command line
for key, value in config.items():
if not hasattr(args, key) or getattr(args, key) is None:
setattr(args, key, value)
# Initialize random seeds for reproducibility
init(os.getenv("USER"), args.seed, "babel" in socket.gethostname())
# ============================================================================
# TOKENIZER SETUP
# ============================================================================
# Configure tokenizer to match the one used for trace generation
# This ensures consistent tokenization between teacher traces and student training
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, use_fast=True, padding_side="left")
# Model-specific tokenizer configuration (same as in gentraces.py)
if "llama" in args.tokenizer.lower():
eot_token_id = 128009
eos_token_id = 128001
tokenizer.pad_token_id = 128004
tokenizer.eos_token_id = eos_token_id
tokenizer.add_eos_token = False
eos_token = tokenizer.eos_token
else:
eos_token = tokenizer.eos_token
special_tokens = {"pad_token": "[PAD]"}
tokenizer.add_special_tokens(special_tokens)
# ============================================================================
# PROXY STUDENT MODEL SETUP
# ============================================================================
# Load the proxy student model for gradient computation
# Note: Using float32 for more precise gradient computation
model = AutoModelForCausalLM.from_pretrained(
args.proxy_student,
trust_remote_code=True,
# Flash Attention disabled for gradient computation to ensure precision
torch_dtype=torch.float32, # Higher precision for gradients
use_cache=True,
)
model.resize_token_embeddings(len(tokenizer))
if accelerator.is_main_process:
print(f"Student model {args.proxy_student} loaded with {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M parameters")
# ============================================================================
# DATASET LOADING AND PREPROCESSING
# ============================================================================
# Load and preprocess the holdout traces on the main process
if accelerator.is_main_process:
# Load the holdout dataset generated by gentraces.py
ds = load_from_disk(args.trace_path)
def preprocessor(examples):
"""
Preprocess reasoning traces for training.
Removes BOS tokens and tokenizes traces for completion-only training.
This prepares the data for computing gradients on how well the student
model can learn to reproduce the teacher's reasoning traces.
"""
# Remove BOS token if present to avoid duplication
if tokenizer.bos_token:
traces = [text.replace(tokenizer.bos_token, "", 1) for text in examples[args.trace_colname]]
# Tokenize the traces
tokenized = tokenizer(
traces,
padding=False,
truncation=True,
return_attention_mask=True,
return_tensors=None,
)
return tokenized
# Apply preprocessing to convert traces to tokens
input_ds = ds.map(
preprocessor,
batched=True,
num_proc=96,
remove_columns=ds.column_names,
desc="Preprocessing dataset",
load_from_cache_file=True
)
dataset_size = len(input_ds)
input_ds.save_to_disk("/tmp/cached_ds")
print(f"Loaded {args.trace_path} with {dataset_size} samples")
print(f"Example trace: {tokenizer.decode(input_ds[0]['input_ids'])}")
# Synchronize processes and load cached dataset
accelerator.wait_for_everyone()
input_ds = load_from_disk("/tmp/cached_ds")
# ============================================================================
# DATA COLLATOR SETUP FOR COMPLETION-ONLY TRAINING
# ============================================================================
# Set up completion-only data collator to compute loss only on assistant responses
# This is crucial because we only want gradients from the reasoning/answer portions,
# not from the user prompts or system messages
response_str = "<|Assistant|>" # Assistant response marker for loss computation
data_collator = DataCollatorForCompletionOnlyLM(
response_template=tokenizer.encode(response_str, add_special_tokens=False),
tokenizer=tokenizer,
mlm=False # Not using masked language modeling
)
# Create dataloader for batch processing
dataloader = DataLoader(
input_ds,
batch_size=args.batch_size,
shuffle=False, # No shuffling needed for gradient computation
collate_fn=data_collator,
num_workers=1,
pin_memory=True
)
# ============================================================================
# DISTRIBUTED TRAINING PREPARATION
# ============================================================================
# Prepare model and dataloader for distributed processing
model, dataloader = accelerator.prepare(model, dataloader)
# ============================================================================
# GRADIENT ACCUMULATION SETUP
# ============================================================================
# Initialize gradient accumulators for all model parameters
# These will store the sum of gradients across all samples in the dataset
grads = {}
for name, param in model.named_parameters():
if param.requires_grad:
# Create gradient accumulator on the same device as the parameter
grads[name] = torch.zeros_like(param.data)
# ============================================================================
# GRADIENT COMPUTATION LOOP
# ============================================================================
# Accumulate gradients across the entire holdout dataset
local_samples = 0
model.train() # Set model to training mode for gradient computation
for batch in tqdm(dataloader, desc="Accumulating gradients", disable=not accelerator.is_main_process):
local_samples += batch["input_ids"].size(0)
# Forward pass: compute loss on completion-only portions
outputs = model(**batch)
# Scale loss by batch size for proper averaging later
loss = outputs.loss * batch["input_ids"].size(0)
# Backward pass: compute gradients
accelerator.backward(loss)
# Accumulate gradients from this batch
for name, param in model.named_parameters():
if param.requires_grad and param.grad is not None:
grads[name].add_(param.grad)
# Clear gradients for next iteration
model.zero_grad()
# ============================================================================
# GRADIENT REDUCTION AND AVERAGING
# ============================================================================
# Reduce sample counts across all processes to get total dataset size
local_tensor = torch.tensor([local_samples], device=accelerator.device)
accelerator.wait_for_everyone()
reduced_tensor = accelerator.reduce(local_tensor, reduction="sum")
total_samples = reduced_tensor.item()
if accelerator.is_main_process:
print(f"Processed a total of {total_samples} samples across all processes")
# Reduce gradients across all processes and compute averages
for name in grads:
# Sum gradients from all processes
accelerator.reduce(grads[name], reduction="sum")
if accelerator.is_main_process:
# Compute average gradient by dividing by total number of samples
grads[name] = grads[name] / total_samples
# ============================================================================
# GRADIENT SAVING
# ============================================================================
# Save the averaged gradients for use in antidistillation sampling
if accelerator.is_main_process:
grad_save_path = os.path.join(args.exp_dir, "student_grads.pt")
torch.save(grads, grad_save_path)
print(f"Saved average gradients to {grad_save_path}")
# Log gradient statistics for debugging
total_grad_norm = sum(torch.norm(grad).item() ** 2 for grad in grads.values()) ** 0.5
print(f"Total gradient norm: {total_grad_norm:.2e}")
print(f"Number of parameters with gradients: {len(grads)}")
accelerator.end_training()