diff --git a/tests/entrypoints/openai/test_response_api_mcp_tools.py b/tests/entrypoints/openai/test_response_api_mcp_tools.py index cd338b5555c5..1184b1202b43 100644 --- a/tests/entrypoints/openai/test_response_api_mcp_tools.py +++ b/tests/entrypoints/openai/test_response_api_mcp_tools.py @@ -1,11 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import pytest import pytest_asyncio -from openai import OpenAI from openai_harmony import ToolDescription, ToolNamespaceConfig +from openai import OpenAI from vllm.entrypoints.tool_server import MCPToolServer from ...utils import RemoteOpenAIServer @@ -206,6 +207,67 @@ async def test_mcp_tool_env_flag_disabled(mcp_disabled_client: OpenAI, model_nam ) +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_mcp_tool_calling_streaming_types( + mcp_enabled_client: OpenAI, model_name: str +): + pairs_of_event_types = { + "response.completed": "response.created", + "response.output_item.done": "response.output_item.added", + "response.content_part.done": "response.content_part.added", + "response.output_text.done": "response.output_text.delta", + "response.reasoning_text.done": "response.reasoning_text.delta", + "response.reasoning_part.done": "response.reasoning_part.added", + "response.mcp_call_arguments.done": "response.mcp_call_arguments.delta", + "response.mcp_call.completed": "response.mcp_call.in_progress", + } + + tools = [ + { + "type": "mcp", + "server_label": "code_interpreter", + } + ] + input_text = "What is 13 * 24? Use python to calculate the result." + + stream_response = await mcp_enabled_client.responses.create( + model=model_name, + input=input_text, + tools=tools, + stream=True, + instructions=( + "You must use the Python tool to execute code. Never simulate execution." + ), + ) + + stack_of_event_types = [] + saw_mcp_type = False + async for event in stream_response: + if event.type == "response.created": + stack_of_event_types.append(event.type) + elif event.type == "response.completed": + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + stack_of_event_types.pop() + elif ( + event.type.endswith("added") + or event.type == "response.mcp_call.in_progress" + ): + stack_of_event_types.append(event.type) + elif event.type.endswith("delta"): + if stack_of_event_types[-1] == event.type: + continue + stack_of_event_types.append(event.type) + elif event.type.endswith("done") or event.type == "response.mcp_call.completed": + assert stack_of_event_types[-1] == pairs_of_event_types[event.type] + if "mcp_call" in event.type: + saw_mcp_type = True + stack_of_event_types.pop() + + assert len(stack_of_event_types) == 0 + assert saw_mcp_type, "Should have seen at least one mcp call" + + def test_get_tool_description(): """Test MCPToolServer.get_tool_description filtering logic. diff --git a/tests/entrypoints/openai/test_response_api_with_harmony.py b/tests/entrypoints/openai/test_response_api_with_harmony.py index 6dd2d0e86f9d..f568b82d5050 100644 --- a/tests/entrypoints/openai/test_response_api_with_harmony.py +++ b/tests/entrypoints/openai/test_response_api_with_harmony.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import importlib import importlib.util import json import time @@ -8,11 +7,12 @@ import pytest import pytest_asyncio import requests -from openai import BadRequestError, NotFoundError, OpenAI from openai_harmony import ( Message, ) +from openai import BadRequestError, NotFoundError, OpenAI + from ...utils import RemoteOpenAIServer MODEL_NAME = "openai/gpt-oss-20b" @@ -44,6 +44,8 @@ def server(): env_dict = dict( VLLM_ENABLE_RESPONSES_API_STORE="1", PYTHON_EXECUTION_BACKEND="dangerously_use_uv", + VLLM_GPT_OSS_SYSTEM_TOOL_MCP_LABELS="code_interpreter,container,web_search_preview", + VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS="1", ) with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server: @@ -855,6 +857,237 @@ async def test_function_calling_with_stream(client: OpenAI, model_name: str): assert event.response.output_text is not None +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_function_calling_no_code_interpreter_events( + client: OpenAI, model_name: str +): + """Verify that function calls don't trigger code_interpreter events. + + This test ensures that function calls (functions.*) use their own + function_call event types and don't incorrectly emit code_interpreter + events during streaming. + """ + tools = [GET_WEATHER_SCHEMA] + input_list = [ + { + "role": "user", + "content": "What's the weather like in Paris today?", + } + ] + stream_response = await client.responses.create( + model=model_name, + input=input_list, + tools=tools, + stream=True, + ) + + # Track which event types we see + event_types_seen = set() + function_call_found = False + + async for event in stream_response: + event_types_seen.add(event.type) + + if ( + event.type == "response.output_item.added" + and event.item.type == "function_call" + ): + function_call_found = True + + # Ensure NO code_interpreter events are emitted for function calls + assert "code_interpreter" not in event.type, ( + "Found code_interpreter event " + f"'{event.type}' during function call. Function calls should only " + "emit function_call events, not code_interpreter events." + ) + + # Verify we actually saw a function call + assert function_call_found, "Expected to see a function_call in the stream" + + # Verify we saw the correct function call event types + assert ( + "response.function_call_arguments.delta" in event_types_seen + or "response.function_call_arguments.done" in event_types_seen + ), "Expected to see function_call_arguments events" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_mcp_code_interpreter_streaming(client: OpenAI, model_name: str, server): + tools = [ + { + "type": "mcp", + "server_label": "code_interpreter", + } + ] + input_text = ( + "Calculate 15 * 32 using python. " + "The python interpreter is not stateful and you must print to see the output." + ) + + stream_response = await client.responses.create( + model=model_name, + input=input_text, + tools=tools, + stream=True, + temperature=0.0, + instructions=( + "You must use the Python tool to execute code. Never simulate execution." + ), + ) + + mcp_call_added = False + mcp_call_in_progress = False + mcp_arguments_delta_seen = False + mcp_arguments_done = False + mcp_call_completed = False + mcp_item_done = False + + code_interpreter_events_seen = False + + async for event in stream_response: + if "code_interpreter" in event.type: + code_interpreter_events_seen = True + + if event.type == "response.output_item.added": + if hasattr(event.item, "type") and event.item.type == "mcp_call": + mcp_call_added = True + assert event.item.name == "python" + assert event.item.server_label == "code_interpreter" + + elif event.type == "response.mcp_call.in_progress": + mcp_call_in_progress = True + + elif event.type == "response.mcp_call_arguments.delta": + mcp_arguments_delta_seen = True + assert event.delta is not None + + elif event.type == "response.mcp_call_arguments.done": + mcp_arguments_done = True + assert event.name == "python" + assert event.arguments is not None + + elif event.type == "response.mcp_call.completed": + mcp_call_completed = True + + elif ( + event.type == "response.output_item.done" + and hasattr(event.item, "type") + and event.item.type == "mcp_call" + ): + mcp_item_done = True + assert event.item.name == "python" + assert event.item.status == "completed" + + assert mcp_call_added, "MCP call was not added" + assert mcp_call_in_progress, "MCP call in_progress event not seen" + assert mcp_arguments_delta_seen, "MCP arguments delta event not seen" + assert mcp_arguments_done, "MCP arguments done event not seen" + assert mcp_call_completed, "MCP call completed event not seen" + assert mcp_item_done, "MCP item done event not seen" + + assert not code_interpreter_events_seen, ( + "Should not see code_interpreter events when using MCP type" + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_mcp_tool_multi_turn(client: OpenAI, model_name: str, server): + """Test MCP tool calling across multiple turns. + + This test verifies that MCP tools work correctly in multi-turn conversations, + maintaining state across turns via the previous_response_id mechanism. + """ + tools = [ + { + "type": "mcp", + "server_label": "code_interpreter", + } + ] + + # First turn - make a calculation + response1 = await client.responses.create( + model=model_name, + input="Calculate 123 * 456 using python and print the result.", + tools=tools, + temperature=0.0, + instructions=( + "You must use the Python tool to execute code. Never simulate execution." + ), + extra_body={"enable_response_messages": True}, + ) + + assert response1 is not None + assert response1.status == "completed" + + # Verify MCP call in first response by checking output_messages + tool_call_found = False + tool_response_found = False + for message in response1.output_messages: + recipient = message.get("recipient") + if recipient and recipient.startswith("python"): + tool_call_found = True + + author = message.get("author", {}) + if ( + author.get("role") == "tool" + and author.get("name") + and author.get("name").startswith("python") + ): + tool_response_found = True + + # Verify MCP tools were actually used + assert tool_call_found, "MCP tool call not found in output_messages" + assert tool_response_found, "MCP tool response not found in output_messages" + + # Verify input messages: Should have system message with tool, NO developer message + developer_messages = [ + msg for msg in response1.input_messages if msg["author"]["role"] == "developer" + ] + assert len(developer_messages) == 0, ( + "No developer message expected for elevated tools" + ) + + # Second turn - reference previous calculation + response2 = await client.responses.create( + model=model_name, + input="Now divide that result by 2.", + tools=tools, + temperature=0.0, + instructions=( + "You must use the Python tool to execute code. Never simulate execution." + ), + previous_response_id=response1.id, + extra_body={"enable_response_messages": True}, + ) + + assert response2 is not None + assert response2.status == "completed" + + # Verify input messages are correct: should have two messages - + # one to the python recipient on analysis channel and one from tool role + mcp_recipient_messages = [] + tool_role_messages = [] + for msg in response2.input_messages: + if msg["author"]["role"] == "assistant": + # Check if this is a message to MCP recipient on analysis channel + if msg.get("channel") == "analysis" and msg.get("recipient"): + recipient = msg.get("recipient") + if recipient.startswith("code_interpreter") or recipient == "python": + mcp_recipient_messages.append(msg) + elif msg["author"]["role"] == "tool": + tool_role_messages.append(msg) + + assert len(mcp_recipient_messages) > 0, ( + "Expected message(s) to MCP recipient on analysis channel" + ) + assert len(tool_role_messages) > 0, ( + "Expected message(s) from tool role after MCP call" + ) + + @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_output_messages_enabled(client: OpenAI, model_name: str, server): diff --git a/vllm/entrypoints/openai/serving_responses.py b/vllm/entrypoints/openai/serving_responses.py index fcc4ad826cca..844aa672026e 100644 --- a/vllm/entrypoints/openai/serving_responses.py +++ b/vllm/entrypoints/openai/serving_responses.py @@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Callable, Sequence from contextlib import AsyncExitStack from copy import copy +from dataclasses import dataclass from http import HTTPStatus from typing import Final @@ -27,6 +28,10 @@ ResponseFunctionCallArgumentsDoneEvent, ResponseFunctionToolCall, ResponseFunctionWebSearch, + ResponseMcpCallArgumentsDeltaEvent, + ResponseMcpCallArgumentsDoneEvent, + ResponseMcpCallCompletedEvent, + ResponseMcpCallInProgressEvent, ResponseOutputItem, ResponseOutputItemAddedEvent, ResponseOutputItemDoneEvent, @@ -44,6 +49,7 @@ response_function_web_search, response_text_delta_event, ) +from openai.types.responses.response_output_item import McpCall from openai.types.responses.response_output_text import Logprob, LogprobTopLogprob from openai.types.responses.response_reasoning_item import ( Content as ResponseReasoningTextContent, @@ -119,6 +125,23 @@ logger = init_logger(__name__) +@dataclass +class HarmonyStreamingState: + """Mutable state for harmony streaming event processing.""" + + current_content_index: int = -1 + current_output_index: int = 0 + current_item_id: str = "" + sent_output_item_added: bool = False + is_first_function_call_delta: bool = False + + def reset_for_new_item(self) -> None: + """Reset state when expecting a new output item.""" + self.current_output_index += 1 + self.sent_output_item_added = False + self.is_first_function_call_delta = False + + def _extract_allowed_tools_from_mcp_requests( tools: list[Tool], ) -> dict[str, list[str] | None]: @@ -740,6 +763,26 @@ async def responses_full_generator( self.response_store[response.id] = response return response + def _is_mcp_tool_by_namespace(self, recipient: str | None) -> bool: + """ + Determine if a tool call is an MCP tool based on recipient prefix. + + - Tools starting with "functions." are function calls + - Everything else is an MCP tool + """ + if recipient is None: + return False + + # Function calls have "functions." prefix + # Everything else is an MCP tool + return not recipient.startswith("functions.") + + _TOOL_NAME_TO_MCP_SERVER_LABEL: Final[dict[str, str]] = { + "python": "code_interpreter", + "container": "container", + "browser": "web_search_preview", + } + def _topk_logprobs( self, logprobs: dict[int, SampleLogprob], @@ -1036,8 +1079,7 @@ def _construct_input_messages_with_harmony( del prev_msgs[prev_final_msg_idx + 1 :] for msg in recent_turn_msgs: assert isinstance(msg, OpenAIHarmonyMessage) - if msg.channel != "analysis": - prev_msgs.append(msg) + prev_msgs.append(msg) messages.extend(prev_msgs) # Append the new input. # Responses API supports simple text inputs without chat format. @@ -1520,6 +1562,816 @@ async def _process_simple_streaming_events( ) ) + def _emit_function_call_done_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when a function call completes.""" + function_name = previous_item.recipient[len("functions.") :] + events = [] + events.append( + ResponseFunctionCallArgumentsDoneEvent( + type="response.function_call_arguments.done", + arguments=previous_item.content[0].text, + name=function_name, + item_id=state.current_item_id, + output_index=state.current_output_index, + sequence_number=-1, + ) + ) + function_call_item = ResponseFunctionToolCall( + type="function_call", + arguments=previous_item.content[0].text, + name=function_name, + item_id=state.current_item_id, + output_index=state.current_output_index, + sequence_number=-1, + call_id=f"fc_{random_uuid()}", + status="completed", + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=function_call_item, + ) + ) + return events + + def _emit_mcp_call_done_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when an MCP tool call completes.""" + server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get( + previous_item.recipient, previous_item.recipient + ) + events = [] + events.append( + ResponseMcpCallArgumentsDoneEvent( + type="response.mcp_call_arguments.done", + arguments=previous_item.content[0].text, + name=previous_item.recipient, + item_id=state.current_item_id, + output_index=state.current_output_index, + sequence_number=-1, + ) + ) + events.append( + ResponseMcpCallCompletedEvent( + type="response.mcp_call.completed", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=McpCall( + type="mcp_call", + arguments=previous_item.content[0].text, + name=previous_item.recipient, + id=state.current_item_id, + server_label=server_label, + status="completed", + ), + ) + ) + return events + + def _emit_reasoning_done_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when a reasoning (analysis) item completes.""" + content = ResponseReasoningTextContent( + text=previous_item.content[0].text, + type="reasoning_text", + ) + reasoning_item = ResponseReasoningItem( + type="reasoning", + content=[content], + status="completed", + id=state.current_item_id, + summary=[], + ) + events = [] + events.append( + ResponseReasoningTextDoneEvent( + type="response.reasoning_text.done", + item_id=state.current_item_id, + sequence_number=-1, + output_index=state.current_output_index, + content_index=state.current_content_index, + text=previous_item.content[0].text, + ) + ) + events.append( + ResponseReasoningPartDoneEvent( + type="response.reasoning_part.done", + sequence_number=-1, + item_id=state.current_item_id, + output_index=state.current_output_index, + content_index=state.current_content_index, + part=content, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=reasoning_item, + ) + ) + return events + + def _emit_text_output_done_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when a final text output item completes.""" + text_content = ResponseOutputText( + type="output_text", + text=previous_item.content[0].text, + annotations=[], + ) + events = [] + events.append( + ResponseTextDoneEvent( + type="response.output_text.done", + sequence_number=-1, + output_index=state.current_output_index, + content_index=state.current_content_index, + text=previous_item.content[0].text, + logprobs=[], + item_id=state.current_item_id, + ) + ) + events.append( + ResponseContentPartDoneEvent( + type="response.content_part.done", + sequence_number=-1, + item_id=state.current_item_id, + output_index=state.current_output_index, + content_index=state.current_content_index, + part=text_content, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=ResponseOutputMessage( + id=state.current_item_id, + type="message", + role="assistant", + content=[text_content], + status="completed", + ), + ) + ) + return events + + def _emit_previous_item_done_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit done events for the previous item when expecting a new start.""" + if previous_item.recipient is not None: + # Deal with tool call + if previous_item.recipient.startswith("functions."): + return self._emit_function_call_done_events(previous_item, state) + elif ( + self._is_mcp_tool_by_namespace(previous_item.recipient) + and state.current_item_id is not None + and state.current_item_id.startswith("mcp_") + ): + return self._emit_mcp_call_done_events(previous_item, state) + elif previous_item.channel == "analysis": + return self._emit_reasoning_done_events(previous_item, state) + elif previous_item.channel == "final": + return self._emit_text_output_done_events(previous_item, state) + return [] + + def _emit_final_channel_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for final channel text delta streaming.""" + events = [] + if not state.sent_output_item_added: + state.sent_output_item_added = True + state.current_item_id = f"msg_{random_uuid()}" + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=ResponseOutputMessage( + id=state.current_item_id, + type="message", + role="assistant", + content=[], + status="in_progress", + ), + ) + ) + state.current_content_index += 1 + events.append( + ResponseContentPartAddedEvent( + type="response.content_part.added", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + content_index=state.current_content_index, + part=ResponseOutputText( + type="output_text", + text="", + annotations=[], + logprobs=[], + ), + ) + ) + events.append( + ResponseTextDeltaEvent( + type="response.output_text.delta", + sequence_number=-1, + content_index=state.current_content_index, + output_index=state.current_output_index, + item_id=state.current_item_id, + delta=ctx.parser.last_content_delta, + # TODO, use logprobs from ctx.last_request_output + logprobs=[], + ) + ) + return events + + def _emit_analysis_channel_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for analysis channel reasoning delta streaming.""" + events = [] + if not state.sent_output_item_added: + state.sent_output_item_added = True + state.current_item_id = f"msg_{random_uuid()}" + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=ResponseReasoningItem( + type="reasoning", + id=state.current_item_id, + summary=[], + status="in_progress", + ), + ) + ) + state.current_content_index += 1 + events.append( + ResponseReasoningPartAddedEvent( + type="response.reasoning_part.added", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + content_index=state.current_content_index, + part=ResponseReasoningTextContent( + text="", + type="reasoning_text", + ), + ) + ) + events.append( + ResponseReasoningTextDeltaEvent( + type="response.reasoning_text.delta", + item_id=state.current_item_id, + output_index=state.current_output_index, + content_index=state.current_content_index, + delta=ctx.parser.last_content_delta, + sequence_number=-1, + ) + ) + return events + + def _emit_mcp_tool_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + recipient: str, + ) -> list[StreamingResponsesResponse]: + """Emit events for MCP tool delta streaming.""" + server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient) + events = [] + if not state.sent_output_item_added: + state.sent_output_item_added = True + state.current_item_id = f"mcp_{random_uuid()}" + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=McpCall( + type="mcp_call", + id=state.current_item_id, + name=recipient, + arguments="", + server_label=server_label, + status="in_progress", + ), + ) + ) + events.append( + ResponseMcpCallInProgressEvent( + type="response.mcp_call.in_progress", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseMcpCallArgumentsDeltaEvent( + type="response.mcp_call_arguments.delta", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + delta=ctx.parser.last_content_delta, + ) + ) + return events + + def _emit_code_interpreter_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for code interpreter delta streaming.""" + events = [] + if not state.sent_output_item_added: + state.sent_output_item_added = True + state.current_item_id = f"tool_{random_uuid()}" + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=state.current_item_id, + code=None, + container_id="auto", + outputs=None, + status="in_progress", + ), + ) + ) + events.append( + ResponseCodeInterpreterCallInProgressEvent( + type="response.code_interpreter_call.in_progress", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseCodeInterpreterCallCodeDeltaEvent( + type="response.code_interpreter_call_code.delta", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + delta=ctx.parser.last_content_delta, + ) + ) + return events + + def _emit_mcp_prefix_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for MCP prefix (mcp.*) delta streaming.""" + events = [] + if not state.sent_output_item_added: + state.sent_output_item_added = True + state.current_item_id = f"mcp_{random_uuid()}" + mcp_name = ctx.parser.current_recipient[len("mcp.") :] + + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=McpCall( + type="mcp_call", + id=state.current_item_id, + name=mcp_name, + arguments="", + server_label=mcp_name, + status="in_progress", + ), + ) + ) + events.append( + ResponseMcpCallInProgressEvent( + type="response.mcp_call.in_progress", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + + events.append( + ResponseMcpCallArgumentsDeltaEvent( + type="response.mcp_call_arguments.delta", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + delta=ctx.parser.last_content_delta, + ) + ) + return events + + def _emit_content_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for content delta streaming based on channel type.""" + if not ctx.parser.last_content_delta: + return [] + + if ( + ctx.parser.current_channel == "final" + and ctx.parser.current_recipient is None + ): + return self._emit_final_channel_delta_events(ctx, state) + elif ( + ctx.parser.current_channel == "analysis" + and ctx.parser.current_recipient is None + ): + return self._emit_analysis_channel_delta_events(ctx, state) + # built-in tools will be triggered on the analysis channel + # However, occasionally built-in tools will + # still be output to commentary. + elif ( + ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) and ctx.parser.current_recipient is not None: + recipient = ctx.parser.current_recipient + # Check for function calls first - they have their own event handling + if recipient.startswith("functions."): + return self._emit_function_call_delta_events(ctx, state) + is_mcp_tool = self._is_mcp_tool_by_namespace(recipient) + if is_mcp_tool: + return self._emit_mcp_tool_delta_events(ctx, state, recipient) + else: + return self._emit_code_interpreter_delta_events(ctx, state) + elif ( + ( + ctx.parser.current_channel == "commentary" + or ctx.parser.current_channel == "analysis" + ) + and ctx.parser.current_recipient is not None + and ctx.parser.current_recipient.startswith("mcp.") + ): + return self._emit_mcp_prefix_delta_events(ctx, state) + + return [] + + def _emit_browser_tool_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for browser tool calls (web search).""" + function_name = previous_item.recipient[len("browser.") :] + parsed_args = json.loads(previous_item.content[0].text) + action = None + + if function_name == "search": + action = response_function_web_search.ActionSearch( + type="search", + query=parsed_args["query"], + ) + elif function_name == "open": + action = response_function_web_search.ActionOpenPage( + type="open_page", + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) + elif function_name == "find": + action = response_function_web_search.ActionFind( + type="find", + pattern=parsed_args["pattern"], + # TODO: translate to url + url=f"cursor:{parsed_args.get('cursor', '')}", + ) + else: + raise ValueError(f"Unknown function name: {function_name}") + + state.current_item_id = f"tool_{random_uuid()}" + events = [] + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=response_function_web_search.ResponseFunctionWebSearch( + # TODO: generate a unique id for web search call + type="web_search_call", + id=state.current_item_id, + action=action, + status="in_progress", + ), + ) + ) + events.append( + ResponseWebSearchCallInProgressEvent( + type="response.web_search_call.in_progress", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseWebSearchCallSearchingEvent( + type="response.web_search_call.searching", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + # enqueue + events.append( + ResponseWebSearchCallCompletedEvent( + type="response.web_search_call.completed", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=ResponseFunctionWebSearch( + type="web_search_call", + id=state.current_item_id, + action=action, + status="completed", + ), + ) + ) + return events + + def _emit_mcp_tool_completion_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when an MCP tool completes during assistant action turn.""" + recipient = previous_item.recipient + server_label = self._TOOL_NAME_TO_MCP_SERVER_LABEL.get(recipient, recipient) + events = [] + events.append( + ResponseMcpCallArgumentsDoneEvent( + type="response.mcp_call_arguments.done", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + arguments=previous_item.content[0].text, + name=recipient, + ) + ) + events.append( + ResponseMcpCallCompletedEvent( + type="response.mcp_call.completed", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=McpCall( + type="mcp_call", + id=state.current_item_id, + name=recipient, + arguments=previous_item.content[0].text, + server_label=server_label, + status="completed", + ), + ) + ) + return events + + def _emit_code_interpreter_completion_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when code interpreter completes.""" + events = [] + events.append( + ResponseCodeInterpreterCallCodeDoneEvent( + type="response.code_interpreter_call_code.done", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + code=previous_item.content[0].text, + ) + ) + events.append( + ResponseCodeInterpreterCallInterpretingEvent( + type="response.code_interpreter_call.interpreting", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseCodeInterpreterCallCompletedEvent( + type="response.code_interpreter_call.completed", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=ResponseCodeInterpreterToolCallParam( + type="code_interpreter_call", + id=state.current_item_id, + code=previous_item.content[0].text, + container_id="auto", + outputs=[], + status="completed", + ), + ) + ) + return events + + def _emit_mcp_prefix_completion_events( + self, + previous_item, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events when an MCP prefix tool (mcp.*) completes.""" + mcp_name = previous_item.recipient[len("mcp.") :] + events = [] + events.append( + ResponseMcpCallArgumentsDoneEvent( + type="response.mcp_call_arguments.done", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + arguments=previous_item.content[0].text, + name=mcp_name, + ) + ) + events.append( + ResponseMcpCallCompletedEvent( + type="response.mcp_call.completed", + sequence_number=-1, + output_index=state.current_output_index, + item_id=state.current_item_id, + ) + ) + events.append( + ResponseOutputItemDoneEvent( + type="response.output_item.done", + sequence_number=-1, + output_index=state.current_output_index, + item=McpCall( + type="mcp_call", + id=state.current_item_id, + name=mcp_name, + arguments=previous_item.content[0].text, + server_label=mcp_name, + status="completed", + ), + ) + ) + return events + + def _emit_tool_action_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for tool action turn.""" + if not ctx.is_assistant_action_turn() or len(ctx.parser.messages) == 0: + return [] + + events = [] + previous_item = ctx.parser.messages[-1] + + # Handle browser tool + if ( + self.tool_server is not None + and self.tool_server.has_tool("browser") + and previous_item.recipient is not None + and previous_item.recipient.startswith("browser.") + ): + events.extend(self._emit_browser_tool_events(previous_item, state)) + + # Handle tool completion + if ( + self.tool_server is not None + and previous_item.recipient is not None + and state.current_item_id is not None + and state.sent_output_item_added + ): + recipient = previous_item.recipient + # Handle MCP prefix tool completion first + if recipient.startswith("mcp."): + events.extend( + self._emit_mcp_prefix_completion_events(previous_item, state) + ) + else: + # Handle other MCP tool and code interpreter completion + is_mcp_tool = self._is_mcp_tool_by_namespace( + recipient + ) and state.current_item_id.startswith("mcp_") + if is_mcp_tool: + events.extend( + self._emit_mcp_tool_completion_events(previous_item, state) + ) + else: + events.extend( + self._emit_code_interpreter_completion_events( + previous_item, state + ) + ) + + return events + + def _emit_function_call_delta_events( + self, + ctx: StreamingHarmonyContext, + state: HarmonyStreamingState, + ) -> list[StreamingResponsesResponse]: + """Emit events for developer function calls on commentary channel.""" + if not ( + ctx.parser.current_channel == "commentary" + and ctx.parser.current_recipient + and ctx.parser.current_recipient.startswith("functions.") + ): + return [] + + events = [] + if state.is_first_function_call_delta is False: + state.is_first_function_call_delta = True + fc_name = ctx.parser.current_recipient[len("functions.") :] + tool_call_item = ResponseFunctionToolCall( + name=fc_name, + type="function_call", + id=state.current_item_id, + call_id=f"call_{random_uuid()}", + arguments="", + status="in_progress", + ) + state.current_item_id = f"fc_{random_uuid()}" + events.append( + ResponseOutputItemAddedEvent( + type="response.output_item.added", + sequence_number=-1, + output_index=state.current_output_index, + item=tool_call_item, + ) + ) + else: + events.append( + ResponseFunctionCallArgumentsDeltaEvent( + item_id=state.current_item_id, + delta=ctx.parser.last_content_delta, + output_index=state.current_output_index, + sequence_number=-1, + type="response.function_call_arguments.delta", + ) + ) + return events + async def _process_harmony_streaming_events( self, request: ResponsesRequest, @@ -1534,11 +2386,8 @@ async def _process_harmony_streaming_events( [StreamingResponsesResponse], StreamingResponsesResponse ], ) -> AsyncGenerator[StreamingResponsesResponse, None]: - current_content_index = -1 - current_output_index = 0 - current_item_id: str = "" - sent_output_item_added = False - is_first_function_call_delta = False + state = HarmonyStreamingState() + async for ctx in result_generator: assert isinstance(ctx, StreamingHarmonyContext) @@ -1546,435 +2395,26 @@ async def _process_harmony_streaming_events( self._raise_if_error(ctx.finish_reason, request.request_id) if ctx.is_expecting_start(): - current_output_index += 1 - sent_output_item_added = False - is_first_function_call_delta = False + state.reset_for_new_item() if len(ctx.parser.messages) > 0: previous_item = ctx.parser.messages[-1] - if previous_item.recipient is not None: - # Deal with tool call - if previous_item.recipient.startswith("functions."): - function_name = previous_item.recipient[len("functions.") :] - yield _increment_sequence_number_and_return( - ResponseFunctionCallArgumentsDoneEvent( - type="response.function_call_arguments.done", - arguments=previous_item.content[0].text, - name=function_name, - item_id=current_item_id, - output_index=current_output_index, - sequence_number=-1, - ) - ) - function_call_item = ResponseFunctionToolCall( - type="function_call", - arguments=previous_item.content[0].text, - name=function_name, - item_id=current_item_id, - output_index=current_output_index, - sequence_number=-1, - call_id=f"fc_{random_uuid()}", - status="completed", - ) - yield _increment_sequence_number_and_return( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=-1, - output_index=current_output_index, - item=function_call_item, - ) - ) - elif previous_item.channel == "analysis": - content = ResponseReasoningTextContent( - text=previous_item.content[0].text, - type="reasoning_text", - ) - reasoning_item = ResponseReasoningItem( - type="reasoning", - content=[content], - status="completed", - id=current_item_id, - summary=[], - ) - yield _increment_sequence_number_and_return( - ResponseReasoningTextDoneEvent( - type="response.reasoning_text.done", - item_id=current_item_id, - sequence_number=-1, - output_index=current_output_index, - content_index=current_content_index, - text=previous_item.content[0].text, - ) - ) - yield _increment_sequence_number_and_return( - ResponseReasoningPartDoneEvent( - type="response.reasoning_part.done", - sequence_number=-1, - item_id=current_item_id, - output_index=current_output_index, - content_index=current_content_index, - part=content, - ) - ) - yield _increment_sequence_number_and_return( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=-1, - output_index=current_output_index, - item=reasoning_item, - ) - ) - elif previous_item.channel == "final": - text_content = ResponseOutputText( - type="output_text", - text=previous_item.content[0].text, - annotations=[], - ) - yield _increment_sequence_number_and_return( - ResponseTextDoneEvent( - type="response.output_text.done", - sequence_number=-1, - output_index=current_output_index, - content_index=current_content_index, - text=previous_item.content[0].text, - logprobs=[], - item_id=current_item_id, - ) - ) - yield _increment_sequence_number_and_return( - ResponseContentPartDoneEvent( - type="response.content_part.done", - sequence_number=-1, - item_id=current_item_id, - output_index=current_output_index, - content_index=current_content_index, - part=text_content, - ) - ) - yield _increment_sequence_number_and_return( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=-1, - output_index=current_output_index, - item=ResponseOutputMessage( - id=current_item_id, - type="message", - role="assistant", - content=[text_content], - status="completed", - ), - ) - ) - - # stream the output of a harmony message - if ctx.parser.last_content_delta: - if ( - ctx.parser.current_channel == "final" - and ctx.parser.current_recipient is None - ): - if not sent_output_item_added: - sent_output_item_added = True - current_item_id = f"msg_{random_uuid()}" - yield _increment_sequence_number_and_return( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=ResponseOutputMessage( - id=current_item_id, - type="message", - role="assistant", - content=[], - status="in_progress", - ), - ) - ) - current_content_index += 1 - yield _increment_sequence_number_and_return( - ResponseContentPartAddedEvent( - type="response.content_part.added", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - content_index=current_content_index, - part=ResponseOutputText( - type="output_text", - text="", - annotations=[], - logprobs=[], - ), - ) - ) - yield _increment_sequence_number_and_return( - ResponseTextDeltaEvent( - type="response.output_text.delta", - sequence_number=-1, - content_index=current_content_index, - output_index=current_output_index, - item_id=current_item_id, - delta=ctx.parser.last_content_delta, - # TODO, use logprobs from ctx.last_request_output - logprobs=[], - ) - ) - elif ( - ctx.parser.current_channel == "analysis" - and ctx.parser.current_recipient is None - ): - if not sent_output_item_added: - sent_output_item_added = True - current_item_id = f"msg_{random_uuid()}" - yield _increment_sequence_number_and_return( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=ResponseReasoningItem( - type="reasoning", - id=current_item_id, - summary=[], - status="in_progress", - ), - ) - ) - current_content_index += 1 - yield _increment_sequence_number_and_return( - ResponseReasoningPartAddedEvent( - type="response.reasoning_part.added", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - content_index=current_content_index, - part=ResponseReasoningTextContent( - text="", - type="reasoning_text", - ), - ) - ) - yield _increment_sequence_number_and_return( - ResponseReasoningTextDeltaEvent( - type="response.reasoning_text.delta", - item_id=current_item_id, - output_index=current_output_index, - content_index=current_content_index, - delta=ctx.parser.last_content_delta, - sequence_number=-1, - ) - ) - # built-in tools will be triggered on the analysis channel - # However, occasionally built-in tools will - # still be output to commentary. - elif ( - ctx.parser.current_channel == "commentary" - or ctx.parser.current_channel == "analysis" - ) and ctx.parser.current_recipient == "python": - if not sent_output_item_added: - sent_output_item_added = True - current_item_id = f"tool_{random_uuid()}" - yield _increment_sequence_number_and_return( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code=None, - container_id="auto", - outputs=None, - status="in_progress", - ), - ) - ) - yield _increment_sequence_number_and_return( - ResponseCodeInterpreterCallInProgressEvent( - type="response.code_interpreter_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - ) - ) - yield _increment_sequence_number_and_return( - ResponseCodeInterpreterCallCodeDeltaEvent( - type="response.code_interpreter_call_code.delta", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - delta=ctx.parser.last_content_delta, - ) - ) + for event in self._emit_previous_item_done_events( + previous_item, state + ): + yield _increment_sequence_number_and_return(event) - # stream tool call outputs - if ctx.is_assistant_action_turn() and len(ctx.parser.messages) > 0: - previous_item = ctx.parser.messages[-1] - if ( - self.tool_server is not None - and self.tool_server.has_tool("browser") - and previous_item.recipient is not None - and previous_item.recipient.startswith("browser.") - ): - function_name = previous_item.recipient[len("browser.") :] - action = None - parsed_args = json.loads(previous_item.content[0].text) - if function_name == "search": - action = response_function_web_search.ActionSearch( - type="search", - query=parsed_args["query"], - ) - elif function_name == "open": - action = response_function_web_search.ActionOpenPage( - type="open_page", - # TODO: translate to url - url=f"cursor:{parsed_args.get('cursor', '')}", - ) - elif function_name == "find": - action = response_function_web_search.ActionFind( - type="find", - pattern=parsed_args["pattern"], - # TODO: translate to url - url=f"cursor:{parsed_args.get('cursor', '')}", - ) - else: - raise ValueError(f"Unknown function name: {function_name}") + # Stream the output of a harmony message + for event in self._emit_content_delta_events(ctx, state): + yield _increment_sequence_number_and_return(event) - current_item_id = f"tool_{random_uuid()}" - yield _increment_sequence_number_and_return( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=response_function_web_search.ResponseFunctionWebSearch( - # TODO: generate a unique id for web search call - type="web_search_call", - id=current_item_id, - action=action, - status="in_progress", - ), - ) - ) - yield _increment_sequence_number_and_return( - ResponseWebSearchCallInProgressEvent( - type="response.web_search_call.in_progress", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - ) - ) - yield _increment_sequence_number_and_return( - ResponseWebSearchCallSearchingEvent( - type="response.web_search_call.searching", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - ) - ) + # Stream tool call outputs + for event in self._emit_tool_action_events(ctx, state): + yield _increment_sequence_number_and_return(event) - # enqueue - yield _increment_sequence_number_and_return( - ResponseWebSearchCallCompletedEvent( - type="response.web_search_call.completed", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - ) - ) - yield _increment_sequence_number_and_return( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=-1, - output_index=current_output_index, - item=ResponseFunctionWebSearch( - type="web_search_call", - id=current_item_id, - action=action, - status="completed", - ), - ) - ) - - if ( - self.tool_server is not None - and self.tool_server.has_tool("python") - and previous_item.recipient is not None - and previous_item.recipient.startswith("python") - ): - yield _increment_sequence_number_and_return( - ResponseCodeInterpreterCallCodeDoneEvent( - type="response.code_interpreter_call_code.done", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - code=previous_item.content[0].text, - ) - ) - yield _increment_sequence_number_and_return( - ResponseCodeInterpreterCallInterpretingEvent( - type="response.code_interpreter_call.interpreting", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - ) - ) - yield _increment_sequence_number_and_return( - ResponseCodeInterpreterCallCompletedEvent( - type="response.code_interpreter_call.completed", - sequence_number=-1, - output_index=current_output_index, - item_id=current_item_id, - ) - ) - yield _increment_sequence_number_and_return( - ResponseOutputItemDoneEvent( - type="response.output_item.done", - sequence_number=-1, - output_index=current_output_index, - item=ResponseCodeInterpreterToolCallParam( - type="code_interpreter_call", - id=current_item_id, - code=previous_item.content[0].text, - container_id="auto", - # TODO: add outputs here - outputs=[], - status="completed", - ), - ) - ) - # developer tools will be triggered on the commentary channel + # Developer tools will be triggered on the commentary channel # and recipient starts with "functions.TOOL_NAME" - if ( - ctx.parser.current_channel == "commentary" - and ctx.parser.current_recipient - and ctx.parser.current_recipient.startswith("functions.") - ): - if is_first_function_call_delta is False: - is_first_function_call_delta = True - fc_name = ctx.parser.current_recipient[len("functions.") :] - tool_call_item = ResponseFunctionToolCall( - name=fc_name, - type="function_call", - id=current_item_id, - call_id=f"call_{random_uuid()}", - arguments="", - status="in_progress", - ) - current_item_id = f"fc_{random_uuid()}" - yield _increment_sequence_number_and_return( - ResponseOutputItemAddedEvent( - type="response.output_item.added", - sequence_number=-1, - output_index=current_output_index, - item=tool_call_item, - ) - ) - else: - yield _increment_sequence_number_and_return( - ResponseFunctionCallArgumentsDeltaEvent( - item_id=current_item_id, - delta=ctx.parser.last_content_delta, - output_index=current_output_index, - sequence_number=-1, - type="response.function_call_arguments.delta", - ) - ) + for event in self._emit_function_call_delta_events(ctx, state): + yield _increment_sequence_number_and_return(event) async def responses_stream_generator( self,