-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathgraph.py
More file actions
83 lines (60 loc) · 2.47 KB
/
graph.py
File metadata and controls
83 lines (60 loc) · 2.47 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Literal
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import START
from langgraph.graph.state import CompiledStateGraph, RunnableConfig, StateGraph
from pydantic import ValidationError
from rich import print
from examples.ex007.state import State
from examples.ex007.tools import TOOLS, TOOLS_BY_NAME
from examples.ex007.utils import load_llm
def call_llm(state: State, config: RunnableConfig) -> State:
print("> call llm")
user_type = config.get("configurable", {}).get("user_type")
model_provider = "ollama" if user_type == "plus" else "ollama" # noqa: RUF034
model = "gpt-oss:20b" if user_type == "plus" else "qwen3-coder:30b"
llm_with_tools = load_llm().bind_tools(TOOLS)
llm_with_config = llm_with_tools.with_config(
config={
"configurable": {
"model": model,
"model_provider": model_provider,
}
}
)
print(llm_with_config.temperature)
result = llm_with_config.invoke(
state["messages"],
)
return {"messages": [result]}
def tool_node(state: State) -> State:
print("> tool node")
llm_response = state["messages"][-1]
if not isinstance(llm_response, AIMessage) or not getattr(
llm_response, "tool_calls", None
):
return state
call = llm_response.tool_calls[-1]
name, args, id_ = call["name"], call["args"], call["id"]
try:
content = TOOLS_BY_NAME[name].invoke(args)
status = "success"
except (KeyError, IndexError, TypeError, ValidationError, ValueError) as error:
content = f"Please, fix your mistakes: {error}"
status = "error"
tool_message = ToolMessage(content=content, tool_call_id=id_, status=status)
return {"messages": [tool_message]}
def router(state: State) -> Literal["tool_node", "__end__"]:
print("> router")
llm_response = state["messages"][-1]
if getattr(llm_response, "tool_calls", None):
return "tool_node"
return "__end__"
def build_graph() -> CompiledStateGraph[State, None, State, State]:
builder = StateGraph(State)
builder.add_node("call_llm", call_llm)
builder.add_node("tool_node", tool_node)
builder.add_edge(START, "call_llm")
builder.add_conditional_edges("call_llm", router, ["tool_node", "__end__"])
builder.add_edge("tool_node", "call_llm")
return builder.compile(checkpointer=InMemorySaver())