6
6
from typing import Any , List
7
7
from typing_extensions import Literal
8
8
from langchain_openai import ChatOpenAI
9
- from langchain_core .messages import SystemMessage
9
+ from langchain_core .messages import SystemMessage , BaseMessage
10
10
from langchain_core .runnables import RunnableConfig
11
11
from langchain .tools import tool
12
12
from langgraph .graph import StateGraph , END
@@ -39,11 +39,15 @@ def get_weather(location: str):
39
39
# print(f"Your tool logic here")
40
40
# return "Your tool response here."
41
41
42
- tools = [
42
+ backend_tools = [
43
43
get_weather
44
44
# your_tool_here
45
45
]
46
46
47
+ # Extract tool names from backend_tools for comparison
48
+ backend_tool_names = [tool .name for tool in backend_tools ]
49
+
50
+
47
51
async def chat_node (state : AgentState , config : RunnableConfig ) -> Command [Literal ["tool_node" , "__end__" ]]:
48
52
"""
49
53
Standard chat node based on the ReAct design pattern. It handles:
@@ -63,7 +67,7 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Litera
63
67
model_with_tools = model .bind_tools (
64
68
[
65
69
* state .get ("tools" , []), # bind tools defined by ag-ui
66
- get_weather ,
70
+ * backend_tools ,
67
71
# your_tool_here
68
72
],
69
73
@@ -84,18 +88,41 @@ async def chat_node(state: AgentState, config: RunnableConfig) -> Command[Litera
84
88
* state ["messages" ],
85
89
], config )
86
90
91
+ # only route to tool node if tool is not in the tools list
92
+ if route_to_tool_node (response ):
93
+ print ("routing to tool node" )
94
+ return Command (
95
+ goto = "tool_node" ,
96
+ update = {
97
+ "messages" : [response ],
98
+ }
99
+ )
100
+
87
101
# 5. We've handled all tool calls, so we can end the graph.
88
102
return Command (
89
103
goto = END ,
90
104
update = {
91
- "messages" : response
105
+ "messages" : [ response ],
92
106
}
93
107
)
94
108
109
+ def route_to_tool_node (response : BaseMessage ):
110
+ """
111
+ Route to tool node if any tool call in the response matches a backend tool name.
112
+ """
113
+ tool_calls = getattr (response , "tool_calls" , None )
114
+ if not tool_calls :
115
+ return False
116
+
117
+ for tool_call in tool_calls :
118
+ if tool_call .get ("name" ) in backend_tool_names :
119
+ return True
120
+ return False
121
+
95
122
# Define the workflow graph
96
123
workflow = StateGraph (AgentState )
97
124
workflow .add_node ("chat_node" , chat_node )
98
- workflow .add_node ("tool_node" , ToolNode (tools = tools ))
125
+ workflow .add_node ("tool_node" , ToolNode (tools = backend_tools ))
99
126
workflow .add_edge ("tool_node" , "chat_node" )
100
127
workflow .set_entry_point ("chat_node" )
101
128
0 commit comments