Skip to content

Conversation

@hamishivi
Copy link
Collaborator

@hamishivi hamishivi commented Nov 10, 2025

Previously, grpo_fast only averaged loss locally, not globally across ranks, and not taking into account gradient accumulation. This meant that small samples had an outsized effect. See https://huggingface.co/blog/gradient_accumulation for more details. This slipped my notice... we need to check the DPO and SFT code too probably.

Anyway, the fix is: introduce a new value for masked_mean_denominator: tokens.

When set, this divides the loss by the total number of tokens across all ranks for each minibatch, which correctly scales the loss: now we normalize by the total number of tokens in the minibatch across all ranks. So while before we had:

loss = avg(rank0(loss_accum0 / rank0_accum0_toks + loss_accum1 / rank0_accum1_toks), rank1(loss_accum0 / rank1_accum0_toks + loss_accum1 / rank1_accum1_toks))

Note here if e.g. rank0_accum0_toks >> rank1_accum1_toks, rank1_accum1_toks would overwhelm the update.

Now we have:

loss = avg(rank0((loss_accum0 + loss_accum1) / all_toks), rank1((loss_accum0 + loss_accum1) / all_toks))
= all_loss / all_toks

Gave it a test locally and it seems fine. Default right now sticks to the older (technically incorrect) behaviour. Setting masked_mean_denominator to an integer value (e.g. following Dr GRPO) also technically fixes since we replace the denominators (previously token counts) with a constant.


Note

Adds global token- and group-based loss normalization to GRPO with distributed allreduce and accumulation-aware scaling, refactors masked_mean into rl_utils, and introduces denominator validation with tests.

  • GRPO training (open_instruct/grpo_fast.py):
    • Introduce masked_mean_denominator modes: "token" (global across ranks via allreduce) and "group" (per-group averaging), defaulting to "token".
    • Precompute per-accumulation token/group counts (calculate_token_counts) and apply correct scaling with DDP and gradient accumulation.
    • Replace inline reductions with pluggable reduce_fn (uses masked_mean or group-weighted reduction) for loss and stats.
    • Validate denominator via utils.get_denominator in Args.__post_init__.
    • Minor: ensure checkpoint_state_dir uses /filestore prefix when syncing with GCS.
  • RL utils (open_instruct/rl_utils.py):
    • Add masked_mean and masked_group_mean helpers.
  • Utils (open_instruct/utils.py):
    • Add get_denominator to validate/parse masked_mean_denominator.
  • Tests (open_instruct/test_utils.py):
    • Add tests for utils.get_denominator (valid/invalid inputs).

Written by Cursor Bugbot for commit febf44a. This will update automatically on new commits. Configure here.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hamishivi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an inaccuracy in grpo_fast's loss accumulation, particularly affecting distributed training and gradient accumulation. The core change introduces a 'token'-based normalization strategy for the masked_mean_denominator. This ensures that the loss is correctly scaled by the total number of tokens across all participating ranks and accumulation steps, leading to more stable and accurate training dynamics by preventing smaller batches from disproportionately influencing updates.

Highlights

  • Global Loss Accumulation: Introduced a new mechanism to correctly accumulate loss globally across all distributed ranks and gradient accumulation steps, addressing a previous issue where small samples had an outsized effect on the loss.
  • New masked_mean_denominator Option: Added a special string value 'token' to the masked_mean_denominator argument. When set, this normalizes the loss by the total number of tokens in a minibatch across all ranks, ensuring proper scaling.
  • Distributed Token Counting: Implemented pre-computation and all_reduce operations to gather the total token count across all distributed ranks for each accumulation group, which is then used for global loss normalization.
  • Loss Normalization Adjustment: Modified the loss calculation to conditionally skip division by accumulation_steps when 'token' normalization is active, as the global token count already accounts for the entire accumulation group.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adjusts the loss accumulation for grpo_fast by introducing a new normalization method based on the total number of tokens across all ranks. This prevents samples with fewer tokens from having an outsized effect on the gradient updates. The implementation looks solid. I've added a couple of suggestions to improve code readability and maintainability by reducing code duplication and simplifying conditional logic.

accumulation_group_tokens[group_start] = local_group_tokens_tensor.item()
else:
accumulation_group_tokens[group_start] = local_group_tokens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you makek this a function?

accumulation_group_tokens = maybe_calculate_group_tokens(...)

).float()
kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float()
kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float()
kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Undefined variables when reference policy not loaded

Lines 1370-1372 compute kl2_stats, kl3_stats, and kl4_stats using variables kl2, kl3, and kl4 that are only defined inside the if args.load_ref_policy: block (lines 1332-1349). When args.load_ref_policy is False, these variables are undefined, causing a NameError at runtime. These three lines need to be indented to be inside the if args.load_ref_policy: block.

Fix in Cursor Fix in Web

else:
self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}"
if not checkpoint_dir_name.startswith("/filestore"):
self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Missing path separator in checkpoint directory

The code concatenates /filestore directly with self.checkpoint_state_dir without a separator, resulting in paths like /filestoremycheckpoint instead of /filestore/my_checkpoint. This creates invalid filesystem paths when the checkpoint directory doesn't already start with /filestore.

Fix in Cursor Fix in Web

# Sum weighted values
loss = (valid_v * weights).sum()
scale = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1
scale *= accumulation_steps
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Incorrect gradient scaling in group mode

In group mode, the loss is multiplied by accumulation_steps on line 1393, causing accumulated gradients to be scaled by accumulation_steps. Since gradients are summed across minibatches during gradient accumulation, this multiplication inflates the effective gradient by a factor of accumulation_steps, making the effective learning rate much larger than intended.

Fix in Cursor Fix in Web

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants