diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index a5117d73af4..4c246b3112f 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -330,9 +330,10 @@ def apply_monkey_patch( try: num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads except AttributeError: + text_config = getattr(model.config, "text_config", None) or model.config.get_text_config() num_attention_heads, num_key_value_heads = ( - model.config.text_config.num_attention_heads, - model.config.text_config.num_key_value_heads, + text_config.num_attention_heads, + text_config.num_key_value_heads, ) assert num_attention_heads % ulysses_sp_size == 0, ( diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 8ec9c6c0b2b..55c38e3e538 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -107,7 +107,12 @@ def _get_attr(attr_name, default_value=None): from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy # Add lambda policy for LoRA modules if is_lora is True - if is_lora: + # Skip lambda policy when min_num_params > 0: the size policy already + # controls wrapping granularity. Combining both creates nested FSDP + # units (LoRA leaves inside decoder-layer units) whose allgather order + # can diverge across ranks when input lengths differ (use_remove_padding), + # causing NCCL deadlocks. + if is_lora and min_num_params == 0: def lambda_policy_fn(module): return bool( @@ -590,7 +595,7 @@ def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinit return total_norm -def layered_summon_lora_params(fsdp_module) -> OrderedDict: +def layered_summon_lora_params(fsdp_module, is_diffusers=False) -> OrderedDict: from peft.utils.save_and_load import get_peft_model_state_dict def __prefix_submodules(module, prefix): @@ -599,22 +604,35 @@ def __prefix_submodules(module, prefix): yield name, submodule lora_params = OrderedDict() - prefix_list = [ - # fsdp - "_fsdp_wrapped_module.base_model.model.", - "_fsdp_wrapped_module.base_model.model.model.", - "_fsdp_wrapped_module.base_model.model.model.layers.", - "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", - # fsdp2 - "base_model.model.", - "base_model.model.model.", - "base_model.model.model.layers.", - "base_model.model.model.language_model.layers.", - ] + if is_diffusers: + prefix_list = [ + # fsdp + "_fsdp_wrapped_module.transformer_blocks.", + # fsdp2 + "transformer_blocks.", + ] + else: + prefix_list = [ + # fsdp + "_fsdp_wrapped_module.base_model.model.", + "_fsdp_wrapped_module.base_model.model.model.", + "_fsdp_wrapped_module.base_model.model.model.layers.", + "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", + "_fsdp_wrapped_module.base_model.model.thinker.model.layers.", + # fsdp2 + "base_model.model.", + "base_model.model.model.", + "base_model.model.model.layers.", + "base_model.model.model.language_model.layers.", + "base_model.model.thinker.model.layers.", + ] peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) for prefix in prefix_list: for name, submodule in __prefix_submodules(fsdp_module, prefix): - prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") + if is_diffusers: + prefix = name.replace("_fsdp_wrapped_module.", "") + else: + prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") if name.endswith(".model") or name.endswith(".layers"): continue if fsdp_version(submodule) > 0: @@ -632,7 +650,9 @@ def __prefix_submodules(module, prefix): return lora_params -def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict: +def collect_lora_params( + module: FSDP, layered_summon: bool, base_sync_done: bool, is_diffusers: bool = False +) -> OrderedDict: """ collect lora params or full params if base model is not ready in vllm work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) @@ -648,7 +668,22 @@ def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " "rollout.load_format=safetensors" ) - lora_params = layered_summon_lora_params(module) + lora_params = layered_summon_lora_params(module, is_diffusers=is_diffusers) + if not lora_params: + import logging + + logging.getLogger(__name__).warning( + "layered_summon returned empty, falling back to full summon" + ) + with FSDP.summon_full_params(module, writeback=False, offload_to_cpu=True): + lora_params = get_peft_model_state_dict(peft_model) + lora_params = { + name: param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + for name, param in lora_params.items() + } + get_torch_device().empty_cache() else: with FSDP.summon_full_params(module, writeback=False): if base_sync_done: diff --git a/verl/utils/model.py b/verl/utils/model.py index d4f1939a8e4..ff2d5af02c3 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -675,11 +675,67 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod _architecture_to_auto_class = { "ForCausalLM": AutoModelForCausalLM, + "Qwen3OmniMoeForConditionalGeneration": AutoModelForCausalLM, "ForVision2Seq": AutoModelForVision2Seq, "ForTokenClassification": AutoModelForTokenClassification, "ForSequenceClassification": AutoModelForSequenceClassification, } +# Register Qwen3-Omni Thinker in AutoModelForCausalLM so veRL's FSDP engine can load it. +# Qwen3OmniMoe uses "ForConditionalGeneration" suffix but is a decoder-only causal LM. +# We register the Thinker-only class since that's what we train in Thinker post-training. +try: + from transformers.models.qwen3_omni_moe import ( + Qwen3OmniMoeConfig, + Qwen3OmniMoeForConditionalGeneration, + ) + + def _qwen3_omni_get_input_embeddings(self): + return self.thinker.get_input_embeddings() + + def _qwen3_omni_set_input_embeddings(self, value): + self.thinker.set_input_embeddings(value) + + def _qwen3_omni_forward(self, input_ids=None, attention_mask=None, position_ids=None, + past_key_values=None, inputs_embeds=None, labels=None, + use_cache=None, output_attentions=None, + output_hidden_states=None, return_dict=None, **kwargs): + return self.thinker( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + Qwen3OmniMoeForConditionalGeneration.forward = _qwen3_omni_forward + Qwen3OmniMoeForConditionalGeneration.get_input_embeddings = _qwen3_omni_get_input_embeddings + Qwen3OmniMoeForConditionalGeneration.set_input_embeddings = _qwen3_omni_set_input_embeddings + # Fix _no_split_modules: the full model incorrectly lists Qwen3OmniMoeDecoderLayer + # which doesn't exist; the actual Thinker decoder layer is Qwen3OmniMoeThinkerTextDecoderLayer. + Qwen3OmniMoeForConditionalGeneration._no_split_modules = ["Qwen3OmniMoeThinkerTextDecoderLayer"] + Qwen3OmniMoeForConditionalGeneration._verl_strip_modules = [ + "talker", "code2wav", "code_predictor", + ] + # Fix tie_word_embeddings: the full model config sets this to True, which forces + # all FSDP ranks to load on CPU (use_meta_tensor=False) and OOMs during FSDP init. + # Use a descriptor that returns False but has a no-op setter so config __init__ works. + class _FalseTieDescriptor: + def __get__(self, obj, objtype=None): + return False + def __set__(self, obj, value): + pass + Qwen3OmniMoeConfig.tie_word_embeddings = _FalseTieDescriptor() + AutoModelForCausalLM.register(Qwen3OmniMoeConfig, Qwen3OmniMoeForConditionalGeneration) +except ImportError: + pass + def get_hf_auto_model_class(hf_config): has_remote_code = hasattr(hf_config, "auto_map") and any( diff --git a/verl/utils/vllm/utils.py b/verl/utils/vllm/utils.py index 1ac655fcf60..dca4f0740cb 100644 --- a/verl/utils/vllm/utils.py +++ b/verl/utils/vllm/utils.py @@ -60,7 +60,10 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_tensors = None from vllm.lora.peft_helper import PEFTHelper - if isinstance(lora_request, TensorLoRARequest): + _has_tensor_lora = hasattr(lora_request, "peft_config") and hasattr( + lora_request, "lora_tensors" + ) + if isinstance(lora_request, TensorLoRARequest) or _has_tensor_lora: peft_config = lora_request.peft_config lora_tensors = lora_request.lora_tensors peft_helper = PEFTHelper.from_dict(peft_config) @@ -96,7 +99,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_request_kwargs["target_embedding_padding"] = ( self.vocab_size + self.lora_config.lora_extra_vocab_size ) - if isinstance(lora_request, TensorLoRARequest): + if isinstance(lora_request, TensorLoRARequest) or _has_tensor_lora: lora = self._lora_model_cls.from_lora_tensors( tensors=lora_tensors, **lora_request_kwargs, diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index 657d3e2e1c9..b5537ee9eee 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -246,6 +246,12 @@ def _build_module(self): config=self.model_config.hf_config, trust_remote_code=self.model_config.trust_remote_code, ) + + _strip_list = getattr(module, '_verl_strip_modules', []) + for attr in _strip_list: + if hasattr(module, attr): + delattr(module, attr) + logger.info(f"Stripped unused sub-module '{attr}' to reduce memory") else: from verl.utils.model import load_valuehead_model @@ -326,6 +332,14 @@ def _build_lora_module(self, module): } module = get_peft_model(module, LoraConfig(**lora_config)) + # Cast LoRA params to match base model dtype so FSDP can flatten + # all params in the same unit into a single contiguous tensor. + base_dtype = next((p.dtype for p in module.parameters() if not p.requires_grad), None) + if base_dtype is not None: + for param in module.parameters(): + if param.requires_grad and param.dtype != base_dtype: + param.data = param.data.to(base_dtype) + return module def _build_fsdp_module(self, module): diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py index 033e7684c1d..8280d05871b 100644 --- a/verl/workers/rollout/base.py +++ b/verl/workers/rollout/base.py @@ -84,6 +84,7 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: ("vllm", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", ("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter", ("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter", + ("vllm_omni", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", } diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index 1f271ce64b0..b139cab72b9 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -369,6 +369,15 @@ def _load_trtllm(): RolloutReplicaRegistry.register("trtllm", _load_trtllm) +def _load_vllm_omni(): + from verl_omni.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniReplica + + return vLLMOmniReplica + + +RolloutReplicaRegistry.register("vllm_omni", _load_vllm_omni) + + # Original function for backward compatibility def get_rollout_replica_class(rollout: str) -> type[RolloutReplica]: return RolloutReplicaRegistry.get(rollout)