[megatron] feat: support Megatron-FSDP mode for Megatron backend#5423
Open
conver334 wants to merge 30 commits intoverl-project:mainfrom
Open
[megatron] feat: support Megatron-FSDP mode for Megatron backend#5423conver334 wants to merge 30 commits intoverl-project:mainfrom
conver334 wants to merge 30 commits intoverl-project:mainfrom
Conversation
Signed-off-by: conver334 <conver334@gmail.com>
Contributor
There was a problem hiding this comment.
Code Review
This pull request introduces support for Megatron-FSDP as a new training backend, including configuration flags, automatic DDP settings, and state lifecycle management. The overall implementation is sound, but I've identified a critical issue in the FSDP parameter synchronization logic that could lead to inconsistent model states during inference. Additionally, the new example script has a hardcoded model path, which impacts its portability. I've provided suggestions to fix these issues.
4 tasks
Resolve 3 conflicts: - megatron_utils.py: keep FSDP ddp_config in main's cleaner structure; take main's GDN-aware grad buffer handling - megatron_checkpoint_manager.py: apply FSDP while-loop unwrap in main's should_generate_model_sections structure; use main's refactored PEFT save path with FSDP skip for HF checkpoint - run_sft_engine.sh: keep both megatron_fsdp and automodel backends Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Collaborator
|
I tested both SFT and GRPO (without ALL_OFFLOAD) on 8×H100 with Qwen2.5-0.5B, Megatron-LM main, and Megatron-Bridge PR #1910:
@conver334 All looks good to me now. |
# Conflicts: # verl/workers/megatron_workers.py
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Simiao Zuo <simiaoz@nvidia.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Simiao Zuo <simiaoz@nvidia.com>
Update the Megatron-Bridge pin in the Megatron-FSDP CI job and example doc to 6fea5bb (merge commit of NVIDIA-NeMo/Megatron-Bridge#3512), which is the preferred version now that the HF<->Megatron-FSDP weight conversion PR has landed. Also drop the now-merged Megatron-LM PR3191 and Megatron-Bridge PR1910 checkouts from the example doc in favor of the same pinned commits used in CI, and refresh the doc's "Last updated" date. Co-authored-by: Claude
59dd36c to
63ea0d6
Compare
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: conver334 <conver334@gmail.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: conver334 <conver334@gmail.com>
…path Co-authored-by: Claude
Co-authored-by: OpenAI Codex
Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Save and load Megatron-FSDP trainer checkpoints through PyTorch DCP, including model, optimizer, scheduler, and RNG state. Preserve HF export support through Megatron-Bridge and document the current example assumptions. Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: conver334 <conver334@gmail.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Support using Megatron-FSDP for SFT and RL.
Add Megatron-FSDP as a new training backend option for the Megatron engine. This is implementation of #5244 .
Key changes:
examples/grpo_trainer/run_qwen2-7b_math_megatron_fsdp.shand SFTexamples/sft/gsm8k/run_qwen_megatron_fsdp.sh. Run these following the user guidedocs/examples/megatron_fsdp_example.rstChecklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Loss is normal in SFT. Tested on 8*H100 with Moonlight-16B-A3B-Instruct, GSM8K SFT dataset

Loss is normal in SFT. Tested on 8*H100 with Qwen2.5-Math-7B, GSM8K SFT dataset

Reward is normal in GRPO. Tested on 8*H100 with Qwen2.5-Math-7B, GSM8K

MFU

API and Usage Example
Enable Megatron-FSDP by setting three config flags:
actor_rollout_ref.actor.megatron.use_mbridge=True \ actor_rollout_ref.actor.megatron.vanilla_mbridge=False \ actor_rollout_ref.actor.megatron.use_megatron_fsdp=True \The FSDP-specific DDP settings (sharding strategy, overlap, etc.) are auto-configured with defaults. Advanced users can override them:
actor_rollout_ref.actor.megatron.override_ddp_config.data_parallel_sharding_strategy=optim_grads \ actor_rollout_ref.actor.megatron.override_ddp_config.overlap_grad_reduce=False \Design & Code Changes
Megatron-FSDP use the same training loop as Megatron.
Conversion between HuggingFace format and Megatron-FSDP DTensor is implemented via NVIDIA-NeMo/Megatron-Bridge#1910.
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.