Skip to content

Commit 1ee1afb

Browse files
committed
add route_to_tool_node function
Signed-off-by: Tyler Slaton <[email protected]>
1 parent df07459 commit 1ee1afb

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

agent/agent.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import Any, List
77
from typing_extensions import Literal
88
from langchain_openai import ChatOpenAI
9-
from langchain_core.messages import SystemMessage
9+
from langchain_core.messages import SystemMessage, BaseMessage
1010
from langchain_core.runnables import RunnableConfig
1111
from langchain.tools import tool
1212
from langgraph.graph import StateGraph, END
@@ -39,11 +39,15 @@ def get_weather(location: str):
3939
# print(f"Your tool logic here")
4040
# return "Your tool response here."
4141

42-
tools = [
42+
backend_tools = [
4343
get_weather
4444
# your_tool_here
4545
]
4646

47+
# Extract tool names from backend_tools for comparison
48+
backend_tool_names = [tool.name for tool in backend_tools]
49+
50+
4751
async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Literal["tool_node", "__end__"]]:
4852
"""
4953
Standard chat node based on the ReAct design pattern. It handles:
@@ -63,7 +67,7 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Litera
6367
model_with_tools = model.bind_tools(
6468
[
6569
*state.get("tools", []), # bind tools defined by ag-ui
66-
get_weather,
70+
*backend_tools,
6771
# your_tool_here
6872
],
6973

@@ -84,18 +88,41 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Litera
8488
*state["messages"],
8589
], config)
8690

91+
# only route to tool node if tool is not in the tools list
92+
if route_to_tool_node(response):
93+
print("routing to tool node")
94+
return Command(
95+
goto="tool_node",
96+
update={
97+
"messages": [response],
98+
}
99+
)
100+
87101
# 5. We've handled all tool calls, so we can end the graph.
88102
return Command(
89103
goto=END,
90104
update={
91-
"messages": response
105+
"messages": [response],
92106
}
93107
)
94108

109+
def route_to_tool_node(response: BaseMessage):
110+
"""
111+
Route to tool node if any tool call in the response matches a backend tool name.
112+
"""
113+
tool_calls = getattr(response, "tool_calls", None)
114+
if not tool_calls:
115+
return False
116+
117+
for tool_call in tool_calls:
118+
if tool_call.get("name") in backend_tool_names:
119+
return True
120+
return False
121+
95122
# Define the workflow graph
96123
workflow = StateGraph(AgentState)
97124
workflow.add_node("chat_node", chat_node)
98-
workflow.add_node("tool_node", ToolNode(tools=tools))
125+
workflow.add_node("tool_node", ToolNode(tools=backend_tools))
99126
workflow.add_edge("tool_node", "chat_node")
100127
workflow.set_entry_point("chat_node")
101128

0 commit comments

Comments
 (0)