[fsdp] fix: FSDP2 CPUOffloadPolicy crashes in get_per_tensor_param during weight sync#6188
[fsdp] fix: FSDP2 CPUOffloadPolicy crashes in get_per_tensor_param during weight sync#6188xiefan46 wants to merge 1 commit intoverl-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request addresses a device mismatch crash in FSDP2 when using CPUOffloadPolicy by replacing the standard state_dict() call with get_model_state_dict within the get_per_tensor_param method. A new regression test has been added to verify this behavior. Review feedback suggests that the same fix should be applied to the LoRA merged path to ensure consistency and prevent similar crashes. Additionally, the regression test contains a reference to a non-existent Qwen3Config class, which should be updated to a valid configuration like Qwen2Config.
| params = normalize_peft_param_name(params) | ||
| else: | ||
| params = self.module.state_dict() | ||
| if fsdp_version(self.module) == 2: |
There was a problem hiding this comment.
The fix for the FSDP2 state_dict() crash is also required for the LoRA merged path at line 771. Currently, if merge_lora is enabled with FSDP2 and CPUOffloadPolicy, the call to self.module.state_dict() at line 771 will still trigger the device mismatch error. Consider applying the same get_model_state_dict logic to the LoRA path to ensure consistency and prevent crashes during weight synchronization for LoRA models.
| import torch | ||
| import torch.distributed as dist | ||
| import torch.multiprocessing as mp | ||
| from transformers import AutoConfig, AutoModelForCausalLM, Qwen3Config |
There was a problem hiding this comment.
Qwen3Config is not a standard class in the transformers library (the current latest stable release is Qwen2.5). This will cause an ImportError when running the test in most environments. It should likely be Qwen2Config or Qwen2_5Config.
| from transformers import AutoConfig, AutoModelForCausalLM, Qwen3Config | |
| from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config |
| if world_size < 2: | ||
| pytest.skip("Need at least 2 GPUs") | ||
|
|
||
| config = Qwen3Config(num_hidden_layers=2) |
…erl-project#5995) When using FSDP2 with offload_policy=True (CPUOffloadPolicy), get_per_tensor_param() crashes during weight sync to the vLLM rollout engine with "Attempted to set the storage of a tensor on device cpu to a storage on different device cuda:0". Root cause: After training, CPUOffloadPolicy offloads params to CPU. get_per_tensor_param() calls self.module.state_dict() which crashes because FSDP2 does not override state_dict() like FSDP1 — it uses PyTorch's native nn.Module.state_dict() which cannot handle CPU-offloaded DTensor params. Additionally, load_fsdp_model_to_gpu() (model.to(device)) is a no-op under CPUOffloadPolicy and leaves the model in an inconsistent state if called. Fix: - Guard load_fsdp_model_to_gpu() with _is_offload_param check to skip it when CPUOffloadPolicy manages param placement. - For FSDP2 + CPUOffloadPolicy, use get_fsdp_full_state_dict() (which calls get_model_state_dict with full_state_dict=True) instead of self.module.state_dict(). FSDP2 manual offload and FSDP1 paths unchanged.
f67c534 to
0d978d3
Compare
What does this PR do?
Fixes #5995. When using FSDP2 with offload_policy=True (CPUOffloadPolicy), get_per_tensor_param() crashes during weight sync to the vLLM rollout engine:
RuntimeError: Attempted to set the storage of a tensor on device "cpu"
to a storage on different device "cuda:0"
Root cause: After training, FSDP2's CPUOffloadPolicy offloads params to CPU. get_per_tensor_param() then calls self.module.state_dict(), which crashes because FSDP2's native state_dict() cannot handle
CPU-offloaded params.
Fix: For FSDP2, use get_fsdp_full_state_dict() instead
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,vllm_omni,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
Reproduced the crash in a regression test (test_fsdp2_cpuoffload_state_dict.py). Without the fix the test fails with the device mismatch error; with the fix both cases pass. Verified on 2×H100.
pytest tests/models/test_fsdp2_cpuoffload_state_dict.py -s -x -v
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.