Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions tests/test_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from vllm_mlx.reasoning import (
DeltaMessage,
GLM4ReasoningParser,
ReasoningParser,
get_parser,
list_parsers,
Expand All @@ -28,6 +29,7 @@ def test_list_parsers_includes_builtin(self):
parsers = list_parsers()
assert "qwen3" in parsers
assert "deepseek_r1" in parsers
assert "glm4" in parsers

def test_get_parser_qwen3(self):
"""Should be able to get Qwen3 parser."""
Expand All @@ -41,6 +43,12 @@ def test_get_parser_deepseek(self):
parser = parser_cls()
assert isinstance(parser, ReasoningParser)

def test_get_parser_glm4(self):
"""Should be able to get GLM4 parser."""
parser_cls = get_parser("glm4")
parser = parser_cls()
assert isinstance(parser, ReasoningParser)

def test_get_unknown_parser_raises(self):
"""Unknown parser name should raise KeyError."""
with pytest.raises(KeyError) as exc_info:
Expand Down Expand Up @@ -914,8 +922,7 @@ def test_streaming_constrain_format(self, parser):
def test_constrain_tokens_stripped(self, parser):
"""<|constrain|> should not leak into output."""
output = (
"<|channel|>final <|constrain|>JSON<|message|>"
'{"hello":"world"}<|return|>'
'<|channel|>final <|constrain|>JSON<|message|>{"hello":"world"}<|return|>'
)
reasoning, content = parser.extract_reasoning(output)
assert "<|constrain|>" not in (content or "")
Expand Down
61 changes: 58 additions & 3 deletions tests/test_tool_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
AutoToolParser,
DeepSeekToolParser,
FunctionaryToolParser,
Glm47ToolParser,
GraniteToolParser,
HermesToolParser,
KimiToolParser,
Expand Down Expand Up @@ -811,9 +812,7 @@ def test_qwen3_coder_multiline_parameter(self):
def test_bare_function_without_tool_call_wrapper(self):
"""Test bare <function=...> blocks without <tool_call> wrapper."""
parser = HermesToolParser()
text = (
"<function=get_weather>" "<parameter=city>Berlin</parameter>" "</function>"
)
text = "<function=get_weather><parameter=city>Berlin</parameter></function>"
result = parser.extract_tool_calls(text)

assert result.tools_called
Expand Down Expand Up @@ -1160,3 +1159,59 @@ def test_streaming_bare_multi_function_blocks(self):
assert len(emitted_calls) == 2
assert emitted_calls[0]["function"]["name"] == "func1"
assert emitted_calls[1]["function"]["name"] == "func2"


class TestGLM47ToolParser:
"""Tests for GLM47 tool parser."""

def test_zero_arguments_tool_call(self):
"""Test Fix 2: Handle zero-argument tool calls without crashing."""
parser = Glm47ToolParser()

output = "<tool_call>get_current_time</tool_call>"

result = parser.extract_tool_calls(output)

assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "get_current_time"
args = json.loads(result.tool_calls[0]["arguments"])
assert args == {}

def test_with_arguments(self):
"""Test tool call with arguments."""
parser = Glm47ToolParser()

output = "<tool_call>search\n<arg_key>query</arg_key><arg_value>Python</arg_value></tool_call>"

result = parser.extract_tool_calls(output)

assert result.tools_called is True
assert len(result.tool_calls) == 1
assert result.tool_calls[0]["name"] == "search"
args = json.loads(result.tool_calls[0]["arguments"])
assert args["query"] == "Python"

def test_streaming_zero_args(self):
"""Test Fix 2: Streaming with zero-argument tool call."""
parser = Glm47ToolParser()

chunks = ["<tool_call>", "get_status", "</tool_call>"]
accumulated = ""
tool_calls_found = False

for chunk in chunks:
prev = accumulated
accumulated += chunk
r = parser.extract_tool_calls_streaming(
previous_text=prev,
current_text=accumulated,
delta_text=chunk,
)
if r is not None and "tool_calls" in r:
tool_calls_found = True
assert r["tool_calls"][0]["function"]["name"] == "get_status"
args = json.loads(r["tool_calls"][0]["function"]["arguments"])
assert args == {}

assert tool_calls_found, "Zero-argument tool call should have been detected"
7 changes: 7 additions & 0 deletions vllm_mlx/reasoning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""

from .base import DeltaMessage, ReasoningParser
from .glm4_parser import GLM4ReasoningParser
from .think_parser import BaseThinkingReasoningParser

# Parser registry
Expand Down Expand Up @@ -76,10 +77,12 @@ def list_parsers() -> list[str]:
def _register_builtin_parsers():
"""Register built-in parsers."""
from .deepseek_r1_parser import DeepSeekR1ReasoningParser
from .glm4_parser import GLM4ReasoningParser
from .gpt_oss_parser import GptOssReasoningParser
from .harmony_parser import HarmonyReasoningParser
from .qwen3_parser import Qwen3ReasoningParser

register_parser("glm4", GLM4ReasoningParser)
register_parser("qwen3", Qwen3ReasoningParser)
register_parser("deepseek_r1", DeepSeekR1ReasoningParser)
register_parser("gpt_oss", GptOssReasoningParser)
Expand All @@ -99,4 +102,8 @@ def _register_builtin_parsers():
"register_parser",
"get_parser",
"list_parsers",
# Built-in parsers
"GLM4ReasoningParser",
"Qwen3ReasoningParser",
"DeepSeekR1ReasoningParser",
]
75 changes: 75 additions & 0 deletions vllm_mlx/reasoning/glm4_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# SPDX-License-Identifier: Apache-2.0
"""
Reasoning parser for GLM4 models.

GLM4 uses <think>...</think> tags for reasoning content. The GLM-4.7 chat
template injects <think> directly into the prompt, so the model never
outputs it natively - it only outputs </think> to end thinking.

This desynchronizes standard parsers, so we need special handling.
"""

from .think_parser import BaseThinkingReasoningParser


class GLM4ReasoningParser(BaseThinkingReasoningParser):
"""
Reasoning parser for GLM4 models.

Handles the GLM-specific case where:
1. The chat template injects <think> into the prompt
2. The model starts its output already "in reasoning"
3. The model only outputs </think> to end thinking

This is different from Qwen3 where both tags may appear in output.

Example:
Model output: "Let me analyze this...</think>The answer is 42."
Output: reasoning="Let me analyze this...", content="The answer is 42."
"""

@property
def start_token(self) -> str:
return "<think>"

@property
def end_token(self) -> str:
return "</think>"

def extract_reasoning(
self,
model_output: str,
) -> tuple[str | None, str | None]:
"""
Extract reasoning from GLM4 output.

GLM4 typically only outputs </think> (not <think>) because the start
token was injected in the prompt by the chat template.

Args:
model_output: Complete model output text.

Returns:
(reasoning, content) tuple.
"""
text = model_output

# Case 1: Both tags present (rare, but handle it)
if self.start_token in text and self.end_token in text:
_, _, after_start = text.partition(self.start_token)
reasoning, _, content = after_start.partition(self.end_token)
return reasoning.strip() or None, content.strip() or None

# Case 2: Only closing tag (most common for GLM)
# Model was already "in reasoning" due to prompt injection
if self.end_token in text:
reasoning, _, content = text.partition(self.end_token)
return reasoning.strip() or None, content.strip() or None

# Case 3: Only start tag (reasoning in progress)
if self.start_token in text:
_, _, reasoning = text.partition(self.start_token)
return reasoning.strip() or None, None

# Case 4: No tags - pure content (thinking disabled)
return None, model_output
Loading
Loading