-
-
Notifications
You must be signed in to change notification settings - Fork 14.1k
Description
Your current environment
Collecting environment information...
uv is set
==============================
System Info
==============================
OS : Ubuntu 22.04.4 LTS (x86_64)
GCC version : (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version : Could not collect
CMake version : version 3.22.1
Libc version : glibc-2.35
==============================
PyTorch Info
==============================
PyTorch version : 2.10.0+cu129
Is debug build : False
CUDA used to build PyTorch : 12.9
ROCM used to build PyTorch : N/A
==============================
Python Environment
==============================
Python version : 3.13.12 (main, Feb 12 2026, 00:45:41) [Clang 21.1.4 ] (64-bit runtime)
Python platform : Linux-6.5.0-1015-aws-x86_64-with-glibc2.35
==============================
CUDA / GPU Info
==============================
Is CUDA available : True
CUDA runtime version : 11.5.119
CUDA_MODULE_LOADING set to :
GPU models and configuration : GPU 0: NVIDIA A10G
Nvidia driver version : 535.183.01
cuDNN version : Could not collect
HIP runtime version : N/A
MIOpen runtime version : N/A
Is XNNPACK available : True
==============================
CPU Info
==============================
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD EPYC 7R32
CPU family: 23
Model: 49
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
BogoMIPS: 5599.99
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr rdpru wbnoinvd arat npt nrip_save rdpid
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 4 MiB (8 instances)
L3 cache: 32 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
==============================
Versions of relevant libraries
==============================
[pip3] flashinfer-python==0.6.4
[pip3] numpy==2.2.6
[pip3] nvidia-cublas-cu12==12.9.1.4
[pip3] nvidia-cuda-cupti-cu12==12.9.79
[pip3] nvidia-cuda-nvrtc-cu12==12.9.86
[pip3] nvidia-cuda-runtime-cu12==12.9.79
[pip3] nvidia-cudnn-cu12==9.10.2.21
[pip3] nvidia-cudnn-frontend==1.18.0
[pip3] nvidia-cufft-cu12==11.4.1.4
[pip3] nvidia-cufile-cu12==1.14.1.1
[pip3] nvidia-curand-cu12==10.3.10.19
[pip3] nvidia-cusolver-cu12==11.7.5.82
[pip3] nvidia-cusparse-cu12==12.5.10.65
[pip3] nvidia-cusparselt-cu12==0.7.1
[pip3] nvidia-cutlass-dsl==4.4.1
[pip3] nvidia-cutlass-dsl-libs-base==4.4.1
[pip3] nvidia-ml-py==13.590.48
[pip3] nvidia-nccl-cu12==2.27.5
[pip3] nvidia-nvjitlink-cu12==12.9.86
[pip3] nvidia-nvshmem-cu12==3.4.5
[pip3] nvidia-nvtx-cu12==12.9.79
[pip3] pyzmq==27.1.0
[pip3] torch==2.10.0+cu129
[pip3] torch-c-dlpack-ext==0.1.5
[pip3] torchaudio==2.10.0+cu129
[pip3] torchvision==0.25.0+cu129
[pip3] transformers==4.57.6
[pip3] triton==3.6.0
[conda] Could not collect
==============================
vLLM Info
==============================
ROCM Version : Could not collect
vLLM Version : 0.17.0rc1.dev154+gfde4771bb (git sha: fde4771bb)
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled
GPU Topology:
GPU0 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X 0-15 0 N/A
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
==============================
Environment Variables
==============================
PYTORCH_NVML_BASED_CUDA_CHECK=1
TORCHINDUCTOR_COMPILE_THREADS=1
TORCHINDUCTOR_CACHE_DIR=/tmp/torchinductor_ubuntu
π Describe the bug
When using the Responses API (/v1/responses) with streaming enabled and a non-harmony model that uses XML-based tool calls (e.g. Qwen3.5), tool call content is emitted as response.output_text.delta events containing raw XML instead of proper response.function_call_arguments.delta / response.function_call_arguments.done events.
The non-streaming Responses API, Chat Completions streaming, and Chat Completions non-streaming all work correctly β only Responses API streaming is affected.
Root Cause
In vllm/entrypoints/openai/responses/serving.py, _process_simple_streaming_events has two issues:
-
reasoning_parserandtool_parserare mutually exclusive in theif/elifchain. When a model has both a reasoning parser (e.g.<think>tags) and a tool parser, theelif tool_parser:branch is never reached becauseif reasoning_parser:always takes priority. After reasoning ends, subsequent tokens (including<tool_call>XML) fall through as plain content. -
No tool call event emission. Even if the tool parser were reached, there was no code to convert
DeltaMessage.tool_callsintoResponseFunctionCallArgumentsDeltaEvent/ResponseFunctionCallArgumentsDoneEvent. The original code had# todo(kebe7jun) tool call support.
Reproduce
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8000/v1", api_key="dummy")
# Streaming Responses API with tool calls
with client.responses.create(
model="Qwen/Qwen3.5-9B", # or any non-harmony model with tool call support
input="What's the weather in Boston today?",
tools=[{
"type": "function",
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location", "unit"]
}
}],
stream=True,
) as stream:
for event in stream:
print(event)Server started with:
vllm serve Qwen/Qwen3.5-9B \
--enable-auto-tool-choice \
--tool-call-parser qwen3_coderActual behavior
Tool call XML leaks into text delta events:
response.output_text.delta β "<tool_call>"
response.output_text.delta β "\n<function=get_current_weather>"
response.output_text.delta β "\n<parameter=location>"
response.output_text.delta β "Boston, MA"
...
response.output_text.done β "\n\n<tool_call>\n<function=get_current_weather>\n<parameter=location>\nBoston, MA\n</parameter>\n<parameter=unit>\nfahrenheit\n</parameter>\n</function>\n</tool_call>"
Expected behavior
Proper function call events (matching OpenAI's Responses API behavior):
response.output_item.added β {type: "function_call", name: "get_current_weather", status: "in_progress"}
response.function_call_arguments.delta β "{"
response.function_call_arguments.delta β '"location": "Boston, MA"'
response.function_call_arguments.delta β ', "unit": "fahrenheit"'
response.function_call_arguments.delta β "}"
response.function_call_arguments.done β '{"location": "Boston, MA", "unit": "fahrenheit"}'
response.output_item.done β {type: "function_call", arguments: '{"location": "Boston, MA", "unit": "fahrenheit"}'}
Related
- [Bug]: Responses API: Streaming returns ResponseTextDeltaEvent instead of ResponseFunctionCallArgumentsDeltaEvent for tool calls while using non-harmony modelsΒ #29725
- [Frontend] OpenAI Responses API supports Tool/Function calling - non-harmony Β #26874 (fixed non-streaming path only)
Fix
I have a fix at https://github.com/herve-ves/vllm/tree/fix/responses-streaming-tool-call β happy to open a PR if desired.
The fix modifies _process_simple_streaming_events to:
- Handle both reasoning and tool calls (reasoning first, then tool calls after
is_reasoning_end(), matching Chat Completions behavior) - Emit
ResponseFunctionCallArgumentsDeltaEvent/ResponseFunctionCallArgumentsDoneEventvia existingemit_function_call_delta_events/emit_function_call_done_eventshelpers - Properly close message output items before function call events when content precedes tool calls
- Sync
tool_streaming_state.current_output_indexwith the main output index
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.