Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,9 @@ async def run(
current_agent = starting_agent
should_run_agent_start_hooks = True

# save the original input to the session if enabled
await self._save_result_to_session(session, original_input, [])

try:
while True:
all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
Expand Down Expand Up @@ -516,9 +519,7 @@ async def run(
output_guardrail_results=output_guardrail_results,
context_wrapper=context_wrapper,
)

# Save the conversation to session if enabled
await self._save_result_to_session(session, input, result)
await self._save_result_to_session(session, [], turn_result.new_step_items)

return result
elif isinstance(turn_result.next_step, NextStepHandoff):
Expand All @@ -527,7 +528,7 @@ async def run(
current_span = None
should_run_agent_start_hooks = True
elif isinstance(turn_result.next_step, NextStepRunAgain):
pass
await self._save_result_to_session(session, [], turn_result.new_step_items)
else:
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
Expand Down Expand Up @@ -758,6 +759,8 @@ async def _start_streaming(
# Update the streamed result with the prepared input
streamed_result.input = prepared_input

await AgentRunner._save_result_to_session(session, starting_input, [])

while True:
if streamed_result.is_complete:
break
Expand Down Expand Up @@ -860,24 +863,15 @@ async def _start_streaming(
streamed_result.is_complete = True

# Save the conversation to session if enabled
# Create a temporary RunResult for session saving
temp_result = RunResult(
input=streamed_result.input,
new_items=streamed_result.new_items,
raw_responses=streamed_result.raw_responses,
final_output=streamed_result.final_output,
_last_agent=current_agent,
input_guardrail_results=streamed_result.input_guardrail_results,
output_guardrail_results=streamed_result.output_guardrail_results,
context_wrapper=context_wrapper,
)
await AgentRunner._save_result_to_session(
session, starting_input, temp_result
session, [], turn_result.new_step_items
)

streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
elif isinstance(turn_result.next_step, NextStepRunAgain):
pass
await AgentRunner._save_result_to_session(
session, [], turn_result.new_step_items
)
except AgentsException as exc:
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
Expand Down Expand Up @@ -1448,7 +1442,7 @@ async def _save_result_to_session(
cls,
session: Session | None,
original_input: str | list[TResponseInputItem],
result: RunResult,
new_items: list[RunItem],
) -> None:
"""Save the conversation turn to session."""
if session is None:
Expand All @@ -1458,7 +1452,7 @@ async def _save_result_to_session(
input_list = ItemHelpers.input_to_new_input_list(original_input)

# Convert new items to input format
new_items_as_input = [item.to_input_item() for item in result.new_items]
new_items_as_input = [item.to_input_item() for item in new_items]

# Save all items from this turn
items_to_save = input_list + new_items_as_input
Expand Down
97 changes: 97 additions & 0 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json
import tempfile
from pathlib import Path
from typing import Any
from unittest.mock import patch

import pytest
from typing_extensions import TypedDict
Expand All @@ -20,6 +23,7 @@
RunConfig,
RunContextWrapper,
Runner,
SQLiteSession,
UserError,
handoff,
)
Expand Down Expand Up @@ -780,3 +784,96 @@ async def add_tool() -> str:

assert executed["called"] is True
assert result.final_output == "done"


@pytest.mark.asyncio
async def test_session_add_items_called_multiple_times_for_multi_turn_completion():
"""Test that SQLiteSession.add_items is called multiple times
during a multi-turn agent completion.

"""
with tempfile.TemporaryDirectory() as temp_dir:
db_path = Path(temp_dir) / "test_agent_runner_session_multi_turn_calls.db"
session_id = "runner_session_multi_turn_calls"
session = SQLiteSession(session_id, db_path)

# Define a tool that will be called by the orchestrator agent
@function_tool
async def echo_tool(text: str) -> str:
return f"Echo: {text}"

# Orchestrator agent that calls the tool multiple times in one completion
orchestrator_agent = Agent(
name="orchestrator_agent",
instructions=(
"Call echo_tool twice with inputs of 'foo' and 'bar', then return a summary."
),
tools=[echo_tool],
)

# Patch the model to simulate two tool calls and a final message
model = FakeModel()
orchestrator_agent.model = model
model.add_multiple_turn_outputs(
[
# First turn: tool call
[get_function_tool_call("echo_tool", json.dumps({"text": "foo"}), call_id="1")],
# Second turn: tool call
[get_function_tool_call("echo_tool", json.dumps({"text": "bar"}), call_id="2")],
# Third turn: final output
[get_final_output_message("Summary: Echoed foo and bar")],
]
)

# Patch add_items to count calls
with patch.object(SQLiteSession, "add_items", wraps=session.add_items) as mock_add_items:
result = await Runner.run(orchestrator_agent, input="foo and bar", session=session)

expected_items = [
{"content": "foo and bar", "role": "user"},
{
"arguments": '{"text": "foo"}',
"call_id": "1",
"name": "echo_tool",
"type": "function_call",
"id": "1",
},
{"call_id": "1", "output": "Echo: foo", "type": "function_call_output"},
{
"arguments": '{"text": "bar"}',
"call_id": "2",
"name": "echo_tool",
"type": "function_call",
"id": "1",
},
{"call_id": "2", "output": "Echo: bar", "type": "function_call_output"},
{
"id": "1",
"content": [
{
"annotations": [],
"text": "Summary: Echoed foo and bar",
"type": "output_text",
}
],
"role": "assistant",
"status": "completed",
"type": "message",
},
]

expected_calls = [
# First call is the initial input
(([expected_items[0]],),),
# Second call is the first tool call and its result
(([expected_items[1], expected_items[2]],),),
# Third call is the second tool call and its result
(([expected_items[3], expected_items[4]],),),
# Fourth call is the final output
(([expected_items[5]],),),
]
assert mock_add_items.call_args_list == expected_calls
assert result.final_output == "Summary: Echoed foo and bar"
assert (await session.get_items()) == expected_items

session.close()