Skip to content

Commit 85da7ee

Browse files
authored
Merge pull request #81 from PathOnAI/mcts-backend
2 parents c484e68 + 7b9b578 commit 85da7ee

File tree

4 files changed

+435
-66
lines changed

4 files changed

+435
-66
lines changed
Lines changed: 298 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,43 @@
11
from typing import Any, Optional, Tuple, List
22
from datetime import datetime
3+
import logging
4+
import json
5+
import time
6+
from openai import OpenAI
37
from dotenv import load_dotenv
48
load_dotenv()
59

6-
from .tree_vis import RED, better_print, print_trajectory, collect_all_nodes, GREEN, RESET, print_entire_tree
10+
from .tree_vis import RED, GREEN, RESET, better_print, print_trajectory, collect_all_nodes, print_entire_tree
711
from .lats_node import LATSNode
812
from .base_agent import BaseAgent
13+
from .trajectory_score import create_llm_prompt, score_trajectory_with_openai
14+
from ...replay_async import generate_feedback, playwright_step_execution
15+
from ...webagent_utils_async.browser_env.observation import extract_page_info
16+
from ...webagent_utils_async.action.prompt_functions import extract_top_actions
17+
from ...webagent_utils_async.utils.utils import parse_function_args, locate_element
18+
from ...evaluation_async.evaluators import goal_finished_evaluator
19+
20+
openai_client = OpenAI()
21+
22+
logger = logging.getLogger(__name__)
23+
logger.setLevel(logging.INFO)
924

1025
class MCTSAgent(BaseAgent):
11-
async def run(self, websocket=None) -> list[LATSNode]:
26+
"""
27+
Monte Carlo Tree Search Agent for web navigation tasks.
28+
This implementation uses reflection-based search to improve performance.
29+
"""
30+
31+
async def run(self, websocket=None) -> List[dict[str, Any]]:
32+
"""
33+
Run the MCTS algorithm based on configuration.
34+
35+
Args:
36+
websocket: Optional WebSocket connection to send updates to
37+
38+
Returns:
39+
List[Dict[str, Any]]: List of actions in the best path found
40+
"""
1241
if websocket:
1342
await websocket.send_json({
1443
"type": "search_status",
@@ -17,4 +46,270 @@ async def run(self, websocket=None) -> list[LATSNode]:
1746
"timestamp": datetime.utcnow().isoformat()
1847
})
1948

20-
pass
49+
# Reset browser to initial state
50+
live_browser_url, session_id = await self._reset_browser(websocket)
51+
52+
best_node = await self.mcts_search(websocket)
53+
print_trajectory(best_node)
54+
55+
return best_node
56+
57+
async def node_selection(self, node: LATSNode, websocket=None) -> Optional[LATSNode]:
58+
if node.is_terminal:
59+
return None
60+
61+
current_node = node
62+
path = [current_node]
63+
selection_depth = 0
64+
65+
while current_node.children and not current_node.is_terminal:
66+
logger.info(f"\nSelection Step {selection_depth + 1}:")
67+
logger.info(f"Current node action: {current_node.action}")
68+
logger.info(f"Number of children: {len(current_node.children)}")
69+
70+
# Get trajectory for GPT-4 to evaluate
71+
trajectory = []
72+
for n in path[1:]: # Skip root node
73+
trajectory.append({
74+
"natural_language_description": n.natural_language_description,
75+
"action": n.action,
76+
"feedback": n.feedback if hasattr(n, 'feedback') else None
77+
})
78+
79+
# Create prompt for GPT-4
80+
prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next.
81+
Consider the overall progress, efficiency, and likelihood of success.
82+
83+
Goal: {self.goal}
84+
85+
Current Trajectory:
86+
{json.dumps(trajectory, indent=2)}
87+
88+
Available Children:
89+
{json.dumps([{
90+
'action': child.action,
91+
'description': child.natural_language_description,
92+
'visits': child.visits,
93+
'value': child.value if hasattr(child, 'value') else 0
94+
} for child in current_node.children], indent=2)}
95+
96+
Return a JSON response with:
97+
{{
98+
"selected_child_index": int, # Index of the selected child
99+
"explanation": str # Brief explanation of the selection
100+
}}"""
101+
102+
response = openai_client.chat.completions.create(
103+
model=self.config.evaluation_model,
104+
messages=[
105+
{"role": "system", "content": "You are an expert at selecting promising paths in a search tree."},
106+
{"role": "user", "content": prompt}
107+
],
108+
response_format={"type": "json_object"}
109+
)
110+
111+
selection = json.loads(response.choices[0].message.content)
112+
selected_index = selection["selected_child_index"]
113+
114+
if 0 <= selected_index < len(current_node.children):
115+
current_node = current_node.children[selected_index]
116+
path.append(current_node)
117+
logger.info(f"Selected child {selected_index + 1}: {current_node.action}")
118+
logger.info(f"Selection explanation: {selection['explanation']}")
119+
else:
120+
logger.warning(f"Invalid child index {selected_index}, breaking selection")
121+
break
122+
123+
124+
selection_depth += 1
125+
126+
# Send final node selection update
127+
await self.websocket_node_selection(current_node, websocket=websocket)
128+
return current_node
129+
130+
async def evaluate_selected_path(self, path) -> None:
131+
"""Evaluate the current node and assign its score."""
132+
# Get the path from root to this node
133+
# path = self.get_path_to_root(node)
134+
135+
# Create trajectory for scoring (skip root node)
136+
trajectory = []
137+
for n in path[1:]: # Skip root node
138+
trajectory.append({
139+
"natural_language_description": n.natural_language_description,
140+
"action": n.action,
141+
"feedback": n.feedback
142+
})
143+
144+
# Score the trajectory
145+
# TODO: if node is terminal, score is 0?
146+
# if node.is_terminal:
147+
# score = 0
148+
prompt = create_llm_prompt(trajectory, self.goal)
149+
print(f"prompt: {prompt}")
150+
result = score_trajectory_with_openai(
151+
prompt,
152+
openai_client,
153+
model=self.config.evaluation_model
154+
)
155+
print(f"result: {result}")
156+
score = result["overall_score"]
157+
print(f"Simulation Results, evaluate selected path:")
158+
print(f"Overall Score: {score:.3f}")
159+
print(f"Efficiency Score: {result['efficiency_score']:.3f}")
160+
print(f"Accuracy Score: {result['accuracy_score']:.3f}")
161+
print(f"Robustness Score: {result['robustness_score']:.3f}")
162+
return score
163+
164+
async def reflection_backtracking(self, path) -> List[LATSNode]:
165+
"""
166+
Implement reflection-based backtracking to improve search trajectory.
167+
168+
Args:
169+
node: Current node
170+
path: Current path from root to node
171+
172+
Returns:
173+
List[LATSNode]: Modified path after backtracking
174+
"""
175+
# Create trajectory for reflection
176+
trajectory = []
177+
for n in path[1:]: # Skip root node
178+
trajectory.append({
179+
"natural_language_description": n.natural_language_description,
180+
"action": n.action,
181+
"feedback": n.feedback if hasattr(n, 'feedback') else None
182+
})
183+
184+
score = await self.evaluate_selected_path(path)
185+
print(f"\nReflection Step (Score {score:.3f} < {self.config.reflection_score}):")
186+
187+
# Generate reflection prompt
188+
reflection_prompt = f"""Analyze the current trajectory and suggest improvements for the current website.
189+
190+
Goal: {self.goal}
191+
192+
Current Trajectory:
193+
{json.dumps(trajectory, indent=2)}
194+
195+
Score: {score}
196+
197+
Return a JSON response with:
198+
{{
199+
"backtrack_to_step": int, # Which step to backtrack to (0-based index)
200+
"reason": str, # Why backtrack to this step
201+
"suggested_improvements": [str] # List of suggested improvements specific to current websites
202+
}}"""
203+
204+
reflection = openai_client.chat.completions.create(
205+
model=self.config.evaluation_model,
206+
messages=[
207+
{"role": "system", "content": "You are an expert at analyzing and improving search trajectories."},
208+
{"role": "user", "content": reflection_prompt}
209+
],
210+
response_format={"type": "json_object"}
211+
)
212+
213+
reflection_result = json.loads(reflection.choices[0].message.content)
214+
backtrack_step = reflection_result["backtrack_to_step"]
215+
216+
# Backtrack to the suggested step
217+
if 0 <= backtrack_step < len(path):
218+
# Prevent backtracking to root when we have actions
219+
if backtrack_step == 0 and len(path) > 1:
220+
backtrack_step = 1
221+
print("Adjusted backtracking to maintain at least one action")
222+
223+
current_node = path[backtrack_step]
224+
# Remove nodes after the backtrack point
225+
while len(path) > backtrack_step + 1:
226+
path.pop()
227+
228+
print(f"Backtracking to step {backtrack_step}")
229+
print(f"Reason: {reflection_result['reason']}")
230+
print("Suggested improvements:")
231+
for improvement in reflection_result["suggested_improvements"]:
232+
print(f"- {improvement}")
233+
234+
return path
235+
236+
async def mcts_search(self, websocket=None) -> Optional[LATSNode]:
237+
best_score = float('-inf')
238+
best_node = None
239+
print(f"iterations: {self.config.iterations}")
240+
241+
for i in range(self.config.iterations):
242+
await self.websocket_iteration_start(i, websocket=websocket)
243+
244+
print(f"\n{'='*50}")
245+
print(f"MCTS Iteration {i + 1}/{self.config.iterations}")
246+
print(f"{'='*50}\n")
247+
248+
# Step 1: Node Selection (contain simulation)
249+
# "node selection" combines selection and partial simulation
250+
print(f"{GREEN}Step 1: Node Selection{RESET}")
251+
await self.websocket_step_start(step=1, step_name="node_selection", websocket=websocket)
252+
node = await self.node_selection(self.root_node, websocket)
253+
254+
if node is None:
255+
logger.warning("All paths lead to terminal nodes. Ending search.")
256+
break
257+
258+
# Step 2: Node Expansion
259+
print(f"{GREEN}Step 2: Node Expansion{RESET}")
260+
await self.websocket_step_start(step=2, step_name="node_expansion", websocket=websocket)
261+
await self.node_expansion(node, websocket)
262+
if node is None:
263+
# all the nodes are terminal, stop the search
264+
print(f"{RED}All nodes are terminal, stopping search{RESET}")
265+
break
266+
tree_data = self._get_tree_data()
267+
if websocket:
268+
await self.websocket_tree_update(type="tree_update_node_expansion", websocket=websocket, tree_data=tree_data)
269+
else:
270+
print_entire_tree(self.root_node)
271+
272+
273+
# Step 3: simulation using the current node, (generate a path using the current node, and score the path)
274+
# TODO: implement simulation using openai
275+
print(f"{GREEN}Step 3: Simulation{RESET}")
276+
await self.websocket_step_start(step=3, step_name="simulation", websocket=websocket)
277+
path = self.get_path_to_root(node)
278+
score = await self.evaluate_selected_path(path)
279+
# change to reward later?
280+
if score > best_score:
281+
best_score = score
282+
best_path = path
283+
print(f"\nNew best path found!")
284+
print(f"Previous best score: {best_score:.3f}")
285+
print(f"New best score: {score:.3f}")
286+
287+
288+
## Step 4: reflection backtracking
289+
print(f"{GREEN}Step 4: Reflection Backtracking{RESET}")
290+
await self.websocket_step_start(step=4, step_name="reflection_backtracking", websocket=websocket)
291+
if score >= self.config.reflection_score:
292+
# Convert path to serializable trajectory
293+
trajectory = [node.action for node in path if node.action is not None]
294+
await self.websocket_search_complete("success", score, trajectory, websocket=websocket)
295+
return node
296+
297+
print(f"path: {path}")
298+
path = await self.reflection_backtracking(path)
299+
print(f"path: {path}")
300+
301+
# Step 5: backpropagation
302+
print(f"{GREEN}Step 5: Backpropagation{RESET}")
303+
await self.websocket_step_start(step=5, step_name="backpropagation", websocket=websocket)
304+
for node in path:
305+
old_value = node.value
306+
node.visits += 1
307+
node.value = (node.value * (node.visits - 1) + score) / node.visits
308+
print(f"Node {node.action}:")
309+
print(f" Visits: {node.visits}")
310+
print(f" Value: {old_value:.3f} -> {node.value:.3f}")
311+
if best_node:
312+
# Convert node to serializable trajectory
313+
trajectory = [n.action for n in self.get_path_to_root(best_node) if n.action is not None]
314+
await self.websocket_search_complete("partial_success", best_node.value, best_node.get_trajectory(), websocket=websocket)
315+
return best_node

visual-tree-search-backend/app/api/lwats/core_async/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class AgentConfig:
2424
max_depth: int = 3
2525
num_simulations: int = 1
2626
account_reset: bool = True
27+
28+
reflection_score: float = 0.75
2729

2830
# Features
2931
features: List[str] = field(default_factory=lambda: ['axtree'])

visual-tree-search-backend/test/test-tree-search-ws-lats.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ async def connect_and_test_search(
7777
starting_url: str,
7878
goal: str,
7979
search_algorithm: str = "bfs",
80-
max_depth: int = 3
80+
max_depth: int = 3,
81+
iterations: int = 5
8182
):
8283
"""
8384
Connect to the WebSocket endpoint and test the tree search functionality.
@@ -88,6 +89,7 @@ async def connect_and_test_search(
8889
goal: Goal to achieve
8990
search_algorithm: Search algorithm to use (bfs or dfs)
9091
max_depth: Maximum depth for the search tree
92+
iterations: Number of iterations for LATS algorithm
9193
"""
9294
logger.info(f"Connecting to WebSocket at {ws_url}")
9395

@@ -107,7 +109,8 @@ async def connect_and_test_search(
107109
"starting_url": starting_url,
108110
"goal": goal,
109111
"search_algorithm": search_algorithm,
110-
"max_depth": max_depth
112+
"max_depth": max_depth,
113+
"iterations": iterations
111114
}
112115

113116
logger.info(f"Sending search request: {request}")
@@ -156,6 +159,9 @@ def parse_arguments():
156159
parser.add_argument("--max-depth", type=int, default=3,
157160
help="Maximum depth for the search tree (default: 3)")
158161

162+
parser.add_argument("--iterations", type=int, default=5,
163+
help="Number of iterations for LATS algorithm (default: 5)")
164+
159165
# Add the new argument for log file
160166
parser.add_argument("--log-file", type=str,
161167
help="File to save the colored output to")
@@ -196,14 +202,16 @@ def flush(self):
196202
logger.info(f"Goal: {args.goal}")
197203
logger.info(f"Algorithm: {args.algorithm}")
198204
logger.info(f"Max depth: {args.max_depth}")
205+
logger.info(f"Iterations: {args.iterations}")
199206

200207
try:
201208
await connect_and_test_search(
202209
ws_url=args.ws_url,
203210
starting_url=args.starting_url,
204211
goal=args.goal,
205212
search_algorithm=args.algorithm,
206-
max_depth=args.max_depth
213+
max_depth=args.max_depth,
214+
iterations=args.iterations
207215
)
208216
finally:
209217
# Clean up if logging to file

0 commit comments

Comments
 (0)