|
1 | 1 | import asyncio
|
| 2 | +import contextlib |
2 | 3 | from unittest.mock import AsyncMock
|
3 | 4 |
|
4 | 5 | import pytest
|
@@ -62,7 +63,15 @@ async def fn() -> None:
|
62 | 63 |
|
63 | 64 | handler.background_task = asyncio.create_task(fn())
|
64 | 65 | with pytest.raises(TaskHandlerError):
|
65 |
| - handler.background_task = asyncio.create_task(fn()) |
| 66 | + new_task = asyncio.create_task(fn()) |
| 67 | + handler.background_task = new_task |
| 68 | + |
| 69 | + # cleanup |
| 70 | + handler.background_task.cancel() |
| 71 | + new_task.cancel() |
| 72 | + with contextlib.suppress(asyncio.CancelledError): |
| 73 | + await handler.background_task |
| 74 | + await new_task |
66 | 75 |
|
67 | 76 |
|
68 | 77 | @pytest.mark.asyncio
|
@@ -211,10 +220,25 @@ async def plus_two(arg1: int) -> int:
|
211 | 220 | tool_calls=tool_calls,
|
212 | 221 | )
|
213 | 222 | # continue conversation with tool calls
|
214 |
| - mock_return_value = ChatMessage( |
215 |
| - role=ChatRole.ASSISTANT, |
216 |
| - content="The final response.", |
217 |
| - ) |
| 223 | + mock_return_value = [ |
| 224 | + # tool calls |
| 225 | + ChatMessage( |
| 226 | + role=ChatRole.TOOL, |
| 227 | + content="2", |
| 228 | + ), |
| 229 | + ChatMessage( |
| 230 | + role=ChatRole.TOOL, |
| 231 | + content="3", |
| 232 | + ), |
| 233 | + ChatMessage( |
| 234 | + role=ChatRole.TOOL, |
| 235 | + content="error: tool name `plus_three` doesn't exist", |
| 236 | + ), |
| 237 | + ChatMessage( |
| 238 | + role=ChatRole.ASSISTANT, |
| 239 | + content="The final response.", |
| 240 | + ), |
| 241 | + ] |
218 | 242 | mock_llm.continue_conversation_with_tool_results.return_value = (
|
219 | 243 | mock_return_value
|
220 | 244 | )
|
@@ -252,3 +276,46 @@ async def plus_two(arg1: int) -> int:
|
252 | 276 | mock_llm.continue_conversation_with_tool_results.assert_awaited_once()
|
253 | 277 | assert step_result.task_step == step
|
254 | 278 | assert step_result.content == "The final response."
|
| 279 | + |
| 280 | + |
| 281 | +@pytest.mark.asyncio |
| 282 | +async def test_run_step_without_tool_calls() -> None: |
| 283 | + """Tests run step.""" |
| 284 | + |
| 285 | + # arrange mocks |
| 286 | + mock_llm = AsyncMock() |
| 287 | + mock_llm.chat.return_value = ChatMessage( |
| 288 | + role=ChatRole.ASSISTANT, |
| 289 | + content="Initial response.", |
| 290 | + ) |
| 291 | + |
| 292 | + handler = TaskHandler( |
| 293 | + task=Task(instruction="mock instruction"), |
| 294 | + llm=mock_llm, |
| 295 | + tools=[], |
| 296 | + ) |
| 297 | + |
| 298 | + # act |
| 299 | + step = TaskStep( |
| 300 | + instruction="Some instruction.", |
| 301 | + last_step=False, |
| 302 | + ) |
| 303 | + step_result = await handler.run_step(step) |
| 304 | + |
| 305 | + # assert |
| 306 | + mock_llm.chat.assert_awaited_once_with( |
| 307 | + input="Some instruction.", |
| 308 | + chat_messages=[ |
| 309 | + ChatMessage( |
| 310 | + role=ChatRole.SYSTEM, |
| 311 | + content=DEFAULT_SYSTEM_MESSAGE.format( |
| 312 | + original_instruction="mock instruction", |
| 313 | + current_rollout="", |
| 314 | + ), |
| 315 | + ), |
| 316 | + ], |
| 317 | + tools=list(handler.tools_registry.keys()), |
| 318 | + ) |
| 319 | + mock_llm.continue_conversation_with_tool_results.assert_not_awaited() |
| 320 | + assert step_result.task_step == step |
| 321 | + assert step_result.content == "Initial response." |
0 commit comments