Skip to content

feat: add Qwen3-Omni Thinker GSPO support#6238

Draft
qinganrice wants to merge 3 commits intoverl-project:mainfrom
qinganrice:qwen3-omni-thinker-v2
Draft

feat: add Qwen3-Omni Thinker GSPO support#6238
qinganrice wants to merge 3 commits intoverl-project:mainfrom
qinganrice:qwen3-omni-thinker-v2

Conversation

@qinganrice
Copy link
Copy Markdown

Summary

  • Register Qwen3-Omni model in AutoModelForCausalLM with forward redirect to Thinker, fix tie_word_embeddings and _no_split_modules for FSDP compatibility
  • Fix FSDP LoRA deadlock: skip lambda wrap policy when min_num_params > 0 to avoid nested FSDP allgather divergence
  • Cast LoRA params to base model dtype after get_peft_model so FSDP can flatten mixed-dtype units
  • Strip unused sub-modules (Talker/Code2Wav) after from_pretrained via _verl_strip_modules
  • Add Thinker layer prefixes to layered_summon with fallback to full summon when layered returns empty
  • Fix text_config fallback in monkey_patch for models without top-level num_attention_heads
  • Duck-typing fix for vLLM LoRA request to support vllm-omni's LoRARequest
  • Add gsm8k_thinker reward with </think> extraction and \boxed{} support
  • Register vllm_omni / vllm_omni_ar in rollout and replica registries for verl-omni integration

Test plan

  • End-to-end GSPO LoRA training with Qwen3-Omni thinker model

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Qwen3-Omni model architecture and enhances FSDP and LoRA handling. Key changes include registering the Qwen3-Omni Thinker as a causal language model with custom forward and embedding logic, implementing a module stripping mechanism to reduce memory usage during FSDP initialization, and adding a new reward scoring utility (gsm8k_thinker) designed for models that output reasoning steps. Additionally, the PR updates LoRA parameter collection to support diffusers and adds a fallback mechanism for parameter summoning. Review feedback highlights the need to narrow broad architecture mappings to prevent conflicts with encoder-decoder models, improve exception handling during model registration, refine regex patterns in the reward scorer to handle currency symbols, and remove debug print statements from production code.

Comment thread verl/utils/model.py

_architecture_to_auto_class = {
"ForCausalLM": AutoModelForCausalLM,
"ForConditionalGeneration": AutoModelForCausalLM,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Mapping the generic suffix ForConditionalGeneration to AutoModelForCausalLM is too broad and will cause issues for encoder-decoder models (like T5 or BART) that also use this suffix but are not causal language models. It is safer to use a more specific key for the Qwen3-Omni model.

Suggested change
"ForConditionalGeneration": AutoModelForCausalLM,
"Qwen3OmniMoeForConditionalGeneration": AutoModelForCausalLM,

Comment thread verl/utils/model.py
Comment on lines +736 to +737
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Catching a broad Exception and silently passing is dangerous as it can hide legitimate issues during model registration or monkey-patching. It is better to catch specific exceptions like ImportError if the goal is to handle missing dependencies.

Suggested change
except Exception:
pass
except ImportError:
pass

Comment on lines +22 to +29
m = re.findall(r"#### (\-?[0-9\.,]+)", region)
if m:
return m[-1].replace(",", "").replace("$", "")

# \boxed{<number>} (Qwen3 thinking model format)
m = re.findall(r"\\boxed\{(\-?[0-9\.,]+)\}", region)
if m:
return m[-1].replace(",", "")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current regular expressions for GSM8K and boxed formats do not account for an optional currency symbol (e.g., $), which is common in math datasets. Although line 24 attempts to strip $, the regex on line 22 will fail to match the string entirely if a $ is present before the number. The regex should be updated to optionally match the $ symbol outside of the capture group.

Suggested change
m = re.findall(r"#### (\-?[0-9\.,]+)", region)
if m:
return m[-1].replace(",", "").replace("$", "")
# \boxed{<number>} (Qwen3 thinking model format)
m = re.findall(r"\\boxed\{(\-?[0-9\.,]+)\}", region)
if m:
return m[-1].replace(",", "")
# #### <number> (standard GSM8K format)
m = re.findall(r"#### \$?(\-?[0-9\.,]+)", region)
if m:
return m[-1].replace(",", "")
# \boxed{<number>} (Qwen3 thinking model format)
m = re.findall(r"\\boxed\{\$?(\-?[0-9\.,]+)\}", region)
if m:
return m[-1].replace(",", "")

Comment on lines +250 to +260
_strip_list = getattr(module, '_verl_strip_modules', [])
print(f"[STRIP DEBUG] module class: {type(module).__name__}, "
f"_verl_strip_modules: {_strip_list}, "
f"children: {[n for n, _ in module.named_children()]}")
for attr in _strip_list:
if hasattr(module, attr):
delattr(module, attr)
print(f"[STRIP DEBUG] Stripped '{attr}'")
logger.info(f"Stripped unused sub-module '{attr}' to reduce memory")
else:
print(f"[STRIP DEBUG] '{attr}' not found on module")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The debug print statements should be removed from the production code. The relevant information is already being logged via logger.info on line 258.

Suggested change
_strip_list = getattr(module, '_verl_strip_modules', [])
print(f"[STRIP DEBUG] module class: {type(module).__name__}, "
f"_verl_strip_modules: {_strip_list}, "
f"children: {[n for n, _ in module.named_children()]}")
for attr in _strip_list:
if hasattr(module, attr):
delattr(module, attr)
print(f"[STRIP DEBUG] Stripped '{attr}'")
logger.info(f"Stripped unused sub-module '{attr}' to reduce memory")
else:
print(f"[STRIP DEBUG] '{attr}' not found on module")
_strip_list = getattr(module, '_verl_strip_modules', [])
for attr in _strip_list:
if hasattr(module, attr):
delattr(module, attr)
logger.info(f"Stripped unused sub-module '{attr}' to reduce memory")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants