From e10eea6e939a9b30c2630e73031f4a538cb2aff0 Mon Sep 17 00:00:00 2001 From: Qing An Date: Sun, 3 May 2026 13:00:28 -0700 Subject: [PATCH 1/4] feat: Qwen3-Omni Thinker RL support --- verl/models/transformers/monkey_patch.py | 5 +- verl/utils/fsdp_utils.py | 84 +++++++++++++++----- verl/utils/model.py | 56 +++++++++++++ verl/utils/reward_score/__init__.py | 4 +- verl/utils/reward_score/gsm8k_thinker.py | 44 ++++++++++ verl/utils/vllm/utils.py | 7 +- verl/workers/engine/fsdp/transformer_impl.py | 20 +++++ 7 files changed, 196 insertions(+), 24 deletions(-) create mode 100644 verl/utils/reward_score/gsm8k_thinker.py diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index a5117d73af4..4c246b3112f 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -330,9 +330,10 @@ def apply_monkey_patch( try: num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads except AttributeError: + text_config = getattr(model.config, "text_config", None) or model.config.get_text_config() num_attention_heads, num_key_value_heads = ( - model.config.text_config.num_attention_heads, - model.config.text_config.num_key_value_heads, + text_config.num_attention_heads, + text_config.num_key_value_heads, ) assert num_attention_heads % ulysses_sp_size == 0, ( diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 8ec9c6c0b2b..bff15465695 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -107,7 +107,12 @@ def _get_attr(attr_name, default_value=None): from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy # Add lambda policy for LoRA modules if is_lora is True - if is_lora: + # Skip lambda policy when min_num_params > 0: the size policy already + # controls wrapping granularity. Combining both creates nested FSDP + # units (LoRA leaves inside decoder-layer units) whose allgather order + # can diverge across ranks when input lengths differ (use_remove_padding), + # causing NCCL deadlocks. + if is_lora and min_num_params == 0: def lambda_policy_fn(module): return bool( @@ -590,7 +595,7 @@ def fsdp2_clip_grad_norm_(parameters, max_norm, norm_type=2.0, error_if_nonfinit return total_norm -def layered_summon_lora_params(fsdp_module) -> OrderedDict: +def layered_summon_lora_params(fsdp_module, is_diffusers=False) -> OrderedDict: from peft.utils.save_and_load import get_peft_model_state_dict def __prefix_submodules(module, prefix): @@ -598,26 +603,47 @@ def __prefix_submodules(module, prefix): if name.startswith(prefix) and "." not in name[len(prefix) :]: yield name, submodule + import logging as _logging + _log = _logging.getLogger(__name__) + lora_params = OrderedDict() - prefix_list = [ - # fsdp - "_fsdp_wrapped_module.base_model.model.", - "_fsdp_wrapped_module.base_model.model.model.", - "_fsdp_wrapped_module.base_model.model.model.layers.", - "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", - # fsdp2 - "base_model.model.", - "base_model.model.model.", - "base_model.model.model.layers.", - "base_model.model.model.language_model.layers.", - ] + if is_diffusers: + prefix_list = [ + # fsdp + "_fsdp_wrapped_module.transformer_blocks.", + # fsdp2 + "transformer_blocks.", + ] + else: + prefix_list = [ + # fsdp + "_fsdp_wrapped_module.base_model.model.", + "_fsdp_wrapped_module.base_model.model.model.", + "_fsdp_wrapped_module.base_model.model.model.layers.", + "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", + # fsdp — Qwen3-Omni: layers are under model.model.thinker.model.layers + "_fsdp_wrapped_module.base_model.model.thinker.model.layers.", + # fsdp2 + "base_model.model.", + "base_model.model.model.", + "base_model.model.model.layers.", + "base_model.model.model.language_model.layers.", + # fsdp2 — Qwen3-Omni + "base_model.model.thinker.model.layers.", + ] peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) for prefix in prefix_list: for name, submodule in __prefix_submodules(fsdp_module, prefix): - prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") + if is_diffusers: + prefix = name.replace("_fsdp_wrapped_module.", "") + else: + prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") if name.endswith(".model") or name.endswith(".layers"): + _log.info("[layered_summon] SKIP (ends with .model/.layers): %s", name) continue - if fsdp_version(submodule) > 0: + fv = fsdp_version(submodule) + if fv > 0: + _log.info("[layered_summon] MATCH prefix=%s name=%s fsdp_ver=%d", prefix, name, fv) with FSDP.summon_full_params(submodule, writeback=False): sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) sub_lora_params = { @@ -626,13 +652,21 @@ def __prefix_submodules(module, prefix): else param.detach().cpu() for name, param in sub_lora_params.items() } + _log.info("[layered_summon] collected %d lora params from %s, keys=%s", + len(sub_lora_params), name, list(sub_lora_params.keys())[:3]) lora_params.update(sub_lora_params) submodule._is_root = False get_torch_device().empty_cache() + else: + _log.info("[layered_summon] SKIP (fsdp_ver=0): %s", name) + _log.info("[layered_summon] TOTAL lora_params: %d, sample_keys=%s", + len(lora_params), list(lora_params.keys())[:5]) return lora_params -def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool) -> OrderedDict: +def collect_lora_params( + module: FSDP, layered_summon: bool, base_sync_done: bool, is_diffusers: bool = False +) -> OrderedDict: """ collect lora params or full params if base model is not ready in vllm work with if isinstance(self.module._fsdp_wrapped_module, PeftModel) @@ -648,7 +682,21 @@ def collect_lora_params(module: FSDP, layered_summon: bool, base_sync_done: bool "To use layered_summon, you must make sure base-model is preloaded in vllm, e.g. let " "rollout.load_format=safetensors" ) - lora_params = layered_summon_lora_params(module) + lora_params = layered_summon_lora_params(module, is_diffusers=is_diffusers) + if not lora_params: + import logging as _fallback_log + _fallback_log.getLogger(__name__).warning( + "layered_summon returned empty — falling back to full summon with offload_to_cpu" + ) + with FSDP.summon_full_params(module, writeback=False, offload_to_cpu=True): + lora_params = get_peft_model_state_dict(peft_model) + lora_params = { + name: param.full_tensor().detach().cpu() + if hasattr(param, "full_tensor") + else param.detach().cpu() + for name, param in lora_params.items() + } + get_torch_device().empty_cache() else: with FSDP.summon_full_params(module, writeback=False): if base_sync_done: diff --git a/verl/utils/model.py b/verl/utils/model.py index d4f1939a8e4..4bcfca9f55e 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -675,11 +675,67 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod _architecture_to_auto_class = { "ForCausalLM": AutoModelForCausalLM, + "ForConditionalGeneration": AutoModelForCausalLM, "ForVision2Seq": AutoModelForVision2Seq, "ForTokenClassification": AutoModelForTokenClassification, "ForSequenceClassification": AutoModelForSequenceClassification, } +# Register Qwen3-Omni Thinker in AutoModelForCausalLM so veRL's FSDP engine can load it. +# Qwen3OmniMoe uses "ForConditionalGeneration" suffix but is a decoder-only causal LM. +# We register the Thinker-only class since that's what we train in Thinker post-training. +try: + from transformers.models.qwen3_omni_moe import ( + Qwen3OmniMoeConfig, + Qwen3OmniMoeForConditionalGeneration, + ) + + def _qwen3_omni_get_input_embeddings(self): + return self.thinker.get_input_embeddings() + + def _qwen3_omni_set_input_embeddings(self, value): + self.thinker.set_input_embeddings(value) + + def _qwen3_omni_forward(self, input_ids=None, attention_mask=None, position_ids=None, + past_key_values=None, inputs_embeds=None, labels=None, + use_cache=None, output_attentions=None, + output_hidden_states=None, return_dict=None, **kwargs): + return self.thinker( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + Qwen3OmniMoeForConditionalGeneration.forward = _qwen3_omni_forward + Qwen3OmniMoeForConditionalGeneration.get_input_embeddings = _qwen3_omni_get_input_embeddings + Qwen3OmniMoeForConditionalGeneration.set_input_embeddings = _qwen3_omni_set_input_embeddings + # Fix _no_split_modules: the full model incorrectly lists Qwen3OmniMoeDecoderLayer + # which doesn't exist; the actual Thinker decoder layer is Qwen3OmniMoeThinkerTextDecoderLayer. + Qwen3OmniMoeForConditionalGeneration._no_split_modules = ["Qwen3OmniMoeThinkerTextDecoderLayer"] + Qwen3OmniMoeForConditionalGeneration._verl_strip_modules = [ + "talker", "code2wav", "code_predictor", + ] + # Fix tie_word_embeddings: the full model config sets this to True, which forces + # all FSDP ranks to load on CPU (use_meta_tensor=False) and OOMs during FSDP init. + # Use a descriptor that returns False but has a no-op setter so config __init__ works. + class _FalseTieDescriptor: + def __get__(self, obj, objtype=None): + return False + def __set__(self, obj, value): + pass + Qwen3OmniMoeConfig.tie_word_embeddings = _FalseTieDescriptor() + AutoModelForCausalLM.register(Qwen3OmniMoeConfig, Qwen3OmniMoeForConditionalGeneration) +except Exception: + pass + def get_hf_auto_model_class(hf_config): has_remote_code = hasattr(hf_config, "auto_map") and any( diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 180dc6b1474..8c53747c224 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -42,9 +42,9 @@ def default_compute_score( NotImplementedError: If the reward function is not implemented for the given data source. """ if data_source == "openai/gsm8k": - from . import gsm8k + from . import gsm8k_thinker - res = gsm8k.compute_score(solution_str, ground_truth) + res = gsm8k_thinker.compute_score(solution_str, ground_truth) elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "HuggingFaceH4/MATH-500"]: from . import math_reward diff --git a/verl/utils/reward_score/gsm8k_thinker.py b/verl/utils/reward_score/gsm8k_thinker.py new file mode 100644 index 00000000000..ea6d1a458af --- /dev/null +++ b/verl/utils/reward_score/gsm8k_thinker.py @@ -0,0 +1,44 @@ +import re + +_ANSWER_CLIP_CHARS = 1000 + + +def _get_answer_region(solution_str: str) -> str: + """Use content after if present, else fall back to the tail.""" + think_end = solution_str.rfind("") + if think_end != -1: + region = solution_str[think_end + len(""):] + else: + region = solution_str + if len(region) > _ANSWER_CLIP_CHARS: + region = region[-_ANSWER_CLIP_CHARS:] + return region + + +def extract_solution(solution_str: str) -> str | None: + region = _get_answer_region(solution_str) + + # #### (standard GSM8K format) + m = re.findall(r"#### (\-?[0-9\.,]+)", region) + if m: + return m[-1].replace(",", "").replace("$", "") + + # \boxed{} (Qwen3 thinking model format) + m = re.findall(r"\\boxed\{(\-?[0-9\.,]+)\}", region) + if m: + return m[-1].replace(",", "") + + # last plain number as fallback + m = re.findall(r"(\-?[0-9\.,]+)", region) + for candidate in reversed(m): + if candidate not in ("", "."): + return candidate.replace(",", "") + + return None + + +def compute_score(solution_str: str, ground_truth: str, score: float = 1.0, format_score: float = 0.0) -> float: + answer = extract_solution(solution_str) + if answer is None: + return 0.0 + return score if answer == ground_truth else format_score diff --git a/verl/utils/vllm/utils.py b/verl/utils/vllm/utils.py index 1ac655fcf60..dca4f0740cb 100644 --- a/verl/utils/vllm/utils.py +++ b/verl/utils/vllm/utils.py @@ -60,7 +60,10 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_tensors = None from vllm.lora.peft_helper import PEFTHelper - if isinstance(lora_request, TensorLoRARequest): + _has_tensor_lora = hasattr(lora_request, "peft_config") and hasattr( + lora_request, "lora_tensors" + ) + if isinstance(lora_request, TensorLoRARequest) or _has_tensor_lora: peft_config = lora_request.peft_config lora_tensors = lora_request.lora_tensors peft_helper = PEFTHelper.from_dict(peft_config) @@ -96,7 +99,7 @@ def hijack__load_adapter(self, lora_request: TensorLoRARequest) -> LoRAModel: lora_request_kwargs["target_embedding_padding"] = ( self.vocab_size + self.lora_config.lora_extra_vocab_size ) - if isinstance(lora_request, TensorLoRARequest): + if isinstance(lora_request, TensorLoRARequest) or _has_tensor_lora: lora = self._lora_model_cls.from_lora_tensors( tensors=lora_tensors, **lora_request_kwargs, diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index 657d3e2e1c9..6eabfc2f5ca 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -246,6 +246,18 @@ def _build_module(self): config=self.model_config.hf_config, trust_remote_code=self.model_config.trust_remote_code, ) + + _strip_list = getattr(module, '_verl_strip_modules', []) + print(f"[STRIP DEBUG] module class: {type(module).__name__}, " + f"_verl_strip_modules: {_strip_list}, " + f"children: {[n for n, _ in module.named_children()]}") + for attr in _strip_list: + if hasattr(module, attr): + delattr(module, attr) + print(f"[STRIP DEBUG] Stripped '{attr}'") + logger.info(f"Stripped unused sub-module '{attr}' to reduce memory") + else: + print(f"[STRIP DEBUG] '{attr}' not found on module") else: from verl.utils.model import load_valuehead_model @@ -326,6 +338,14 @@ def _build_lora_module(self, module): } module = get_peft_model(module, LoraConfig(**lora_config)) + # Cast LoRA params to match base model dtype so FSDP can flatten + # all params in the same unit into a single contiguous tensor. + base_dtype = next((p.dtype for p in module.parameters() if not p.requires_grad), None) + if base_dtype is not None: + for param in module.parameters(): + if param.requires_grad and param.dtype != base_dtype: + param.data = param.data.to(base_dtype) + return module def _build_fsdp_module(self, module): From e8564d461bbb30e4f2e5eaf61de8c873d504f442 Mon Sep 17 00:00:00 2001 From: Qing An Date: Sun, 3 May 2026 13:51:32 -0700 Subject: [PATCH 2/4] feat: register vllm_omni and vllm_omni_ar in _ROLLOUT_REGISTRY --- verl/workers/rollout/base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py index 033e7684c1d..9dc5ef1bbeb 100644 --- a/verl/workers/rollout/base.py +++ b/verl/workers/rollout/base.py @@ -84,6 +84,8 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: ("vllm", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", ("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter", ("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter", + ("vllm_omni", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", + ("vllm_omni_ar", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", } From 23790f30389d7eed71c42e44a92bba9774db0d25 Mon Sep 17 00:00:00 2001 From: Qing An Date: Sun, 3 May 2026 13:59:52 -0700 Subject: [PATCH 3/4] Register vllm_omni and vllm_omni_ar in RolloutReplicaRegistry --- verl/workers/rollout/replica.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index 1f271ce64b0..bf6a03a61fe 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -369,6 +369,22 @@ def _load_trtllm(): RolloutReplicaRegistry.register("trtllm", _load_trtllm) +def _load_vllm_omni(): + from verl_omni.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniReplica + + return vLLMOmniReplica + + +def _load_vllm_omni_ar(): + from verl_omni.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniARReplica + + return vLLMOmniARReplica + + +RolloutReplicaRegistry.register("vllm_omni", _load_vllm_omni) +RolloutReplicaRegistry.register("vllm_omni_ar", _load_vllm_omni_ar) + + # Original function for backward compatibility def get_rollout_replica_class(rollout: str) -> type[RolloutReplica]: return RolloutReplicaRegistry.get(rollout) From d6e0d82cac49d51718626d35f911890878f0d9fa Mon Sep 17 00:00:00 2001 From: Qing An Date: Mon, 4 May 2026 11:52:49 -0700 Subject: [PATCH 4/4] Fix comments and clean codes --- verl/utils/fsdp_utils.py | 23 +++------- verl/utils/model.py | 4 +- verl/utils/reward_score/__init__.py | 4 +- verl/utils/reward_score/gsm8k_thinker.py | 44 -------------------- verl/workers/engine/fsdp/transformer_impl.py | 6 --- verl/workers/rollout/base.py | 1 - verl/workers/rollout/replica.py | 7 ---- 7 files changed, 9 insertions(+), 80 deletions(-) delete mode 100644 verl/utils/reward_score/gsm8k_thinker.py diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index bff15465695..55c38e3e538 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -603,9 +603,6 @@ def __prefix_submodules(module, prefix): if name.startswith(prefix) and "." not in name[len(prefix) :]: yield name, submodule - import logging as _logging - _log = _logging.getLogger(__name__) - lora_params = OrderedDict() if is_diffusers: prefix_list = [ @@ -621,14 +618,12 @@ def __prefix_submodules(module, prefix): "_fsdp_wrapped_module.base_model.model.model.", "_fsdp_wrapped_module.base_model.model.model.layers.", "_fsdp_wrapped_module.base_model.model.model.language_model.layers.", - # fsdp — Qwen3-Omni: layers are under model.model.thinker.model.layers "_fsdp_wrapped_module.base_model.model.thinker.model.layers.", # fsdp2 "base_model.model.", "base_model.model.model.", "base_model.model.model.layers.", "base_model.model.model.language_model.layers.", - # fsdp2 — Qwen3-Omni "base_model.model.thinker.model.layers.", ] peft_model = getattr(fsdp_module, "_fsdp_wrapped_module", fsdp_module) @@ -639,11 +634,8 @@ def __prefix_submodules(module, prefix): else: prefix = name.replace("_fsdp_wrapped_module.base_model.model.", "base_model.model.") if name.endswith(".model") or name.endswith(".layers"): - _log.info("[layered_summon] SKIP (ends with .model/.layers): %s", name) continue - fv = fsdp_version(submodule) - if fv > 0: - _log.info("[layered_summon] MATCH prefix=%s name=%s fsdp_ver=%d", prefix, name, fv) + if fsdp_version(submodule) > 0: with FSDP.summon_full_params(submodule, writeback=False): sub_lora_params = get_peft_model_state_dict(peft_model, state_dict=submodule.state_dict()) sub_lora_params = { @@ -652,15 +644,9 @@ def __prefix_submodules(module, prefix): else param.detach().cpu() for name, param in sub_lora_params.items() } - _log.info("[layered_summon] collected %d lora params from %s, keys=%s", - len(sub_lora_params), name, list(sub_lora_params.keys())[:3]) lora_params.update(sub_lora_params) submodule._is_root = False get_torch_device().empty_cache() - else: - _log.info("[layered_summon] SKIP (fsdp_ver=0): %s", name) - _log.info("[layered_summon] TOTAL lora_params: %d, sample_keys=%s", - len(lora_params), list(lora_params.keys())[:5]) return lora_params @@ -684,9 +670,10 @@ def collect_lora_params( ) lora_params = layered_summon_lora_params(module, is_diffusers=is_diffusers) if not lora_params: - import logging as _fallback_log - _fallback_log.getLogger(__name__).warning( - "layered_summon returned empty — falling back to full summon with offload_to_cpu" + import logging + + logging.getLogger(__name__).warning( + "layered_summon returned empty, falling back to full summon" ) with FSDP.summon_full_params(module, writeback=False, offload_to_cpu=True): lora_params = get_peft_model_state_dict(peft_model) diff --git a/verl/utils/model.py b/verl/utils/model.py index 4bcfca9f55e..ff2d5af02c3 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -675,7 +675,7 @@ def load_valuehead_model(local_path, torch_dtype, model_config, trust_remote_cod _architecture_to_auto_class = { "ForCausalLM": AutoModelForCausalLM, - "ForConditionalGeneration": AutoModelForCausalLM, + "Qwen3OmniMoeForConditionalGeneration": AutoModelForCausalLM, "ForVision2Seq": AutoModelForVision2Seq, "ForTokenClassification": AutoModelForTokenClassification, "ForSequenceClassification": AutoModelForSequenceClassification, @@ -733,7 +733,7 @@ def __set__(self, obj, value): pass Qwen3OmniMoeConfig.tie_word_embeddings = _FalseTieDescriptor() AutoModelForCausalLM.register(Qwen3OmniMoeConfig, Qwen3OmniMoeForConditionalGeneration) -except Exception: +except ImportError: pass diff --git a/verl/utils/reward_score/__init__.py b/verl/utils/reward_score/__init__.py index 8c53747c224..180dc6b1474 100644 --- a/verl/utils/reward_score/__init__.py +++ b/verl/utils/reward_score/__init__.py @@ -42,9 +42,9 @@ def default_compute_score( NotImplementedError: If the reward function is not implemented for the given data source. """ if data_source == "openai/gsm8k": - from . import gsm8k_thinker + from . import gsm8k - res = gsm8k_thinker.compute_score(solution_str, ground_truth) + res = gsm8k.compute_score(solution_str, ground_truth) elif data_source in ["lighteval/MATH", "DigitalLearningGmbH/MATH-lighteval", "HuggingFaceH4/MATH-500"]: from . import math_reward diff --git a/verl/utils/reward_score/gsm8k_thinker.py b/verl/utils/reward_score/gsm8k_thinker.py deleted file mode 100644 index ea6d1a458af..00000000000 --- a/verl/utils/reward_score/gsm8k_thinker.py +++ /dev/null @@ -1,44 +0,0 @@ -import re - -_ANSWER_CLIP_CHARS = 1000 - - -def _get_answer_region(solution_str: str) -> str: - """Use content after if present, else fall back to the tail.""" - think_end = solution_str.rfind("") - if think_end != -1: - region = solution_str[think_end + len(""):] - else: - region = solution_str - if len(region) > _ANSWER_CLIP_CHARS: - region = region[-_ANSWER_CLIP_CHARS:] - return region - - -def extract_solution(solution_str: str) -> str | None: - region = _get_answer_region(solution_str) - - # #### (standard GSM8K format) - m = re.findall(r"#### (\-?[0-9\.,]+)", region) - if m: - return m[-1].replace(",", "").replace("$", "") - - # \boxed{} (Qwen3 thinking model format) - m = re.findall(r"\\boxed\{(\-?[0-9\.,]+)\}", region) - if m: - return m[-1].replace(",", "") - - # last plain number as fallback - m = re.findall(r"(\-?[0-9\.,]+)", region) - for candidate in reversed(m): - if candidate not in ("", "."): - return candidate.replace(",", "") - - return None - - -def compute_score(solution_str: str, ground_truth: str, score: float = 1.0, format_score: float = 0.0) -> float: - answer = extract_solution(solution_str) - if answer is None: - return 0.0 - return score if answer == ground_truth else format_score diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index 6eabfc2f5ca..b5537ee9eee 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -248,16 +248,10 @@ def _build_module(self): ) _strip_list = getattr(module, '_verl_strip_modules', []) - print(f"[STRIP DEBUG] module class: {type(module).__name__}, " - f"_verl_strip_modules: {_strip_list}, " - f"children: {[n for n, _ in module.named_children()]}") for attr in _strip_list: if hasattr(module, attr): delattr(module, attr) - print(f"[STRIP DEBUG] Stripped '{attr}'") logger.info(f"Stripped unused sub-module '{attr}' to reduce memory") - else: - print(f"[STRIP DEBUG] '{attr}' not found on module") else: from verl.utils.model import load_valuehead_model diff --git a/verl/workers/rollout/base.py b/verl/workers/rollout/base.py index 9dc5ef1bbeb..8280d05871b 100644 --- a/verl/workers/rollout/base.py +++ b/verl/workers/rollout/base.py @@ -85,7 +85,6 @@ def generate_sequences(self, prompts: DataProto) -> DataProto: ("sglang", "async"): "verl.workers.rollout.sglang_rollout.sglang_rollout.ServerAdapter", ("trtllm", "async"): "verl.workers.rollout.trtllm_rollout.trtllm_rollout.ServerAdapter", ("vllm_omni", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", - ("vllm_omni_ar", "async"): "verl.workers.rollout.vllm_rollout.ServerAdapter", } diff --git a/verl/workers/rollout/replica.py b/verl/workers/rollout/replica.py index bf6a03a61fe..b139cab72b9 100644 --- a/verl/workers/rollout/replica.py +++ b/verl/workers/rollout/replica.py @@ -375,14 +375,7 @@ def _load_vllm_omni(): return vLLMOmniReplica -def _load_vllm_omni_ar(): - from verl_omni.workers.rollout.vllm_rollout.vllm_omni_async_server import vLLMOmniARReplica - - return vLLMOmniARReplica - - RolloutReplicaRegistry.register("vllm_omni", _load_vllm_omni) -RolloutReplicaRegistry.register("vllm_omni_ar", _load_vllm_omni_ar) # Original function for backward compatibility