Skip to content

Commit 9ad2949

Browse files
authored
Save session on turn rather than at final response (#1550)
1 parent f37f70b commit 9ad2949

File tree

2 files changed

+110
-19
lines changed

2 files changed

+110
-19
lines changed

src/agents/run.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,9 @@ async def run(
438438
current_agent = starting_agent
439439
should_run_agent_start_hooks = True
440440

441+
# save the original input to the session if enabled
442+
await self._save_result_to_session(session, original_input, [])
443+
441444
try:
442445
while True:
443446
all_tools = await AgentRunner._get_all_tools(current_agent, context_wrapper)
@@ -537,9 +540,7 @@ async def run(
537540
output_guardrail_results=output_guardrail_results,
538541
context_wrapper=context_wrapper,
539542
)
540-
541-
# Save the conversation to session if enabled
542-
await self._save_result_to_session(session, input, result)
543+
await self._save_result_to_session(session, [], turn_result.new_step_items)
543544

544545
return result
545546
elif isinstance(turn_result.next_step, NextStepHandoff):
@@ -548,7 +549,7 @@ async def run(
548549
current_span = None
549550
should_run_agent_start_hooks = True
550551
elif isinstance(turn_result.next_step, NextStepRunAgain):
551-
pass
552+
await self._save_result_to_session(session, [], turn_result.new_step_items)
552553
else:
553554
raise AgentsException(
554555
f"Unknown next step type: {type(turn_result.next_step)}"
@@ -784,6 +785,8 @@ async def _start_streaming(
784785
# Update the streamed result with the prepared input
785786
streamed_result.input = prepared_input
786787

788+
await AgentRunner._save_result_to_session(session, starting_input, [])
789+
787790
while True:
788791
if streamed_result.is_complete:
789792
break
@@ -887,24 +890,15 @@ async def _start_streaming(
887890
streamed_result.is_complete = True
888891

889892
# Save the conversation to session if enabled
890-
# Create a temporary RunResult for session saving
891-
temp_result = RunResult(
892-
input=streamed_result.input,
893-
new_items=streamed_result.new_items,
894-
raw_responses=streamed_result.raw_responses,
895-
final_output=streamed_result.final_output,
896-
_last_agent=current_agent,
897-
input_guardrail_results=streamed_result.input_guardrail_results,
898-
output_guardrail_results=streamed_result.output_guardrail_results,
899-
context_wrapper=context_wrapper,
900-
)
901893
await AgentRunner._save_result_to_session(
902-
session, starting_input, temp_result
894+
session, [], turn_result.new_step_items
903895
)
904896

905897
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
906898
elif isinstance(turn_result.next_step, NextStepRunAgain):
907-
pass
899+
await AgentRunner._save_result_to_session(
900+
session, [], turn_result.new_step_items
901+
)
908902
except AgentsException as exc:
909903
streamed_result.is_complete = True
910904
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
@@ -1510,7 +1504,7 @@ async def _save_result_to_session(
15101504
cls,
15111505
session: Session | None,
15121506
original_input: str | list[TResponseInputItem],
1513-
result: RunResult,
1507+
new_items: list[RunItem],
15141508
) -> None:
15151509
"""Save the conversation turn to session."""
15161510
if session is None:
@@ -1520,7 +1514,7 @@ async def _save_result_to_session(
15201514
input_list = ItemHelpers.input_to_new_input_list(original_input)
15211515

15221516
# Convert new items to input format
1523-
new_items_as_input = [item.to_input_item() for item in result.new_items]
1517+
new_items_as_input = [item.to_input_item() for item in new_items]
15241518

15251519
# Save all items from this turn
15261520
items_to_save = input_list + new_items_as_input

tests/test_agent_runner.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

33
import json
4+
import tempfile
5+
from pathlib import Path
46
from typing import Any
7+
from unittest.mock import patch
58

69
import pytest
710
from typing_extensions import TypedDict
@@ -20,6 +23,7 @@
2023
RunConfig,
2124
RunContextWrapper,
2225
Runner,
26+
SQLiteSession,
2327
UserError,
2428
handoff,
2529
)
@@ -780,3 +784,96 @@ async def add_tool() -> str:
780784

781785
assert executed["called"] is True
782786
assert result.final_output == "done"
787+
788+
789+
@pytest.mark.asyncio
790+
async def test_session_add_items_called_multiple_times_for_multi_turn_completion():
791+
"""Test that SQLiteSession.add_items is called multiple times
792+
during a multi-turn agent completion.
793+
794+
"""
795+
with tempfile.TemporaryDirectory() as temp_dir:
796+
db_path = Path(temp_dir) / "test_agent_runner_session_multi_turn_calls.db"
797+
session_id = "runner_session_multi_turn_calls"
798+
session = SQLiteSession(session_id, db_path)
799+
800+
# Define a tool that will be called by the orchestrator agent
801+
@function_tool
802+
async def echo_tool(text: str) -> str:
803+
return f"Echo: {text}"
804+
805+
# Orchestrator agent that calls the tool multiple times in one completion
806+
orchestrator_agent = Agent(
807+
name="orchestrator_agent",
808+
instructions=(
809+
"Call echo_tool twice with inputs of 'foo' and 'bar', then return a summary."
810+
),
811+
tools=[echo_tool],
812+
)
813+
814+
# Patch the model to simulate two tool calls and a final message
815+
model = FakeModel()
816+
orchestrator_agent.model = model
817+
model.add_multiple_turn_outputs(
818+
[
819+
# First turn: tool call
820+
[get_function_tool_call("echo_tool", json.dumps({"text": "foo"}), call_id="1")],
821+
# Second turn: tool call
822+
[get_function_tool_call("echo_tool", json.dumps({"text": "bar"}), call_id="2")],
823+
# Third turn: final output
824+
[get_final_output_message("Summary: Echoed foo and bar")],
825+
]
826+
)
827+
828+
# Patch add_items to count calls
829+
with patch.object(SQLiteSession, "add_items", wraps=session.add_items) as mock_add_items:
830+
result = await Runner.run(orchestrator_agent, input="foo and bar", session=session)
831+
832+
expected_items = [
833+
{"content": "foo and bar", "role": "user"},
834+
{
835+
"arguments": '{"text": "foo"}',
836+
"call_id": "1",
837+
"name": "echo_tool",
838+
"type": "function_call",
839+
"id": "1",
840+
},
841+
{"call_id": "1", "output": "Echo: foo", "type": "function_call_output"},
842+
{
843+
"arguments": '{"text": "bar"}',
844+
"call_id": "2",
845+
"name": "echo_tool",
846+
"type": "function_call",
847+
"id": "1",
848+
},
849+
{"call_id": "2", "output": "Echo: bar", "type": "function_call_output"},
850+
{
851+
"id": "1",
852+
"content": [
853+
{
854+
"annotations": [],
855+
"text": "Summary: Echoed foo and bar",
856+
"type": "output_text",
857+
}
858+
],
859+
"role": "assistant",
860+
"status": "completed",
861+
"type": "message",
862+
},
863+
]
864+
865+
expected_calls = [
866+
# First call is the initial input
867+
(([expected_items[0]],),),
868+
# Second call is the first tool call and its result
869+
(([expected_items[1], expected_items[2]],),),
870+
# Third call is the second tool call and its result
871+
(([expected_items[3], expected_items[4]],),),
872+
# Fourth call is the final output
873+
(([expected_items[5]],),),
874+
]
875+
assert mock_add_items.call_args_list == expected_calls
876+
assert result.final_output == "Summary: Echoed foo and bar"
877+
assert (await session.get_items()) == expected_items
878+
879+
session.close()

0 commit comments

Comments
 (0)