-
Notifications
You must be signed in to change notification settings - Fork 124
Add truncated importance sampling and DrGRPO args #1394
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…airseq2 into jacklanchantin/drgrpo
| ) | ||
| per_token_scaled_advantage = per_token_scaled_advantage * tis_imp_ratio | ||
|
|
||
| if ref_logps is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only use kl if ref_logps were computed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nitpick: this also means that beta is non-zero? does an assert make sense that it should never come here if beta is zero? or something that makes this if statement conditioned on beta for better readability?
otherwise LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets make sure this works for bs>1 before merging ! (as discussed offline)
What does this PR do? Please describe:
tis_imp_ratio_capto use truncated importance sampling correctionadv_std_normarlization(for DrGRPO)ref_logpscomputation for kl if beta == 0 (as done in DrGRPO)loss_token_meanfor normalizing over all tokensFixes #{issue number}
Most importantly, this adds truncated importance sampling correction, as recommended by @uralik.
Does your PR introduce any breaking changes? If yes, please list them:
List of all backwards-incompatible changes.
Check list: