Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions verl/models/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def apply_monkey_patch(
try:
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads
except AttributeError:
text_config = getattr(model.config, "text_config", None) or model.config.get_text_config()
num_attention_heads, num_key_value_heads = (
model.config.text_config.num_attention_heads,
model.config.text_config.num_key_value_heads,
text_config.num_attention_heads,
text_config.num_key_value_heads,
)

assert num_attention_heads % ulysses_sp_size == 0, (
Expand Down
69 changes: 52 additions & 17 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,12 @@ def _get_attr(attr_name, default_value=None):
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy

# Add lambda policy for LoRA modules if is_lora is True
if is_lora:
# Skip lambda policy when min_num_params > 0: the size policy already
# controls wrapping granularity. Combining both creates nested FSDP
# units (LoRA leaves inside decoder-layer units) whose allgather order
# can diverge across ranks when input lengths differ (use_remove_padding),
# causing NCCL deadlocks.
if is_lora and min_num_params == 0:

def lambda_policy_fn(module):
return bool(
Expand Down Expand Up @@ -590,7 +595,7 @@ def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinit
return total_norm


def layered_summon_lora_params(fsdp_module) -> OrderedDict:
def layered_summon_lora_params(fsdp_module, is_diffusers=False) -> OrderedDict:
from peft.utils.save_and_load import get_peft_model_state_dict

def __prefix_submodules(module, prefix):
Expand All @@ -599,22 +604,35 @@ def __prefix_submodules(module, prefix):
yield name, submodule

lora_params = OrderedDict()
prefix_list = [
# fsdp
"_fsdp_wrapped_module.base_model.model.",
"_fsdp_wrapped_module.base_model.model.model.",
"_fsdp_wrapped_module.base_model.model.model.layers.",
"_fsdp_wrapped_module.base_model.model.model.language_model.layers.",
# fsdp2
"base_model.model.",
"base_model.model.model.",
"base_model.model.model.layers.",
"base_model.model.model.language_model.layers.",
]
if is_diffusers:
prefix_list = [
# fsdp
"_fsdp_wrapped_module.transformer_blocks.",
# fsdp2
"transformer_blocks.",
]
else:
prefix_list = [
# fsdp
"_fsdp_wrapped_module.base_model.model.",
"_fsdp_wrapped_module.base_model.model.model.",
"_fsdp_wrapped_module.base_model.model.model.layers.",
"_fsdp_wrapped_module.base_model.model.model.language_model.layers.",
"_fsdp_wrapped_module.base_model.model.thinker.model.layers.",
# fsdp2
"base_model.model.",
"base_model.model.model.",
"base_model.model.model.layers.",
"base_model.model.model.language_model.layers.",
"base_model.model.thinker.model.layers.",
]
peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module)
for prefix in prefix_list:
for name, submodule in __prefix_submodules(fsdp_module, prefix):
prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.")
if is_diffusers:
prefix = name.replace("_fsdp_wrapped_module.", "")
else:
prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.")
if name.endswith(".model") or name.endswith(".layers"):
continue
if fsdp_version(submodule) > 0:
Expand All @@ -632,7 +650,9 @@ def __prefix_submodules(module, prefix):
return lora_params


def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict:
def collect_lora_params(
module: FSDP, layered_summon: bool, base_sync_done: bool, is_diffusers: bool = False
) -> OrderedDict:
"""
collect lora params or full params if base model is not ready in vllm
work with if isinstance(self.module._fsdp_wrapped_module, PeftModel)
Expand All @@ -648,7 +668,22 @@ def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool
"To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let "
"rollout.load_format=safetensors"
)
lora_params = layered_summon_lora_params(module)
lora_params = layered_summon_lora_params(module, is_diffusers=is_diffusers)
if not lora_params:
import logging

logging.getLogger(__name__).warning(
"layered_summon returned empty, falling back to full summon"
)
with FSDP.summon_full_params(module, writeback=False, offload_to_cpu=True):
lora_params = get_peft_model_state_dict(peft_model)
lora_params = {
name: param.full_tensor().detach().cpu()
if hasattr(param, "full_tensor")
else param.detach().cpu()
for name, param in lora_params.items()
}
get_torch_device().empty_cache()
else:
with FSDP.summon_full_params(module, writeback=False):
if base_sync_done:
Expand Down
56 changes: 56 additions & 0 deletions verl/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,11 +675,67 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod

_architecture_to_auto_class = {
"ForCausalLM": AutoModelForCausalLM,
"Qwen3OmniMoeForConditionalGeneration": AutoModelForCausalLM,
"ForVision2Seq": AutoModelForVision2Seq,
"ForTokenClassification": AutoModelForTokenClassification,
"ForSequenceClassification": AutoModelForSequenceClassification,
}

# Register Qwen3-Omni Thinker in AutoModelForCausalLM so veRL's FSDP engine can load it.
# Qwen3OmniMoe uses "ForConditionalGeneration" suffix but is a decoder-only causal LM.
# We register the Thinker-only class since that's what we train in Thinker post-training.
try:
from transformers.models.qwen3_omni_moe import (
Qwen3OmniMoeConfig,
Qwen3OmniMoeForConditionalGeneration,
)

def _qwen3_omni_get_input_embeddings(self):
return self.thinker.get_input_embeddings()

def _qwen3_omni_set_input_embeddings(self, value):
self.thinker.set_input_embeddings(value)

def _qwen3_omni_forward(self, input_ids=None, attention_mask=None, position_ids=None,
past_key_values=None, inputs_embeds=None, labels=None,
use_cache=None, output_attentions=None,
output_hidden_states=None, return_dict=None, **kwargs):
return self.thinker(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
labels=labels,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs,
)

Qwen3OmniMoeForConditionalGeneration.forward = _qwen3_omni_forward
Qwen3OmniMoeForConditionalGeneration.get_input_embeddings = _qwen3_omni_get_input_embeddings
Qwen3OmniMoeForConditionalGeneration.set_input_embeddings = _qwen3_omni_set_input_embeddings
# Fix _no_split_modules: the full model incorrectly lists Qwen3OmniMoeDecoderLayer
# which doesn't exist; the actual Thinker decoder layer is Qwen3OmniMoeThinkerTextDecoderLayer.
Qwen3OmniMoeForConditionalGeneration._no_split_modules = ["Qwen3OmniMoeThinkerTextDecoderLayer"]
Qwen3OmniMoeForConditionalGeneration._verl_strip_modules = [
"talker", "code2wav", "code_predictor",
]
# Fix tie_word_embeddings: the full model config sets this to True, which forces
# all FSDP ranks to load on CPU (use_meta_tensor=False) and OOMs during FSDP init.
# Use a descriptor that returns False but has a no-op setter so config __init__ works.
class _FalseTieDescriptor:
def __get__(self, obj, objtype=None):
return False
def __set__(self, obj, value):
pass
Qwen3OmniMoeConfig.tie_word_embeddings = _FalseTieDescriptor()
AutoModelForCausalLM.register(Qwen3OmniMoeConfig, Qwen3OmniMoeForConditionalGeneration)
except ImportError:
pass


def get_hf_auto_model_class(hf_config):
has_remote_code = hasattr(hf_config, "auto_map") and any(
Expand Down
7 changes: 5 additions & 2 deletions verl/utils/vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,10 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
lora_tensors = None
from vllm.lora.peft_helper import PEFTHelper

if isinstance(lora_request, TensorLoRARequest):
_has_tensor_lora = hasattr(lora_request, "peft_config") and hasattr(
lora_request, "lora_tensors"
)
if isinstance(lora_request, TensorLoRARequest) or _has_tensor_lora:
peft_config = lora_request.peft_config
lora_tensors = lora_request.lora_tensors
peft_helper = PEFTHelper.from_dict(peft_config)
Expand Down Expand Up @@ -96,7 +99,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel:
lora_request_kwargs["target_embedding_padding"] = (
self.vocab_size + self.lora_config.lora_extra_vocab_size
)
if isinstance(lora_request, TensorLoRARequest):
if isinstance(lora_request, TensorLoRARequest) or _has_tensor_lora:
lora = self._lora_model_cls.from_lora_tensors(
tensors=lora_tensors,
**lora_request_kwargs,
Expand Down
14 changes: 14 additions & 0 deletions verl/workers/engine/fsdp/transformer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ def _build_module(self):
config=self.model_config.hf_config,
trust_remote_code=self.model_config.trust_remote_code,
)

_strip_list = getattr(module, '_verl_strip_modules', [])
for attr in _strip_list:
if hasattr(module, attr):
delattr(module, attr)
logger.info(f"Stripped unused sub-module '{attr}' to reduce memory")
else:
from verl.utils.model import load_valuehead_model

Expand Down Expand Up @@ -326,6 +332,14 @@ def _build_lora_module(self, module):
}
module = get_peft_model(module, LoraConfig(**lora_config))

# Cast LoRA params to match base model dtype so FSDP can flatten
# all params in the same unit into a single contiguous tensor.
base_dtype = next((p.dtype for p in module.parameters() if not p.requires_grad), None)
if base_dtype is not None:
for param in module.parameters():
if param.requires_grad and param.dtype != base_dtype:
param.data = param.data.to(base_dtype)

return module

def _build_fsdp_module(self, module):
Expand Down
1 change: 1 addition & 0 deletions verl/workers/rollout/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto:
("vllm", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter",
("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter",
("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter",
("vllm_omni", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter",
}


Expand Down
9 changes: 9 additions & 0 deletions verl/workers/rollout/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,15 @@ def _load_trtllm():
RolloutReplicaRegistry.register("trtllm", _load_trtllm)


def _load_vllm_omni():
from verl_omni.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniReplica

return vLLMOmniReplica


RolloutReplicaRegistry.register("vllm_omni", _load_vllm_omni)


# Original function for backward compatibility
def get_rollout_replica_class(rollout: str) -> type[RolloutReplica]:
return RolloutReplicaRegistry.get(rollout)