feat: add Qwen3-Omni Thinker GSPO support#6238
feat: add Qwen3-Omni Thinker GSPO support#6238qinganrice wants to merge 3 commits intoverl-project:mainfrom
Conversation
|
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for the Qwen3-Omni model architecture and enhances FSDP and LoRA handling. Key changes include registering the Qwen3-Omni Thinker as a causal language model with custom forward and embedding logic, implementing a module stripping mechanism to reduce memory usage during FSDP initialization, and adding a new reward scoring utility (gsm8k_thinker) designed for models that output reasoning steps. Additionally, the PR updates LoRA parameter collection to support diffusers and adds a fallback mechanism for parameter summoning. Review feedback highlights the need to narrow broad architecture mappings to prevent conflicts with encoder-decoder models, improve exception handling during model registration, refine regex patterns in the reward scorer to handle currency symbols, and remove debug print statements from production code.
|
|
||
| _architecture_to_auto_class = { | ||
| "ForCausalLM": AutoModelForCausalLM, | ||
| "ForConditionalGeneration": AutoModelForCausalLM, |
There was a problem hiding this comment.
Mapping the generic suffix ForConditionalGeneration to AutoModelForCausalLM is too broad and will cause issues for encoder-decoder models (like T5 or BART) that also use this suffix but are not causal language models. It is safer to use a more specific key for the Qwen3-Omni model.
| "ForConditionalGeneration": AutoModelForCausalLM, | |
| "Qwen3OmniMoeForConditionalGeneration": AutoModelForCausalLM, |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Catching a broad Exception and silently passing is dangerous as it can hide legitimate issues during model registration or monkey-patching. It is better to catch specific exceptions like ImportError if the goal is to handle missing dependencies.
| except Exception: | |
| pass | |
| except ImportError: | |
| pass |
| m = re.findall(r"#### (\-?[0-9\.,]+)", region) | ||
| if m: | ||
| return m[-1].replace(",", "").replace("$", "") | ||
|
|
||
| # \boxed{<number>} (Qwen3 thinking model format) | ||
| m = re.findall(r"\\boxed\{(\-?[0-9\.,]+)\}", region) | ||
| if m: | ||
| return m[-1].replace(",", "") |
There was a problem hiding this comment.
The current regular expressions for GSM8K and boxed formats do not account for an optional currency symbol (e.g., $), which is common in math datasets. Although line 24 attempts to strip $, the regex on line 22 will fail to match the string entirely if a $ is present before the number. The regex should be updated to optionally match the $ symbol outside of the capture group.
| m = re.findall(r"#### (\-?[0-9\.,]+)", region) | |
| if m: | |
| return m[-1].replace(",", "").replace("$", "") | |
| # \boxed{<number>} (Qwen3 thinking model format) | |
| m = re.findall(r"\\boxed\{(\-?[0-9\.,]+)\}", region) | |
| if m: | |
| return m[-1].replace(",", "") | |
| # #### <number> (standard GSM8K format) | |
| m = re.findall(r"#### \$?(\-?[0-9\.,]+)", region) | |
| if m: | |
| return m[-1].replace(",", "") | |
| # \boxed{<number>} (Qwen3 thinking model format) | |
| m = re.findall(r"\\boxed\{\$?(\-?[0-9\.,]+)\}", region) | |
| if m: | |
| return m[-1].replace(",", "") |
| _strip_list = getattr(module, '_verl_strip_modules', []) | ||
| print(f"[STRIP DEBUG] module class: {type(module).__name__}, " | ||
| f"_verl_strip_modules: {_strip_list}, " | ||
| f"children: {[n for n, _ in module.named_children()]}") | ||
| for attr in _strip_list: | ||
| if hasattr(module, attr): | ||
| delattr(module, attr) | ||
| print(f"[STRIP DEBUG] Stripped '{attr}'") | ||
| logger.info(f"Stripped unused sub-module '{attr}' to reduce memory") | ||
| else: | ||
| print(f"[STRIP DEBUG] '{attr}' not found on module") |
There was a problem hiding this comment.
The debug print statements should be removed from the production code. The relevant information is already being logged via logger.info on line 258.
| _strip_list = getattr(module, '_verl_strip_modules', []) | |
| print(f"[STRIP DEBUG] module class: {type(module).__name__}, " | |
| f"_verl_strip_modules: {_strip_list}, " | |
| f"children: {[n for n, _ in module.named_children()]}") | |
| for attr in _strip_list: | |
| if hasattr(module, attr): | |
| delattr(module, attr) | |
| print(f"[STRIP DEBUG] Stripped '{attr}'") | |
| logger.info(f"Stripped unused sub-module '{attr}' to reduce memory") | |
| else: | |
| print(f"[STRIP DEBUG] '{attr}' not found on module") | |
| _strip_list = getattr(module, '_verl_strip_modules', []) | |
| for attr in _strip_list: | |
| if hasattr(module, attr): | |
| delattr(module, attr) | |
| logger.info(f"Stripped unused sub-module '{attr}' to reduce memory") |
Summary
AutoModelForCausalLMwith forward redirect to Thinker, fixtie_word_embeddingsand_no_split_modulesfor FSDP compatibilitymin_num_params > 0to avoid nested FSDP allgather divergenceget_peft_modelso FSDP can flatten mixed-dtype unitsfrom_pretrainedvia_verl_strip_moduleslayered_summonwith fallback to full summon when layered returns emptytext_configfallback in monkey_patch for models without top-levelnum_attention_headsLoRARequestgsm8k_thinkerreward with</think>extraction and\boxed{}supportvllm_omni/vllm_omni_arin rollout and replica registries for verl-omni integrationTest plan