diff --git a/CLAUDE.md b/CLAUDE.md index 3d2e569a76..dd2c9a16ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -8,3 +8,8 @@ - To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes. - Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`. - Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`. + +# Comments Policy +- NEVER remove existing comments from code when making edits unless they are obviously outdated, in which case ALWAYS ask for permission. +- Always preserve all existing comments, especially explanatory ones +- Only add comments when they are needed for clarity diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1dfbc579ff..0b53d53709 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -92,6 +92,7 @@ cleanup_all_llm_judge_clients, soft_format_reward_func, ) +from open_instruct.metrics import LossStatistics from open_instruct.model_utils import ( Batch, ModelConfig, @@ -522,29 +523,177 @@ def masked_mean( return (numerator / denom).mean() -class MetricsTracker: - """A simple class to prellocate all metrics in an array - so we can do only one allreduce operation to get the metrics mean""" +def compare_logprobs( + new_logprobs: torch.Tensor, + old_logprobs: torch.Tensor, + mask: torch.Tensor, + masked_mean_axis: int | None, + masked_mean_denominator: float | None, +) -> dict[str, float]: + """Compare locally computed log probabilities with reference log probabilities. - def __init__(self, max_metrics: int = 32, device: str = "cuda"): - self.metrics = torch.zeros(max_metrics, device=device) - self.names2idx = {} - self.current_idx = 0 - self.max_metrics = max_metrics + Computes statistics on the difference between two sets of log probabilities and returns + debugging metrics including mean/max/std differences and reverse KL divergence. - def add(self, name: str, value: torch.tensor): - if name not in self.names2idx: - if self.current_idx >= self.max_metrics: - raise ValueError(f"Exceeded maximum number of metrics ({self.max_metrics})") - self.names2idx[name] = self.current_idx - self.current_idx += 1 + Args: + new_logprobs: Locally computed log probabilities (shape: [batch, seq_len]) + old_logprobs: Reference log probabilities from behavior policy (shape: [batch, seq_len]) + mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) + masked_mean_axis: Axis for masked mean reduction + masked_mean_denominator: Denominator for masked mean computation - self.metrics[self.names2idx[name]] = value - return self + Returns: + Dictionary of debug metrics + """ + with torch.no_grad(): + valid_mask = mask & ~torch.isnan(old_logprobs) + logprob_diff = (new_logprobs - old_logprobs).abs() + masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0) + mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 + max_diff = masked_diff.max() + std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0 + reverse_kl = masked_mean( + torch.expm1(old_logprobs - new_logprobs) + (old_logprobs - new_logprobs), + mask, + masked_mean_axis, + masked_mean_denominator, + ) + return { + "debug/vllm_vs_local_logprob_diff_mean": mean_diff.item(), + "debug/vllm_vs_local_logprob_diff_max": max_diff.item(), + "debug/vllm_vs_local_logprob_diff_std": std_diff.item(), + "debug/vllm_local_reverse_kl": reverse_kl.item(), + } + + +def maybe_apply_importance_sampling( + unclipped_pg_loss: torch.Tensor, + clipped_pg_loss: torch.Tensor, + old_logprobs: torch.Tensor, + vllm_logprobs: torch.Tensor | None, + response_mask: torch.Tensor, + args: Args, +) -> tuple[torch.Tensor, torch.Tensor]: + """Apply truncated importance sampling (TIS) to policy gradient losses if enabled. + + Importance sampling corrects for the distribution mismatch between the behavior policy + (vLLM inference) and the current policy. The importance ratio is capped to prevent + high variance in gradient estimates. + + Args: + unclipped_pg_loss: Unclipped policy gradient losses (shape: [batch, seq_len]) + clipped_pg_loss: Clipped policy gradient losses (shape: [batch, seq_len]) + old_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) + vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) + response_mask: Boolean mask indicating valid response tokens (shape: [batch, seq_len]) + args: Training arguments containing truncated_importance_sampling_ratio_cap + + Returns: + Tuple of (potentially scaled unclipped_pg_loss, potentially scaled clipped_pg_loss) + """ + if args.truncated_importance_sampling_ratio_cap <= 0 or vllm_logprobs is None: + return unclipped_pg_loss, clipped_pg_loss + + old_logprobs_mask = old_logprobs != INVALID_LOGPROB + vllm_logprobs_mask = vllm_logprobs != INVALID_LOGPROB + + assert torch.allclose(old_logprobs_mask.float(), response_mask.float()), ( + f"Old logprobs mask should match response mask. " + f"old_mask sum={old_logprobs_mask.sum()}, " + f"response_mask sum={response_mask.sum()}" + ) + assert torch.allclose(vllm_logprobs_mask.float(), response_mask.float()), ( + f"vLLM logprobs mask should match response mask. " + f"vllm_mask sum={vllm_logprobs_mask.sum()}, " + f"response_mask sum={response_mask.sum()}" + ) + + valid_mask = response_mask + importance_ratio = torch.ones_like(old_logprobs) + + if valid_mask.any(): + logprob_diff = old_logprobs - vllm_logprobs + logprob_diff = torch.where(valid_mask, logprob_diff.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff)) + importance_ratio = torch.where(valid_mask, torch.exp(logprob_diff), importance_ratio) + importance_ratio = torch.clamp(importance_ratio, max=args.truncated_importance_sampling_ratio_cap) + + unclipped_pg_loss = unclipped_pg_loss * importance_ratio + clipped_pg_loss = clipped_pg_loss * importance_ratio + + return unclipped_pg_loss, clipped_pg_loss + + +def calculate_loss_and_backward( + model: deepspeed.DeepSpeedEngine, + i: int, + loss_statistics: LossStatistics, + local_logprobs: torch.Tensor, + old_logprobs: torch.Tensor, + ref_logprob: torch.Tensor, + advantages: torch.Tensor, + response_masks_bool: torch.Tensor, + vllm_logprobs: torch.Tensor | None, + entropy: torch.Tensor, + accumulation_steps: int, + local_step: int, + args: Args, +) -> int: + """Calculate GRPO loss and perform backward pass for a single minibatch. + + Computes the policy gradient loss using the clipped surrogate objective from PPO, + combines it with a KL penalty term, and performs the backward pass. + + Args: + model: Model wrapper with backward() method (e.g., DeepSpeed engine) + i: Minibatch index for tracking statistics + loss_statistics: LossStatistics object to accumulate training metrics + local_logprobs: Log probabilities from current policy (shape: [batch, seq_len]) + old_logprobs: Log probabilities from old/cached policy (shape: [batch, seq_len]) + ref_logprob: Log probabilities from reference model (shape: [batch, seq_len]) + advantages: Advantage estimates for policy gradient (shape: [batch, seq_len+1]) + response_masks_bool: Boolean mask for valid response tokens (shape: [batch, seq_len]) + vllm_logprobs: Log probabilities from behavior policy, or None (shape: [batch, seq_len]) + entropy: Entropy of the policy distribution (shape: [batch, seq_len]) + accumulation_steps: Number of gradient accumulation steps before optimizer update + local_step: Current local training step (used to determine when to step optimizer) + args: Training arguments containing hyperparameters + + Returns: + Updated local_step (incremented by 1) + """ + logprobs_diff = local_logprobs - old_logprobs + ratio = torch.exp(logprobs_diff) + + # PPO clipped surrogate objective: we compute two losses and take the element-wise maximum. + # - unclipped_pg_loss: standard policy gradient loss using the raw importance ratio + # - clipped_pg_loss: policy gradient loss with the ratio clipped to prevent large updates + # Taking the maximum implements a pessimistic bound that prevents the policy from + # deviating too far from the old policy. The clipfrac metric tracks how often clipping + # is active (clipped_pg_loss > unclipped_pg_loss), which indicates constraint saturation. + unclipped_pg_loss = -advantages[:, 1:] * ratio + clipped_pg_loss = -advantages[:, 1:] * torch.clamp(ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher) + + unclipped_pg_loss, clipped_pg_loss = maybe_apply_importance_sampling( + unclipped_pg_loss, clipped_pg_loss, old_logprobs, vllm_logprobs, response_masks_bool, args + ) + + pg_loss_max = torch.max(unclipped_pg_loss, clipped_pg_loss) - def get_metrics_list(self) -> dict[str, float]: - metrics_list = self.metrics.tolist() - return {name: metrics_list[idx] for name, idx in self.names2idx.items()} + ref_logprobs_diff = (local_logprobs - ref_logprob).clamp(-40.0, 40.0) + kl = loss_statistics.update_kl_estimates(i, ref_logprobs_diff, ratio, response_masks_bool, args) + + loss = masked_mean( + pg_loss_max + (args.beta * kl), response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + loss = loss / accumulation_steps + model.backward(loss) + + with torch.no_grad(): + loss_statistics.update_stats( + i, response_masks_bool, unclipped_pg_loss, clipped_pg_loss, pg_loss_max, ratio, loss, entropy, args + ) + + return local_step + 1 def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: @@ -860,7 +1009,6 @@ def load(self, path: str, map_location=None): else: self.ref_policy.load_state_dict(state_dict) logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") - self.local_metrics = MetricsTracker(max_metrics=32, device=self.device) return optimization_steps_done def forward( @@ -998,6 +1146,7 @@ def train( num_mini_batches: int, ): args = self.args + local_metrics = {} to_device_inplace(collated_query_responses, self.device) to_device_inplace(collated_tool_masks, self.device) to_device_inplace(collated_attention_masks, self.device) @@ -1089,18 +1238,10 @@ def train( local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): - kl1_stats = torch.zeros(len(collated_query_responses)) - kl2_stats = torch.zeros(len(collated_query_responses)) - kl3_stats = torch.zeros(len(collated_query_responses)) - kl4_stats = torch.zeros(len(collated_query_responses)) - kl_loss_stats = torch.zeros(len(collated_query_responses)) - pg_clipfrac_stats = torch.zeros(len(collated_query_responses)) - pg_loss_stats = torch.zeros(len(collated_query_responses)) - loss_stats = torch.zeros(len(collated_query_responses)) - ratio_stats = torch.zeros(len(collated_query_responses)) - entropy_stats = torch.zeros(len(collated_query_responses)) + loss_statistics = LossStatistics(len(collated_query_responses), record_entropy=args.record_entropy) for epoch_idx in range(args.num_epochs): for i in range(len(collated_query_responses)): + # mb = mini-batch mb_ref_logprob = collated_ref_logprobs[i] mb_query_responses = collated_query_responses[i] mb_tool_mask = collated_tool_masks[i] @@ -1127,25 +1268,13 @@ def train( # Replace any remaining NaN values (query tokens in packed sequences are set to NaN by pack_sequences in rl_utils.py) mb_vllm_logprobs = torch.nan_to_num(mb_vllm_logprobs, nan=INVALID_LOGPROB) - # Compare vLLM logprobs with local logprobs - with torch.no_grad(): - valid_mask = mb_response_masks_bool & ~torch.isnan(mb_vllm_logprobs) - logprob_diff = (mb_local_logprobs - mb_vllm_logprobs).abs() - masked_diff = torch.masked_fill(logprob_diff, ~valid_mask, 0.0) - mean_diff = masked_diff.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 - max_diff = masked_diff.max() - std_diff = masked_diff[valid_mask].std() if valid_mask.sum() > 1 else 0.0 - - self.local_metrics.add("debug/vllm_vs_local_logprob_diff_mean", mean_diff.item()) - self.local_metrics.add("debug/vllm_vs_local_logprob_diff_max", max_diff.item()) - self.local_metrics.add("debug/vllm_vs_local_logprob_diff_std", std_diff.item()) - - reverse_kl = torch.exp(mb_vllm_logprobs) * (mb_vllm_logprobs - mb_local_logprobs) - masked_reverse_kl = torch.masked_fill(reverse_kl, ~valid_mask, 0.0) - mean_reverse_kl = masked_reverse_kl.sum() / valid_mask.sum() if valid_mask.sum() > 0 else 0.0 - self.local_metrics.add("debug/vllm_local_reverse_kl", mean_reverse_kl.item()) - - mb_new_logprobs = mb_local_logprobs + local_metrics |= compare_logprobs( + mb_local_logprobs, + mb_vllm_logprobs, + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator, + ) # Cache the old logprobs if num_mini_batches > 1: @@ -1153,10 +1282,9 @@ def train( else: with torch.no_grad(): if epoch_idx == 0: - if args.use_vllm_logprobs: - old_logprobs[i] = mb_vllm_logprobs - else: - old_logprobs[i] = mb_local_logprobs.detach() + old_logprobs[i] = ( + mb_vllm_logprobs if args.use_vllm_logprobs else mb_local_logprobs.detach() + ) mb_old_logprobs = old_logprobs[i] old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB @@ -1166,140 +1294,27 @@ def train( f"response_mask sum={mb_response_masks_bool.sum()}" ) - # Calculate the policy's loss - logprobs_diff = mb_new_logprobs - mb_old_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantages[:, 1:] * ratio - pg_losses2 = -mb_advantages[:, 1:] * torch.clamp( - ratio, 1.0 - args.clip_lower, 1.0 + args.clip_higher - ) - - # Apply truncated importance sampling if enabled - if args.truncated_importance_sampling_ratio_cap > 0 and mb_vllm_logprobs is not None: - old_logprobs_mask = mb_old_logprobs != INVALID_LOGPROB - vllm_logprobs_mask = mb_vllm_logprobs != INVALID_LOGPROB - - assert torch.all(old_logprobs_mask == mb_response_masks_bool), ( - f"Old logprobs mask should match response mask. " - f"old_mask sum={old_logprobs_mask.sum()}, " - f"response_mask sum={mb_response_masks_bool.sum()}" - ) - assert torch.all(vllm_logprobs_mask == mb_response_masks_bool), ( - f"vLLM logprobs mask should match response mask. " - f"vllm_mask sum={vllm_logprobs_mask.sum()}, " - f"response_mask sum={mb_response_masks_bool.sum()}" - ) - - valid_mask = mb_response_masks_bool - - # Initialize importance ratio to 1.0 (no effect) for all positions - tis_imp_ratio = torch.ones_like(mb_old_logprobs) - - if valid_mask.any(): - # Calculate logprob difference only for valid positions - logprob_diff_is = mb_old_logprobs - mb_vllm_logprobs - # Clamp to prevent numerical overflow in exp - logprob_diff_is = torch.where( - valid_mask, logprob_diff_is.clamp(-10.0, 10.0), torch.zeros_like(logprob_diff_is) - ) - # Compute importance ratio only for valid positions - tis_imp_ratio = torch.where(valid_mask, torch.exp(logprob_diff_is), tis_imp_ratio) - # Apply cap - tis_imp_ratio = torch.clamp( - tis_imp_ratio, max=args.truncated_importance_sampling_ratio_cap - ) - - # Apply importance sampling to losses - pg_losses = pg_losses * tis_imp_ratio - pg_losses2 = pg_losses2 * tis_imp_ratio - - pg_loss_max = torch.max(pg_losses, pg_losses2) - - # Here we recalculate kl: we want the KL loss to backpropagate through the model - # We also clamp the KL loss to avoid numerical instability - # https://chatgpt.com/share/679d0ed9-8f48-8011-926e-e274b15ae8ae - ref_logprobs_diff = (mb_new_logprobs - mb_ref_logprob).clamp(-40.0, 40.0) - kl1 = ref_logprobs_diff - kl2 = (ref_logprobs_diff) ** 2 / 2 - kl3 = torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff # this is more numerically stable - kl4 = ratio * ref_logprobs_diff - if args.kl_estimator == "kl1": - kl = kl1 - elif args.kl_estimator == "kl2": - kl = kl2 - elif args.kl_estimator == "kl3": - kl = kl3 - elif args.kl_estimator == "kl4": - kl = kl4 - - # grpo change: directly subtract KL in loss (add) - loss = masked_mean( - pg_loss_max + (args.beta * kl), + local_step = calculate_loss_and_backward( + self.model, + i, + loss_statistics, + mb_local_logprobs, + mb_old_logprobs, + mb_ref_logprob, + mb_advantages, mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, + mb_vllm_logprobs, + mb_entropy, + accumulation_steps, + local_step, + args, ) - loss = loss / accumulation_steps - self.model.backward(loss) - if (local_step + 1) % accumulation_steps == 0: + if local_step % accumulation_steps == 0: self.model.step() - local_step += 1 - with torch.no_grad(): - # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl2_stats[i] = masked_mean( - kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl3_stats[i] = masked_mean( - kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl4_stats[i] = masked_mean( - kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - if args.kl_estimator == "kl1": - kl_loss_stats[i] = kl1_stats[i] * args.beta - elif args.kl_estimator == "kl2": - kl_loss_stats[i] = kl2_stats[i] * args.beta - elif args.kl_estimator == "kl3": - kl_loss_stats[i] = kl3_stats[i] * args.beta - elif args.kl_estimator == "kl4": - kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, - ) - pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - loss_stats[i] = loss - ratio_stats[i] = masked_mean( - ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - if args.record_entropy: - # Calculate entropy statistics - entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - with torch.no_grad(): - self.local_metrics.add("objective/kl_avg", kl1_stats.mean()) - self.local_metrics.add("objective/kl2_avg", kl2_stats.mean()) - self.local_metrics.add("objective/kl3_avg", kl3_stats.mean()) - self.local_metrics.add("objective/kl4_avg", kl4_stats.mean()) - self.local_metrics.add("loss/policy_avg", pg_loss_stats.mean()) - self.local_metrics.add("loss/kl_avg", kl_loss_stats.mean()) - self.local_metrics.add("loss/total_avg", loss_stats.mean()) - self.local_metrics.add("policy/clipfrac_avg", pg_clipfrac_stats.mean()) - self.local_metrics.add("val/ratio", ratio_stats.mean()) - self.local_metrics.add("val/ratio_var", ratio_stats.var()) - if args.record_entropy: - self.local_metrics.add("policy/entropy_avg", entropy_stats.mean()) - self.local_metrics.add("lr", self.scheduler.get_last_lr()[0]) - return self.local_metrics.get_metrics_list() + local_metrics |= loss_statistics.to_dict() + local_metrics["lr"] = self.scheduler.get_last_lr()[0] + return local_metrics def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: args = self.args diff --git a/open_instruct/metrics.py b/open_instruct/metrics.py new file mode 100644 index 0000000000..53dcddee3f --- /dev/null +++ b/open_instruct/metrics.py @@ -0,0 +1,136 @@ +import torch + + +def masked_mean( + values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None +) -> torch.Tensor: + """Compute the mean of tensor values considering only masked (valid) positions. + + Args: + values: Input tensor to compute mean over + mask: Boolean or binary mask tensor (same shape as values) indicating valid positions + axis: Axis along which to sum before computing mean, or None for all axes + denominator: Optional fixed denominator to use instead of mask.sum(). Useful when + the denominator should be consistent across batches. + + Returns: + Masked mean of the input values as a scalar tensor + """ + numerator = (values * mask).sum(axis=axis) + denom = mask.sum(axis=axis) if denominator is None else denominator + return (numerator / denom).mean() + + +class LossStatistics: + """Accumulates training statistics across minibatches for GRPO training. + + Tracks KL divergence estimates, policy gradient losses, clipping statistics, + and importance ratios across multiple minibatches. Provides methods to update + statistics and convert accumulated values to a metrics dictionary. + """ + + def __init__(self, num_batches: int, record_entropy: bool = False): + """Initialize loss statistics storage. + + Args: + num_batches: Number of minibatches to track statistics for + record_entropy: Whether to track policy entropy statistics + """ + self.kl_stats = torch.zeros(4, num_batches) + self.kl_loss_stats = torch.zeros(num_batches) + self.pg_clipfrac_stats = torch.zeros(num_batches) + self.pg_loss_stats = torch.zeros(num_batches) + self.loss_stats = torch.zeros(num_batches) + self.ratio_stats = torch.zeros(num_batches) + self.entropy_stats = torch.zeros(num_batches) if record_entropy else None + + def update_kl_estimates(self, i, ref_logprobs_diff, ratio, mb_response_masks_bool, args): + """Compute and store KL divergence estimates for a minibatch. + + Computes four different KL estimators (kl1-kl4) based on log probability + differences between the current policy and reference policy. + + Args: + i: Minibatch index + ref_logprobs_diff: Log probability differences (new - ref) [batch, seq_len] + ratio: Importance ratio exp(new_logprobs - old_logprobs) [batch, seq_len] + mb_response_masks_bool: Boolean mask for valid response tokens [batch, seq_len] + args: Training arguments containing kl_estimator, masked_mean settings + + Returns: + KL divergence values for the selected estimator (shape: [batch, seq_len]) + """ + kl_values = torch.stack( + [ + ref_logprobs_diff, + ref_logprobs_diff**2 / 2, + torch.expm1(-ref_logprobs_diff) + ref_logprobs_diff, + ratio * ref_logprobs_diff, + ] + ) + + vmapped_fn = torch.vmap( + lambda v: masked_mean(v, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator) + ) + self.kl_stats[:, i] = vmapped_fn(kl_values).float() + + kl_idx = {"kl1": 0, "kl2": 1, "kl3": 2, "kl4": 3}[args.kl_estimator] + return kl_values[kl_idx] + + def update_stats( + self, i, mb_response_masks_bool, pg_losses, pg_losses2, pg_loss_max, ratio, loss, mb_entropy, args + ): + """Update all training statistics for a minibatch. + + Args: + i: Minibatch index + mb_response_masks_bool: Boolean mask for valid response tokens [batch, seq_len] + pg_losses: Unclipped policy gradient losses [batch, seq_len] + pg_losses2: Clipped policy gradient losses [batch, seq_len] + pg_loss_max: Element-wise max of pg_losses and pg_losses2 [batch, seq_len] + ratio: Importance ratio [batch, seq_len] + loss: Total loss value (scalar) + mb_entropy: Policy entropy [batch, seq_len] + args: Training arguments containing beta, record_entropy, masked_mean settings + """ + kl_idx = {"kl1": 0, "kl2": 1, "kl3": 2, "kl4": 3}[args.kl_estimator] + self.kl_loss_stats[i] = self.kl_stats[kl_idx, i] * args.beta + self.pg_clipfrac_stats[i] = masked_mean( + (pg_losses2 > pg_losses).float(), + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator, + ) + self.pg_loss_stats[i] = masked_mean( + pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + self.loss_stats[i] = loss + self.ratio_stats[i] = masked_mean( + ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ) + if args.record_entropy and self.entropy_stats is not None: + self.entropy_stats[i] = masked_mean( + mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ).float() + + def to_dict(self) -> dict[str, float]: + """Convert accumulated statistics to a metrics dictionary. + + Returns: + Dictionary mapping metric names to their averaged values across all minibatches + """ + metrics = { + "objective/kl_avg": self.kl_stats[0].mean().item(), + "objective/kl2_avg": self.kl_stats[1].mean().item(), + "objective/kl3_avg": self.kl_stats[2].mean().item(), + "objective/kl4_avg": self.kl_stats[3].mean().item(), + "loss/policy_avg": self.pg_loss_stats.mean().item(), + "loss/kl_avg": self.kl_loss_stats.mean().item(), + "loss/total_avg": self.loss_stats.mean().item(), + "policy/clipfrac_avg": self.pg_clipfrac_stats.mean().item(), + "val/ratio": self.ratio_stats.mean().item(), + "val/ratio_var": self.ratio_stats.var().item(), + } + if self.entropy_stats is not None: + metrics["policy/entropy_avg"] = self.entropy_stats.mean().item() + return metrics diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index e633f5314b..21a8a51073 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -1086,5 +1086,241 @@ def test_distribution_and_structure( self.assertTrue(torch.all(row[first_pad_idx:] == pad_token_id)) +class TestCompareLogprobs(unittest.TestCase): + @parameterized.expand( + [ + ( + torch.tensor([[1.0, 2.0, 3.0]]), + torch.tensor([[1.1, 2.1, 3.1]]), + torch.tensor([[True, True, True]]), + None, + None, + ), + ( + torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + torch.tensor([[1.2, 2.2], [3.2, 4.2]]), + torch.tensor([[True, True], [True, True]]), + None, + None, + ), + ] + ) + def test_basic_functionality(self, new_logprobs, old_logprobs, mask, masked_mean_axis, masked_mean_denominator): + result = grpo_fast.compare_logprobs( + new_logprobs, old_logprobs, mask, masked_mean_axis, masked_mean_denominator + ) + self.assertIsInstance(result, dict) + self.assertIn("debug/vllm_vs_local_logprob_diff_mean", result) + self.assertIn("debug/vllm_vs_local_logprob_diff_max", result) + self.assertIn("debug/vllm_vs_local_logprob_diff_std", result) + self.assertIn("debug/vllm_local_reverse_kl", result) + self.assertGreater(result["debug/vllm_vs_local_logprob_diff_mean"], 0.0) + + def test_with_nan_values(self): + new_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + old_logprobs = torch.tensor([[1.1, float("nan"), 3.1]]) + mask = torch.tensor([[True, True, True]]) + result = grpo_fast.compare_logprobs(new_logprobs, old_logprobs, mask, None, None) + self.assertIsInstance(result, dict) + self.assertFalse(torch.isnan(torch.tensor(result["debug/vllm_vs_local_logprob_diff_mean"]))) + + def test_multiple_elements(self): + new_logprobs = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) + old_logprobs = torch.tensor([[1.1, 2.1, 3.1, 4.1]]) + mask = torch.tensor([[True, True, True, True]]) + result = grpo_fast.compare_logprobs(new_logprobs, old_logprobs, mask, None, None) + self.assertIsInstance(result, dict) + self.assertGreater(result["debug/vllm_vs_local_logprob_diff_mean"], 0.0) + self.assertGreater(result["debug/vllm_vs_local_logprob_diff_std"], 0.0) + + +class TestMaybeApplyImportanceSampling(unittest.TestCase): + def test_early_return_zero_cap(self): + unclipped = torch.tensor([[-1.0, -2.0]]) + clipped = torch.tensor([[-1.1, -2.1]]) + old_logprobs = torch.tensor([[-0.5, -0.6]]) + vllm_logprobs = torch.tensor([[-0.4, -0.5]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + self.assertTrue(torch.equal(result_unclipped, unclipped)) + self.assertTrue(torch.equal(result_clipped, clipped)) + + def test_early_return_none_vllm_logprobs(self): + unclipped = torch.tensor([[-1.0, -2.0]]) + clipped = torch.tensor([[-1.1, -2.1]]) + old_logprobs = torch.tensor([[-0.5, -0.6]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 2.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, None, mask, mock_args + ) + self.assertTrue(torch.equal(result_unclipped, unclipped)) + self.assertTrue(torch.equal(result_clipped, clipped)) + + def test_importance_ratio_computation(self): + unclipped = torch.tensor([[-1.0, -2.0]]) + clipped = torch.tensor([[-1.0, -2.0]]) + old_logprobs = torch.tensor([[-2.0, -2.0]]) + vllm_logprobs = torch.tensor([[-3.0, -3.0]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 10.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + expected_ratio = torch.exp(torch.tensor([[1.0, 1.0]])) + self.assertTrue(torch.allclose(result_unclipped, unclipped * expected_ratio)) + self.assertTrue(torch.allclose(result_clipped, clipped * expected_ratio)) + + def test_ratio_capping(self): + unclipped = torch.tensor([[-1.0, -1.0]]) + clipped = torch.tensor([[-1.0, -1.0]]) + old_logprobs = torch.tensor([[-2.0, -2.0]]) + vllm_logprobs = torch.tensor([[-7.0, -7.0]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 2.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + self.assertTrue(torch.all(result_unclipped / unclipped <= 2.0)) + self.assertTrue(torch.all(result_clipped / clipped <= 2.0)) + + def test_logprob_diff_clamping(self): + unclipped = torch.tensor([[-1.0, -1.0]]) + clipped = torch.tensor([[-1.0, -1.0]]) + old_logprobs = torch.tensor([[-2.0, -17.0]]) + vllm_logprobs = torch.tensor([[-17.0, -2.0]]) + mask = torch.tensor([[True, True]]) + mock_args = Mock() + mock_args.truncated_importance_sampling_ratio_cap = 1000.0 + result_unclipped, result_clipped = grpo_fast.maybe_apply_importance_sampling( + unclipped, clipped, old_logprobs, vllm_logprobs, mask, mock_args + ) + max_expected_ratio = torch.exp(torch.tensor(10.0)) + min_expected_ratio = torch.exp(torch.tensor(-10.0)) + self.assertTrue(torch.all(result_unclipped / unclipped <= max_expected_ratio)) + self.assertTrue(torch.all(result_unclipped / unclipped >= min_expected_ratio)) + + +class TestCalculateLossAndBackward(unittest.TestCase): + def test_basic_loss_computation(self): + mock_model = Mock() + mock_model.backward = Mock() + loss_statistics = Mock() + loss_statistics.update_kl_estimates = Mock(return_value=torch.tensor(0.5)) + loss_statistics.update_stats = Mock() + local_logprobs = torch.tensor([[1.0, 2.0]]) + old_logprobs = torch.tensor([[0.9, 1.9]]) + ref_logprob = torch.tensor([[0.8, 1.8]]) + advantages = torch.tensor([[0.0, 1.0, 1.0]]) + response_masks_bool = torch.tensor([[True, True]]) + entropy = torch.tensor([[0.5, 0.5]]) + mock_args = Mock() + mock_args.clip_lower = 0.2 + mock_args.clip_higher = 0.2 + mock_args.beta = 0.1 + mock_args.masked_mean_axis = None + mock_args.masked_mean_denominator = None + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + result = grpo_fast.calculate_loss_and_backward( + mock_model, + 0, + loss_statistics, + local_logprobs, + old_logprobs, + ref_logprob, + advantages, + response_masks_bool, + None, + entropy, + 1, + 0, + mock_args, + ) + mock_model.backward.assert_called_once() + self.assertEqual(result, 1) + + def test_gradient_accumulation(self): + mock_model = Mock() + mock_model.backward = Mock() + loss_statistics = Mock() + loss_statistics.update_kl_estimates = Mock(return_value=torch.tensor(0.5)) + loss_statistics.update_stats = Mock() + local_logprobs = torch.tensor([[1.0, 2.0]]) + old_logprobs = torch.tensor([[0.9, 1.9]]) + ref_logprob = torch.tensor([[0.8, 1.8]]) + advantages = torch.tensor([[0.0, 1.0, 1.0]]) + response_masks_bool = torch.tensor([[True, True]]) + entropy = torch.tensor([[0.5, 0.5]]) + mock_args = Mock() + mock_args.clip_lower = 0.2 + mock_args.clip_higher = 0.2 + mock_args.beta = 0.1 + mock_args.masked_mean_axis = None + mock_args.masked_mean_denominator = None + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + grpo_fast.calculate_loss_and_backward( + mock_model, + 0, + loss_statistics, + local_logprobs, + old_logprobs, + ref_logprob, + advantages, + response_masks_bool, + None, + entropy, + 4, + 0, + mock_args, + ) + mock_model.backward.assert_called_once() + loss_arg = mock_model.backward.call_args[0][0] + self.assertIsInstance(loss_arg, torch.Tensor) + + def test_advantages_slicing(self): + mock_model = Mock() + mock_model.backward = Mock() + loss_statistics = Mock() + loss_statistics.update_kl_estimates = Mock(return_value=torch.tensor(0.5)) + loss_statistics.update_stats = Mock() + local_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + old_logprobs = torch.tensor([[1.0, 2.0, 3.0]]) + ref_logprob = torch.tensor([[1.0, 2.0, 3.0]]) + advantages = torch.tensor([[0.0, 1.0, 1.0, 1.0]]) + response_masks_bool = torch.tensor([[True, True, True]]) + entropy = torch.tensor([[0.5, 0.5, 0.5]]) + mock_args = Mock() + mock_args.clip_lower = 0.2 + mock_args.clip_higher = 0.2 + mock_args.beta = 0.1 + mock_args.masked_mean_axis = None + mock_args.masked_mean_denominator = None + mock_args.truncated_importance_sampling_ratio_cap = 0.0 + result = grpo_fast.calculate_loss_and_backward( + mock_model, + 0, + loss_statistics, + local_logprobs, + old_logprobs, + ref_logprob, + advantages, + response_masks_bool, + None, + entropy, + 1, + 0, + mock_args, + ) + mock_model.backward.assert_called_once() + self.assertEqual(result, 1) + + if __name__ == "__main__": unittest.main()