-
Notifications
You must be signed in to change notification settings - Fork 465
Adds a flag to skip loading the reference policy to save memory #1171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
hamishivi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but need to fix the dschf stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see this for dschf
hamishivi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Runs:
Note
Adds a flag to skip loading the reference policy (and KL), centralizes ref-policy loading, refactors eval DeepSpeed config, and updates training flow and scripts accordingly.
open_instruct/grpo_fast.py):Args.load_ref_policywith validation (requiresbeta=0.0when disabled).load_ref_policy.compute_logprobs; streamline old-logprobs path.load_ref_policyinopen_instruct/model_utils.py; disable dropout, DS init, optional checkpoint load.get_eval_ds_configinopen_instruct/utils.pyto return(ds_config, HfDeepSpeedConfig)and acceptper_device_train_batch_size.model_utils.save_with_accelerate.open_instruct/ppo.py):load_ref_policyand updatedget_eval_ds_configreturn signature.large_test_script.sh: switch cluster; set--beta 0.0and--load_ref_policy false.single_gpu_on_beaker.sh: set--beta 0.0and--load_ref_policy true.Written by Cursor Bugbot for commit 1974db8. This will update automatically on new commits. Configure here.