Skip to content

Conversation

pramodith
Copy link
Collaborator

What does this PR do?

Introduces the idea of a ReplayBuffer to GRPO. Implementation details

  • A ReplayBuffer is implemented as a heap/priority queue. The buffer stores a score and a dict with the same keys as _generate_and_score_completions.
  • The ReplayBuffer stores entire groups and all the keys associated with a group that'd be needed for computing the loss. Storing the old/ref_log_probs ensures that we don't run any extra forward passes through models.
  • Currently the scoring is based on the summed product of absolute advantages and std of a group.
  • Everytime _generate_and_score_completions is called we check if 1. There are any groups with non-zero variance, these are candidates to be added to the ReplayBuffer. 2. There are any groups with 0 variance, these need to be substituted out with values from the replay buffer.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@pramodith
Copy link
Collaborator Author

pramodith commented Sep 10, 2025

I'm still mulling over the multi-gpu scenario, I'm wondering if we should have the same buffer used on all gpus/processes or if its okay for each gpu/process to have its own buffer. Happy to hear your views on this @qgallouedec

Also still need to add an e2e test for training with the ReplayBuffer.

@pramodith
Copy link
Collaborator Author

pramodith commented Sep 11, 2025

I should probably break the update_with_replay_buffer function into two smaller functions, its too big rn:

  1. add_to_buffer
  2. replace_from_buffer.

Also need to add new test cases to confirm that the code works when the seq lengths in the buffer are different from the current batch.

@qgallouedec
Copy link
Member

can you migrate this into trl.experimental? 🙏

@pramodith
Copy link
Collaborator Author

can you migrate this into trl.experimental? 🙏

Yes will do.

@pramodith pramodith marked this pull request as ready for review September 18, 2025 21:43
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

class GRPOWithReplayBufferConfig(GRPOConfig):
"""
New Parameters:
replay_buffer_size (`int`, *optional*, defaults to `0`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
replay_buffer_size (`int`, *optional*, defaults to `0`):
replay_buffer_size (`int`, *optional*, defaults to `64`):

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

@pramodith pramodith merged commit d1e24df into huggingface:main Sep 24, 2025
1 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants