Skip to content

[fsdp] fix: FSDP2 CPUOffloadPolicy crashes in get_per_tensor_param during weight sync#6188

Open
xiefan46 wants to merge 1 commit intoverl-project:mainfrom
xiefan46:fix/fsdp2-cpuoffload-state-dict-5995
Open

[fsdp] fix: FSDP2 CPUOffloadPolicy crashes in get_per_tensor_param during weight sync#6188
xiefan46 wants to merge 1 commit intoverl-project:mainfrom
xiefan46:fix/fsdp2-cpuoffload-state-dict-5995

Conversation

@xiefan46
Copy link
Copy Markdown
Contributor

@xiefan46 xiefan46 commented Apr 28, 2026

What does this PR do?

Fixes #5995. When using FSDP2 with offload_policy=True (CPUOffloadPolicy), get_per_tensor_param() crashes during weight sync to the vLLM rollout engine:

RuntimeError: Attempted to set the storage of a tensor on device "cpu"
to a storage on different device "cuda:0"

Root cause: After training, FSDP2's CPUOffloadPolicy offloads params to CPU. get_per_tensor_param() then calls self.module.state_dict(), which crashes because FSDP2's native state_dict() cannot handle
CPU-offloaded params.

Fix: For FSDP2, use get_fsdp_full_state_dict() instead

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, vllm_omni, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

Reproduced the crash in a regression test (test_fsdp2_cpuoffload_state_dict.py). Without the fix the test fails with the device mismatch error; with the fix both cases pass. Verified on 2×H100.

pytest tests/models/test_fsdp2_cpuoffload_state_dict.py -s -x -v

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a device mismatch crash in FSDP2 when using CPUOffloadPolicy by replacing the standard state_dict() call with get_model_state_dict within the get_per_tensor_param method. A new regression test has been added to verify this behavior. Review feedback suggests that the same fix should be applied to the LoRA merged path to ensure consistency and prevent similar crashes. Additionally, the regression test contains a reference to a non-existent Qwen3Config class, which should be updated to a valid configuration like Qwen2Config.

params = normalize_peft_param_name(params)
else:
params = self.module.state_dict()
if fsdp_version(self.module) == 2:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The fix for the FSDP2 state_dict() crash is also required for the LoRA merged path at line 771. Currently, if merge_lora is enabled with FSDP2 and CPUOffloadPolicy, the call to self.module.state_dict() at line 771 will still trigger the device mismatch error. Consider applying the same get_model_state_dict logic to the LoRA path to ensure consistency and prevent crashes during weight synchronization for LoRA models.

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoConfig, AutoModelForCausalLM, Qwen3Config
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Qwen3Config is not a standard class in the transformers library (the current latest stable release is Qwen2.5). This will cause an ImportError when running the test in most environments. It should likely be Qwen2Config or Qwen2_5Config.

Suggested change
from transformers import AutoConfig, AutoModelForCausalLM, Qwen3Config
from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config

if world_size < 2:
pytest.skip("Need at least 2 GPUs")

config = Qwen3Config(num_hidden_layers=2)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Update the configuration class to match the corrected import (likely Qwen2Config).

Suggested change
config = Qwen3Config(num_hidden_layers=2)
config = Qwen2Config(num_hidden_layers=2)

…erl-project#5995)

When using FSDP2 with offload_policy=True (CPUOffloadPolicy),
get_per_tensor_param() crashes during weight sync to the vLLM rollout
engine with "Attempted to set the storage of a tensor on device cpu to
a storage on different device cuda:0".

Root cause: After training, CPUOffloadPolicy offloads params to CPU.
get_per_tensor_param() calls self.module.state_dict() which crashes
because FSDP2 does not override state_dict() like FSDP1 — it uses
PyTorch's native nn.Module.state_dict() which cannot handle
CPU-offloaded DTensor params. Additionally, load_fsdp_model_to_gpu()
(model.to(device)) is a no-op under CPUOffloadPolicy and leaves the
model in an inconsistent state if called.

Fix:
- Guard load_fsdp_model_to_gpu() with _is_offload_param check to skip
  it when CPUOffloadPolicy manages param placement.
- For FSDP2 + CPUOffloadPolicy, use get_fsdp_full_state_dict() (which
  calls get_model_state_dict with full_state_dict=True) instead of
  self.module.state_dict(). FSDP2 manual offload and FSDP1 paths
  unchanged.
@xiefan46 xiefan46 force-pushed the fix/fsdp2-cpuoffload-state-dict-5995 branch from f67c534 to 0d978d3 Compare April 28, 2026 10:09
@xiefan46 xiefan46 marked this pull request as ready for review April 28, 2026 10:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] FSDP2 CPUOffloadPolicy + state_dict() crashes with device mismatch during update_weights (non-LoRA full-weight training)

1 participant