-
Notifications
You must be signed in to change notification settings - Fork 465
Correct loss accumulation for grpo_fast #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1b6c5e1
23ea9ce
92a4bc2
a9ae929
c7ccc15
0aad4ac
89fb420
b37079c
0331016
1390a6e
f94b4b2
1931b99
90eb98e
6e698e9
802a7e7
849ebb4
8c7559d
febf44a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -107,7 +107,7 @@ | |
| push_folder_to_hub, | ||
| ) | ||
| from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics | ||
| from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences | ||
| from open_instruct.rl_utils import PackedSequences, Timer, masked_mean, pack_sequences | ||
| from open_instruct.utils import ( | ||
| ArgumentParserPlus, | ||
| BeakerRuntimeConfig, | ||
|
|
@@ -118,6 +118,7 @@ | |
| combine_reward_metrics, | ||
| download_latest_checkpoint_from_gs, | ||
| get_beaker_whoami, | ||
| get_denominator, | ||
| get_eval_ds_config, | ||
| get_optimizer_grouped_parameters, | ||
| get_train_ds_config, | ||
|
|
@@ -249,10 +250,11 @@ class Args: | |
| """the KL estimator to use""" | ||
| pack_length: int = 512 | ||
| """the length of the pack (you should prob set to the max length of the model)""" | ||
| masked_mean_axis: int | None = None | ||
| """the axis to compute the mean of the masked values""" | ||
| masked_mean_denominator: float | None = None | ||
| """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum""" | ||
| masked_mean_denominator: float | str | None = "token" | ||
| """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. | ||
| Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). | ||
| Special value "group" means use group-level averaging (average across tokens in a group, then average across groups). | ||
| When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" | ||
| alpha: float = 0.6 | ||
| """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) | ||
| reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly | ||
|
|
@@ -456,10 +458,7 @@ def __post_init__(self): | |
| "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. " | ||
| "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." | ||
| ) | ||
| if self.masked_mean_denominator is not None: | ||
| assert self.masked_mean_denominator > 0, ( | ||
| f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" | ||
| ) | ||
| self.masked_mean_denominator = get_denominator(self.masked_mean_denominator) | ||
| assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" | ||
| if self.num_samples_per_prompt_rollout == 1: | ||
| logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") | ||
|
|
@@ -498,6 +497,8 @@ def __post_init__(self): | |
| self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{beaker_users}/{checkpoint_dir_name}" | ||
| else: | ||
| self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}" | ||
| if not checkpoint_dir_name.startswith("/filestore"): | ||
| self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}" | ||
|
|
||
| if self.checkpoint_state_dir is not None: | ||
| if self.gs_checkpoint_state_dir is not None: | ||
|
|
@@ -544,15 +545,6 @@ def __post_init__(self): | |
| raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") | ||
|
|
||
|
|
||
| def masked_mean( | ||
| values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None | ||
| ) -> torch.Tensor: | ||
| """Compute mean of tensor with a masked values.""" | ||
| numerator = (values * mask).sum(axis=axis) | ||
| denom = mask.sum(axis=axis) if denominator is None else denominator | ||
| return (numerator / denom).mean() | ||
|
|
||
|
|
||
| def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: | ||
| padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) | ||
| if pin_memory: | ||
|
|
@@ -1053,6 +1045,75 @@ def compute_logprobs( | |
|
|
||
| return collated_logprobs, collated_entropies | ||
|
|
||
| def calculate_token_counts( | ||
| self, | ||
| accumulation_steps: int, | ||
| collated_response_masks: list[torch.Tensor], | ||
| collated_tool_masks: list[torch.Tensor], | ||
| mode: str = "token", | ||
| ) -> dict[int, float | torch.Tensor]: | ||
| accumulation_counts = {} | ||
|
|
||
| max_group_id = 0 | ||
| if mode == "group": | ||
| # First pass to determine max group index | ||
| for i in range(len(collated_response_masks)): | ||
| mb_response_masks = collated_response_masks[i] | ||
| # group_id = (sample_index) // n | ||
| # sample_index starts at 0. response_mask contains sample_index + 1. | ||
| # So (response_mask - 1) // n | ||
| max_group_id = max( | ||
| max_group_id, ((mb_response_masks.max().item() - 1) // self.args.num_samples_per_prompt_rollout) | ||
| ) | ||
|
|
||
| # All reduce max_group_id to ensure all ranks have the same tensor size | ||
| if dist.is_available() and dist.is_initialized(): | ||
| dist.barrier() | ||
| max_group_id_tensor = torch.tensor(max_group_id, dtype=torch.long, device=self.device) | ||
| dist.all_reduce(max_group_id_tensor, op=dist.ReduceOp.MAX, group=None) | ||
| max_group_id = max_group_id_tensor.item() | ||
|
|
||
| for group_start in range(0, len(collated_response_masks), accumulation_steps): | ||
| group_end = min(group_start + accumulation_steps, len(collated_response_masks)) | ||
|
|
||
| if mode == "group": | ||
| counts = torch.zeros(max_group_id + 1, device=self.device, dtype=torch.float32) | ||
| else: | ||
| counts = torch.tensor(0.0, device=self.device, dtype=torch.float32) | ||
|
|
||
| for i in range(group_start, group_end): | ||
| mb_response_masks = collated_response_masks[i] | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() | ||
| if self.args.mask_tool_use and self.args.tool_use: | ||
| mb_tool_mask = collated_tool_masks[i] | ||
| mb_response_masks_bool = mb_response_masks_bool & mb_tool_mask[:, 1:].bool() | ||
|
|
||
| if mode == "group": | ||
| # Filter valid tokens | ||
| valid_mask = mb_response_masks_bool | ||
| flat_mask = valid_mask.flatten() | ||
| if flat_mask.any(): | ||
| # Get group IDs for valid tokens | ||
| flat_response_masks = mb_response_masks[:, 1:].flatten() | ||
| valid_response_masks = flat_response_masks[flat_mask] | ||
| group_ids = (valid_response_masks - 1) // self.args.num_samples_per_prompt_rollout | ||
| # Accumulate counts | ||
| counts.scatter_add_(0, group_ids, torch.ones_like(group_ids, dtype=torch.float32)) | ||
| else: | ||
| counts += mb_response_masks_bool.sum().float() | ||
|
|
||
| # All reduce counts | ||
| if dist.is_available() and dist.is_initialized(): | ||
| dist.barrier() | ||
| dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=None) | ||
|
|
||
| if mode == "token": | ||
| accumulation_counts[group_start] = counts.item() | ||
| else: | ||
| accumulation_counts[group_start] = counts | ||
|
|
||
| return accumulation_counts | ||
|
|
||
| def train( | ||
| self, | ||
| collated_query_responses, | ||
|
|
@@ -1158,6 +1219,17 @@ def train( | |
| ratio_stats = torch.zeros(len(collated_query_responses)) | ||
| entropy_stats = torch.zeros(len(collated_query_responses)) | ||
| for epoch_idx in range(args.num_epochs): | ||
| # Pre-compute total tokens for each accumulation group if using "token" normalization | ||
| # This ensures all minibatches in an accumulation group are normalized by the same total | ||
| accumulation_token_counts = {} | ||
| if args.masked_mean_denominator in ["token", "group"]: | ||
| accumulation_token_counts = self.calculate_token_counts( | ||
| accumulation_steps, | ||
| collated_response_masks, | ||
| collated_tool_masks, | ||
| mode=args.masked_mean_denominator, | ||
| ) | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you makek this a function? |
||
| for i in range(len(collated_query_responses)): | ||
| mb_query_responses = collated_query_responses[i] | ||
| mb_tool_mask = collated_tool_masks[i] | ||
|
|
@@ -1167,6 +1239,14 @@ def train( | |
| # if masking snippets, do it here. | ||
| if args.mask_tool_use and args.tool_use: | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() | ||
|
|
||
| # Determine the denominator for masked_mean normalization | ||
| loss_denominator = args.masked_mean_denominator | ||
| loss_axis = None | ||
| if args.masked_mean_denominator == "token": | ||
| group_start = (i // accumulation_steps) * accumulation_steps | ||
| loss_denominator = accumulation_token_counts[group_start] | ||
|
|
||
| mb_attention_mask = collated_attention_masks[i] | ||
| mb_position_id = collated_position_ids[i] | ||
| mb_local_logprobs, mb_entropy = self.forward( | ||
|
|
@@ -1272,6 +1352,50 @@ def train( | |
|
|
||
| pg_loss_max = torch.max(pg_losses, pg_losses2) | ||
|
|
||
| # for stats computation, for now no denominator is used | ||
| # unless masked_mean_denominator is a numeric value. | ||
| stats_denominator = ( | ||
| args.masked_mean_denominator | ||
| if args.masked_mean_denominator != "token" and args.masked_mean_denominator != "group" | ||
| else None | ||
| ) | ||
|
|
||
| # Define reduction function based on configuration | ||
| if args.masked_mean_denominator == "group": | ||
| group_start = (i // accumulation_steps) * accumulation_steps | ||
| group_counts = accumulation_token_counts[group_start] | ||
| total_active_groups = (group_counts > 0).sum().item() | ||
| group_ids = (mb_response_masks[:, 1:] - 1) // args.num_samples_per_prompt_rollout | ||
|
|
||
| def reduce_fn(v, m, a=None, d=None): | ||
| flat_v = v.flatten() | ||
| flat_m = m.flatten().bool() | ||
| flat_g = group_ids.flatten() | ||
|
|
||
| # if no valid tokens in batch. | ||
| if not flat_m.any(): | ||
| return torch.tensor(0.0, device=v.device) | ||
|
|
||
| valid_v = flat_v[flat_m] | ||
| valid_g = flat_g[flat_m] | ||
|
|
||
| valid_counts = group_counts[valid_g] | ||
| # Avoid division by zero if count is 0 (should not happen for valid tokens) | ||
| valid_counts = torch.max( | ||
| valid_counts, torch.tensor(1.0, device=valid_counts.device, dtype=valid_counts.dtype) | ||
| ) | ||
|
|
||
| weights = 1.0 / (valid_counts * total_active_groups) | ||
|
|
||
| # Sum weighted values | ||
| loss = (valid_v * weights).sum() | ||
| scale = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 | ||
| scale *= accumulation_steps | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Incorrect gradient scaling in group modeIn group mode, the loss is multiplied by |
||
|
|
||
| return loss * scale | ||
| else: | ||
| reduce_fn = masked_mean | ||
|
|
||
| if args.load_ref_policy: | ||
| mb_ref_logprob = collated_ref_logprobs[i] | ||
| # Here we recalculate kl: we want the KL loss to backpropagate through the model | ||
|
|
@@ -1291,17 +1415,23 @@ def train( | |
| elif args.kl_estimator == "kl4": | ||
| kl = kl4 | ||
| # grpo change: directly subtract KL in loss (add) | ||
| loss = masked_mean( | ||
| pg_loss_max + (args.beta * kl), | ||
| mb_response_masks_bool, | ||
| args.masked_mean_axis, | ||
| args.masked_mean_denominator, | ||
| loss = reduce_fn( | ||
| pg_loss_max + (args.beta * kl), mb_response_masks_bool, loss_axis, loss_denominator | ||
| ) | ||
| else: | ||
| loss = masked_mean( | ||
| pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ) | ||
| loss = loss / accumulation_steps | ||
| loss = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) | ||
| if args.masked_mean_denominator == "token": | ||
| # In token mode, we divide by the GLOBAL total number of tokens. | ||
| # DDP averages gradients across ranks (dividing by world_size). | ||
| # To get the true global mean gradient (sum_all_gradients / global_tokens), | ||
| # we must multiply by world_size to cancel out DDP's division. | ||
| if dist.is_available() and dist.is_initialized(): | ||
| loss *= dist.get_world_size() | ||
| elif args.masked_mean_denominator != "group": | ||
| # For "group" mode, the scaling is handled inside reduce_fn. | ||
| # For default (None) or numeric modes, we divide by accumulation_steps here. | ||
| loss = loss / accumulation_steps | ||
|
|
||
| # Clear CUDA cache before backward pass to free memory for reduce_scatter operations | ||
| torch.cuda.empty_cache() | ||
| self.model.backward(loss) | ||
|
|
@@ -1311,44 +1441,29 @@ def train( | |
| with torch.no_grad(): | ||
| if args.load_ref_policy: | ||
| # NOTE: in packed implementation, kl calculation are averages over response tokens | ||
| kl1_stats[i] = masked_mean( | ||
| kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ).float() | ||
| kl2_stats[i] = masked_mean( | ||
| kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ).float() | ||
| kl3_stats[i] = masked_mean( | ||
| kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ).float() | ||
| kl4_stats[i] = masked_mean( | ||
| kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ).float() | ||
| if args.kl_estimator == "kl1": | ||
| kl_loss_stats[i] = kl1_stats[i] * args.beta | ||
| elif args.kl_estimator == "kl2": | ||
| kl_loss_stats[i] = kl2_stats[i] * args.beta | ||
| elif args.kl_estimator == "kl3": | ||
| kl_loss_stats[i] = kl3_stats[i] * args.beta | ||
| elif args.kl_estimator == "kl4": | ||
| kl_loss_stats[i] = kl4_stats[i] * args.beta | ||
| pg_clipfrac_stats[i] = masked_mean( | ||
| (pg_losses2 > pg_losses).float(), | ||
| mb_response_masks_bool, | ||
| args.masked_mean_axis, | ||
| args.masked_mean_denominator, | ||
| ) | ||
| pg_loss_stats[i] = masked_mean( | ||
| pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ) | ||
| loss_stats[i] = loss | ||
| ratio_stats[i] = masked_mean( | ||
| ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ) | ||
| if args.record_entropy: | ||
| # Calculate entropy statistics | ||
| entropy_stats[i] = masked_mean( | ||
| mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ).float() | ||
| kl1_stats[i] = reduce_fn(kl1, mb_response_masks_bool, loss_axis, stats_denominator).float() | ||
| kl2_stats[i] = reduce_fn(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() | ||
| kl3_stats[i] = reduce_fn(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() | ||
| kl4_stats[i] = reduce_fn(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() | ||
| if args.kl_estimator == "kl1": | ||
| kl_loss_stats[i] = kl1_stats[i] * args.beta | ||
| elif args.kl_estimator == "kl2": | ||
| kl_loss_stats[i] = kl2_stats[i] * args.beta | ||
| elif args.kl_estimator == "kl3": | ||
| kl_loss_stats[i] = kl3_stats[i] * args.beta | ||
| elif args.kl_estimator == "kl4": | ||
| kl_loss_stats[i] = kl4_stats[i] * args.beta | ||
| pg_clipfrac_stats[i] = reduce_fn( | ||
| (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator | ||
| ) | ||
| pg_loss_stats[i] = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) | ||
| loss_stats[i] = loss | ||
| ratio_stats[i] = reduce_fn(ratio, mb_response_masks_bool, loss_axis, stats_denominator) | ||
| if args.record_entropy: | ||
| # Calculate entropy statistics | ||
| entropy_stats[i] = reduce_fn( | ||
| mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator | ||
| ).float() | ||
|
|
||
| with torch.no_grad(): | ||
| if args.load_ref_policy: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Missing path separator in checkpoint directory
The code concatenates
/filestoredirectly withself.checkpoint_state_dirwithout a separator, resulting in paths like/filestoremycheckpointinstead of/filestore/my_checkpoint. This creates invalid filesystem paths when the checkpoint directory doesn't already start with/filestore.