Skip to content

Commit c8a5d26

Browse files
committed
add unit test for session save on turn
1 parent d0fed4b commit c8a5d26

File tree

1 file changed

+98
-0
lines changed

1 file changed

+98
-0
lines changed

tests/test_agent_runner.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
get_text_message,
3737
)
3838

39+
import tempfile
40+
from pathlib import Path
41+
from agents import SQLiteSession
42+
from unittest.mock import patch
43+
3944

4045
@pytest.mark.asyncio
4146
async def test_simple_first_run():
@@ -780,3 +785,96 @@ async def add_tool() -> str:
780785

781786
assert executed["called"] is True
782787
assert result.final_output == "done"
788+
789+
790+
@pytest.mark.asyncio
791+
async def test_session_add_items_called_multiple_times_for_multi_turn_completion():
792+
"""Test that SQLiteSession.add_items is called multiple times
793+
during a multi-turn agent completion.
794+
795+
"""
796+
with tempfile.TemporaryDirectory() as temp_dir:
797+
db_path = Path(temp_dir) / "test_agent_runner_session_multi_turn_calls.db"
798+
session_id = "runner_session_multi_turn_calls"
799+
session = SQLiteSession(session_id, db_path)
800+
801+
# Define a tool that will be called by the orchestrator agent
802+
@function_tool
803+
async def echo_tool(text: str) -> str:
804+
return f"Echo: {text}"
805+
806+
# Orchestrator agent that calls the tool multiple times in one completion
807+
orchestrator_agent = Agent(
808+
name="orchestrator_agent",
809+
instructions=(
810+
"Call echo_tool twice with inputs of 'foo' and 'bar', then return a summary."
811+
),
812+
tools=[echo_tool],
813+
)
814+
815+
# Patch the model to simulate two tool calls and a final message
816+
model = FakeModel()
817+
orchestrator_agent.model = model
818+
model.add_multiple_turn_outputs(
819+
[
820+
# First turn: tool call
821+
[get_function_tool_call("echo_tool", json.dumps({"text": "foo"}), call_id="1")],
822+
# Second turn: tool call
823+
[get_function_tool_call("echo_tool", json.dumps({"text": "bar"}), call_id="2")],
824+
# Third turn: final output
825+
[get_final_output_message("Summary: Echoed foo and bar")],
826+
]
827+
)
828+
829+
# Patch add_items to count calls
830+
with patch.object(SQLiteSession, "add_items", wraps=session.add_items) as mock_add_items:
831+
result = await Runner.run(orchestrator_agent, input="foo and bar", session=session)
832+
833+
expected_items = [
834+
{"content": "foo and bar", "role": "user"},
835+
{
836+
"arguments": '{"text": "foo"}',
837+
"call_id": "1",
838+
"name": "echo_tool",
839+
"type": "function_call",
840+
"id": "1",
841+
},
842+
{"call_id": "1", "output": "Echo: foo", "type": "function_call_output"},
843+
{
844+
"arguments": '{"text": "bar"}',
845+
"call_id": "2",
846+
"name": "echo_tool",
847+
"type": "function_call",
848+
"id": "1",
849+
},
850+
{"call_id": "2", "output": "Echo: bar", "type": "function_call_output"},
851+
{
852+
"id": "1",
853+
"content": [
854+
{
855+
"annotations": [],
856+
"text": "Summary: Echoed foo and bar",
857+
"type": "output_text",
858+
}
859+
],
860+
"role": "assistant",
861+
"status": "completed",
862+
"type": "message",
863+
},
864+
]
865+
866+
expected_calls = [
867+
# First call is the initial input
868+
(([expected_items[0]],),),
869+
# Second call is the first tool call and its result
870+
(([expected_items[1], expected_items[2]],),),
871+
# Third call is the second tool call and its result
872+
(([expected_items[3], expected_items[4]],),),
873+
# Fourth call is the final output
874+
(([expected_items[5]],),),
875+
]
876+
assert mock_add_items.call_args_list == expected_calls
877+
assert result.final_output == "Summary: Echoed foo and bar"
878+
assert (await session.get_items()) == expected_items
879+
880+
session.close()

0 commit comments

Comments
 (0)