Skip to content

Commit b96bae4

Browse files
authored
[Fix] Update TaskHandler.run_step to work with the new continue_conversation_with_tool_results interface. (#39)
* update run step * changelog * beter cleanup
1 parent 31930a2 commit b96bae4

File tree

3 files changed

+83
-9
lines changed

3 files changed

+83
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/).
1010

1111
### Changed
1212

13-
- Updated return type of `continue_conversation_with_tool_results` to `list[ChatMessage]` (#38)
13+
- Update `TaskHandler.run_step()` to work with updated `continue_conversation_with_tool_results` (#39)
14+
- Update return type of `continue_conversation_with_tool_results` to `list[ChatMessage]` (#38)
1415

1516
### Deleted
1617

src/llm_agents_from_scratch/core/task_handler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,20 +181,26 @@ async def run_step(self, step: TaskStep) -> TaskStepResult:
181181
tool_call_results.append(tool_call_result)
182182

183183
# send tool call results back to llm to get result
184-
final_response = (
184+
new_messages = (
185185
await self.llm.continue_conversation_with_tool_results(
186186
tool_call_results=tool_call_results,
187187
chat_messages=chat_history,
188188
)
189189
)
190190

191+
# get final content and update chat history
192+
final_content = new_messages[-1].content
193+
chat_history += new_messages
194+
else:
195+
final_content = response.content
196+
191197
# augment rollout from this turn
192198
async with self._lock:
193199
self.rollout += self._rollout_contribution_from_single_run_step(
194-
chat_history=chat_history + [final_response],
200+
chat_history=chat_history,
195201
)
196202

197203
return TaskStepResult(
198204
task_step=step,
199-
content=final_response.content,
205+
content=final_content,
200206
)

tests/test_task_handler.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import contextlib
23
from unittest.mock import AsyncMock
34

45
import pytest
@@ -62,7 +63,15 @@ async def fn() -> None:
6263

6364
handler.background_task = asyncio.create_task(fn())
6465
with pytest.raises(TaskHandlerError):
65-
handler.background_task = asyncio.create_task(fn())
66+
new_task = asyncio.create_task(fn())
67+
handler.background_task = new_task
68+
69+
# cleanup
70+
handler.background_task.cancel()
71+
new_task.cancel()
72+
with contextlib.suppress(asyncio.CancelledError):
73+
await handler.background_task
74+
await new_task
6675

6776

6877
@pytest.mark.asyncio
@@ -211,10 +220,25 @@ async def plus_two(arg1: int) -> int:
211220
tool_calls=tool_calls,
212221
)
213222
# continue conversation with tool calls
214-
mock_return_value = ChatMessage(
215-
role=ChatRole.ASSISTANT,
216-
content="The final response.",
217-
)
223+
mock_return_value = [
224+
# tool calls
225+
ChatMessage(
226+
role=ChatRole.TOOL,
227+
content="2",
228+
),
229+
ChatMessage(
230+
role=ChatRole.TOOL,
231+
content="3",
232+
),
233+
ChatMessage(
234+
role=ChatRole.TOOL,
235+
content="error: tool name `plus_three` doesn't exist",
236+
),
237+
ChatMessage(
238+
role=ChatRole.ASSISTANT,
239+
content="The final response.",
240+
),
241+
]
218242
mock_llm.continue_conversation_with_tool_results.return_value = (
219243
mock_return_value
220244
)
@@ -252,3 +276,46 @@ async def plus_two(arg1: int) -> int:
252276
mock_llm.continue_conversation_with_tool_results.assert_awaited_once()
253277
assert step_result.task_step == step
254278
assert step_result.content == "The final response."
279+
280+
281+
@pytest.mark.asyncio
282+
async def test_run_step_without_tool_calls() -> None:
283+
"""Tests run step."""
284+
285+
# arrange mocks
286+
mock_llm = AsyncMock()
287+
mock_llm.chat.return_value = ChatMessage(
288+
role=ChatRole.ASSISTANT,
289+
content="Initial response.",
290+
)
291+
292+
handler = TaskHandler(
293+
task=Task(instruction="mock instruction"),
294+
llm=mock_llm,
295+
tools=[],
296+
)
297+
298+
# act
299+
step = TaskStep(
300+
instruction="Some instruction.",
301+
last_step=False,
302+
)
303+
step_result = await handler.run_step(step)
304+
305+
# assert
306+
mock_llm.chat.assert_awaited_once_with(
307+
input="Some instruction.",
308+
chat_messages=[
309+
ChatMessage(
310+
role=ChatRole.SYSTEM,
311+
content=DEFAULT_SYSTEM_MESSAGE.format(
312+
original_instruction="mock instruction",
313+
current_rollout="",
314+
),
315+
),
316+
],
317+
tools=list(handler.tools_registry.keys()),
318+
)
319+
mock_llm.continue_conversation_with_tool_results.assert_not_awaited()
320+
assert step_result.task_step == step
321+
assert step_result.content == "Initial response."

0 commit comments

Comments
 (0)