-
Notifications
You must be signed in to change notification settings - Fork 465
Correct loss accumulation for grpo_fast #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @hamishivi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an inaccuracy in Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly adjusts the loss accumulation for grpo_fast by introducing a new normalization method based on the total number of tokens across all ranks. This prevents samples with fewer tokens from having an outsized effect on the gradient updates. The implementation looks solid. I've added a couple of suggestions to improve code readability and maintainability by reducing code duplication and simplifying conditional logic.
| accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() | ||
| else: | ||
| accumulation_group_tokens[group_start] = local_group_tokens | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you makek this a function?
accumulation_group_tokens = maybe_calculate_group_tokens(...)
open_instruct/grpo_fast.py
Outdated
| ).float() | ||
| kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() | ||
| kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() | ||
| kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Undefined variables when reference policy not loaded
Lines 1370-1372 compute kl2_stats, kl3_stats, and kl4_stats using variables kl2, kl3, and kl4 that are only defined inside the if args.load_ref_policy: block (lines 1332-1349). When args.load_ref_policy is False, these variables are undefined, causing a NameError at runtime. These three lines need to be indented to be inside the if args.load_ref_policy: block.
| else: | ||
| self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}" | ||
| if not checkpoint_dir_name.startswith("/filestore"): | ||
| self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Missing path separator in checkpoint directory
The code concatenates /filestore directly with self.checkpoint_state_dir without a separator, resulting in paths like /filestoremycheckpoint instead of /filestore/my_checkpoint. This creates invalid filesystem paths when the checkpoint directory doesn't already start with /filestore.
| # Sum weighted values | ||
| loss = (valid_v * weights).sum() | ||
| scale = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 | ||
| scale *= accumulation_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Incorrect gradient scaling in group mode
In group mode, the loss is multiplied by accumulation_steps on line 1393, causing accumulated gradients to be scaled by accumulation_steps. Since gradients are summed across minibatches during gradient accumulation, this multiplication inflates the effective gradient by a factor of accumulation_steps, making the effective learning rate much larger than intended.
Previously, grpo_fast only averaged loss locally, not globally across ranks, and not taking into account gradient accumulation. This meant that small samples had an outsized effect. See https://huggingface.co/blog/gradient_accumulation for more details. This slipped my notice... we need to check the DPO and SFT code too probably.
Anyway, the fix is: introduce a new value for
masked_mean_denominator:tokens.When set, this divides the loss by the total number of tokens across all ranks for each minibatch, which correctly scales the loss: now we normalize by the total number of tokens in the minibatch across all ranks. So while before we had:
Note here if e.g. rank0_accum0_toks >> rank1_accum1_toks, rank1_accum1_toks would overwhelm the update.
Now we have:
Gave it a test locally and it seems fine. Default right now sticks to the older (technically incorrect) behaviour. Setting
masked_mean_denominatorto an integer value (e.g. following Dr GRPO) also technically fixes since we replace the denominators (previously token counts) with a constant.Note
Adds global token- and group-based loss normalization to GRPO with distributed allreduce and accumulation-aware scaling, refactors masked_mean into rl_utils, and introduces denominator validation with tests.
open_instruct/grpo_fast.py):masked_mean_denominatormodes:"token"(global across ranks via allreduce) and"group"(per-group averaging), defaulting to"token".calculate_token_counts) and apply correct scaling with DDP and gradient accumulation.reduce_fn(usesmasked_meanor group-weighted reduction) for loss and stats.utils.get_denominatorinArgs.__post_init__.checkpoint_state_diruses/filestoreprefix when syncing with GCS.open_instruct/rl_utils.py):masked_meanandmasked_group_meanhelpers.open_instruct/utils.py):get_denominatorto validate/parsemasked_mean_denominator.open_instruct/test_utils.py):utils.get_denominator(valid/invalid inputs).Written by Cursor Bugbot for commit febf44a. This will update automatically on new commits. Configure here.