|
36 | 36 | get_text_message,
|
37 | 37 | )
|
38 | 38 |
|
| 39 | +import tempfile |
| 40 | +from pathlib import Path |
| 41 | +from agents import SQLiteSession |
| 42 | +from unittest.mock import patch |
| 43 | + |
39 | 44 |
|
40 | 45 | @pytest.mark.asyncio
|
41 | 46 | async def test_simple_first_run():
|
@@ -780,3 +785,96 @@ async def add_tool() -> str:
|
780 | 785 |
|
781 | 786 | assert executed["called"] is True
|
782 | 787 | assert result.final_output == "done"
|
| 788 | + |
| 789 | + |
| 790 | +@pytest.mark.asyncio |
| 791 | +async def test_session_add_items_called_multiple_times_for_multi_turn_completion(): |
| 792 | + """Test that SQLiteSession.add_items is called multiple times |
| 793 | + during a multi-turn agent completion. |
| 794 | +
|
| 795 | + """ |
| 796 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 797 | + db_path = Path(temp_dir) / "test_agent_runner_session_multi_turn_calls.db" |
| 798 | + session_id = "runner_session_multi_turn_calls" |
| 799 | + session = SQLiteSession(session_id, db_path) |
| 800 | + |
| 801 | + # Define a tool that will be called by the orchestrator agent |
| 802 | + @function_tool |
| 803 | + async def echo_tool(text: str) -> str: |
| 804 | + return f"Echo: {text}" |
| 805 | + |
| 806 | + # Orchestrator agent that calls the tool multiple times in one completion |
| 807 | + orchestrator_agent = Agent( |
| 808 | + name="orchestrator_agent", |
| 809 | + instructions=( |
| 810 | + "Call echo_tool twice with inputs of 'foo' and 'bar', then return a summary." |
| 811 | + ), |
| 812 | + tools=[echo_tool], |
| 813 | + ) |
| 814 | + |
| 815 | + # Patch the model to simulate two tool calls and a final message |
| 816 | + model = FakeModel() |
| 817 | + orchestrator_agent.model = model |
| 818 | + model.add_multiple_turn_outputs( |
| 819 | + [ |
| 820 | + # First turn: tool call |
| 821 | + [get_function_tool_call("echo_tool", json.dumps({"text": "foo"}), call_id="1")], |
| 822 | + # Second turn: tool call |
| 823 | + [get_function_tool_call("echo_tool", json.dumps({"text": "bar"}), call_id="2")], |
| 824 | + # Third turn: final output |
| 825 | + [get_final_output_message("Summary: Echoed foo and bar")], |
| 826 | + ] |
| 827 | + ) |
| 828 | + |
| 829 | + # Patch add_items to count calls |
| 830 | + with patch.object(SQLiteSession, "add_items", wraps=session.add_items) as mock_add_items: |
| 831 | + result = await Runner.run(orchestrator_agent, input="foo and bar", session=session) |
| 832 | + |
| 833 | + expected_items = [ |
| 834 | + {"content": "foo and bar", "role": "user"}, |
| 835 | + { |
| 836 | + "arguments": '{"text": "foo"}', |
| 837 | + "call_id": "1", |
| 838 | + "name": "echo_tool", |
| 839 | + "type": "function_call", |
| 840 | + "id": "1", |
| 841 | + }, |
| 842 | + {"call_id": "1", "output": "Echo: foo", "type": "function_call_output"}, |
| 843 | + { |
| 844 | + "arguments": '{"text": "bar"}', |
| 845 | + "call_id": "2", |
| 846 | + "name": "echo_tool", |
| 847 | + "type": "function_call", |
| 848 | + "id": "1", |
| 849 | + }, |
| 850 | + {"call_id": "2", "output": "Echo: bar", "type": "function_call_output"}, |
| 851 | + { |
| 852 | + "id": "1", |
| 853 | + "content": [ |
| 854 | + { |
| 855 | + "annotations": [], |
| 856 | + "text": "Summary: Echoed foo and bar", |
| 857 | + "type": "output_text", |
| 858 | + } |
| 859 | + ], |
| 860 | + "role": "assistant", |
| 861 | + "status": "completed", |
| 862 | + "type": "message", |
| 863 | + }, |
| 864 | + ] |
| 865 | + |
| 866 | + expected_calls = [ |
| 867 | + # First call is the initial input |
| 868 | + (([expected_items[0]],),), |
| 869 | + # Second call is the first tool call and its result |
| 870 | + (([expected_items[1], expected_items[2]],),), |
| 871 | + # Third call is the second tool call and its result |
| 872 | + (([expected_items[3], expected_items[4]],),), |
| 873 | + # Fourth call is the final output |
| 874 | + (([expected_items[5]],),), |
| 875 | + ] |
| 876 | + assert mock_add_items.call_args_list == expected_calls |
| 877 | + assert result.final_output == "Summary: Echoed foo and bar" |
| 878 | + assert (await session.get_items()) == expected_items |
| 879 | + |
| 880 | + session.close() |
0 commit comments