1
1
from typing import Any , Optional , Tuple , List
2
2
from datetime import datetime
3
+ import logging
4
+ import json
5
+ import time
6
+ from openai import OpenAI
3
7
from dotenv import load_dotenv
4
8
load_dotenv ()
5
9
6
- from .tree_vis import RED , better_print , print_trajectory , collect_all_nodes , GREEN , RESET , print_entire_tree
10
+ from .tree_vis import RED , GREEN , RESET , better_print , print_trajectory , collect_all_nodes , print_entire_tree
7
11
from .lats_node import LATSNode
8
12
from .base_agent import BaseAgent
13
+ from .trajectory_score import create_llm_prompt , score_trajectory_with_openai
14
+ from ...replay_async import generate_feedback , playwright_step_execution
15
+ from ...webagent_utils_async .browser_env .observation import extract_page_info
16
+ from ...webagent_utils_async .action .prompt_functions import extract_top_actions
17
+ from ...webagent_utils_async .utils .utils import parse_function_args , locate_element
18
+ from ...evaluation_async .evaluators import goal_finished_evaluator
19
+
20
+ openai_client = OpenAI ()
21
+
22
+ logger = logging .getLogger (__name__ )
23
+ logger .setLevel (logging .INFO )
9
24
10
25
class MCTSAgent (BaseAgent ):
11
- async def run (self , websocket = None ) -> list [LATSNode ]:
26
+ """
27
+ Monte Carlo Tree Search Agent for web navigation tasks.
28
+ This implementation uses reflection-based search to improve performance.
29
+ """
30
+
31
+ async def run (self , websocket = None ) -> List [dict [str , Any ]]:
32
+ """
33
+ Run the MCTS algorithm based on configuration.
34
+
35
+ Args:
36
+ websocket: Optional WebSocket connection to send updates to
37
+
38
+ Returns:
39
+ List[Dict[str, Any]]: List of actions in the best path found
40
+ """
12
41
if websocket :
13
42
await websocket .send_json ({
14
43
"type" : "search_status" ,
@@ -17,4 +46,270 @@ async def run(self, websocket=None) -> list[LATSNode]:
17
46
"timestamp" : datetime .utcnow ().isoformat ()
18
47
})
19
48
20
- pass
49
+ # Reset browser to initial state
50
+ live_browser_url , session_id = await self ._reset_browser (websocket )
51
+
52
+ best_node = await self .mcts_search (websocket )
53
+ print_trajectory (best_node )
54
+
55
+ return best_node
56
+
57
+ async def node_selection (self , node : LATSNode , websocket = None ) -> Optional [LATSNode ]:
58
+ if node .is_terminal :
59
+ return None
60
+
61
+ current_node = node
62
+ path = [current_node ]
63
+ selection_depth = 0
64
+
65
+ while current_node .children and not current_node .is_terminal :
66
+ logger .info (f"\n Selection Step { selection_depth + 1 } :" )
67
+ logger .info (f"Current node action: { current_node .action } " )
68
+ logger .info (f"Number of children: { len (current_node .children )} " )
69
+
70
+ # Get trajectory for GPT-4 to evaluate
71
+ trajectory = []
72
+ for n in path [1 :]: # Skip root node
73
+ trajectory .append ({
74
+ "natural_language_description" : n .natural_language_description ,
75
+ "action" : n .action ,
76
+ "feedback" : n .feedback if hasattr (n , 'feedback' ) else None
77
+ })
78
+
79
+ # Create prompt for GPT-4
80
+ prompt = f"""Given the current trajectory and goal, select the most promising child node to explore next.
81
+ Consider the overall progress, efficiency, and likelihood of success.
82
+
83
+ Goal: { self .goal }
84
+
85
+ Current Trajectory:
86
+ { json .dumps (trajectory , indent = 2 )}
87
+
88
+ Available Children:
89
+ { json .dumps ([{
90
+ 'action' : child .action ,
91
+ 'description' : child .natural_language_description ,
92
+ 'visits' : child .visits ,
93
+ 'value' : child .value if hasattr (child , 'value' ) else 0
94
+ } for child in current_node .children ], indent = 2 )}
95
+
96
+ Return a JSON response with:
97
+ {{
98
+ "selected_child_index": int, # Index of the selected child
99
+ "explanation": str # Brief explanation of the selection
100
+ }}"""
101
+
102
+ response = openai_client .chat .completions .create (
103
+ model = self .config .evaluation_model ,
104
+ messages = [
105
+ {"role" : "system" , "content" : "You are an expert at selecting promising paths in a search tree." },
106
+ {"role" : "user" , "content" : prompt }
107
+ ],
108
+ response_format = {"type" : "json_object" }
109
+ )
110
+
111
+ selection = json .loads (response .choices [0 ].message .content )
112
+ selected_index = selection ["selected_child_index" ]
113
+
114
+ if 0 <= selected_index < len (current_node .children ):
115
+ current_node = current_node .children [selected_index ]
116
+ path .append (current_node )
117
+ logger .info (f"Selected child { selected_index + 1 } : { current_node .action } " )
118
+ logger .info (f"Selection explanation: { selection ['explanation' ]} " )
119
+ else :
120
+ logger .warning (f"Invalid child index { selected_index } , breaking selection" )
121
+ break
122
+
123
+
124
+ selection_depth += 1
125
+
126
+ # Send final node selection update
127
+ await self .websocket_node_selection (current_node , websocket = websocket )
128
+ return current_node
129
+
130
+ async def evaluate_selected_path (self , path ) -> None :
131
+ """Evaluate the current node and assign its score."""
132
+ # Get the path from root to this node
133
+ # path = self.get_path_to_root(node)
134
+
135
+ # Create trajectory for scoring (skip root node)
136
+ trajectory = []
137
+ for n in path [1 :]: # Skip root node
138
+ trajectory .append ({
139
+ "natural_language_description" : n .natural_language_description ,
140
+ "action" : n .action ,
141
+ "feedback" : n .feedback
142
+ })
143
+
144
+ # Score the trajectory
145
+ # TODO: if node is terminal, score is 0?
146
+ # if node.is_terminal:
147
+ # score = 0
148
+ prompt = create_llm_prompt (trajectory , self .goal )
149
+ print (f"prompt: { prompt } " )
150
+ result = score_trajectory_with_openai (
151
+ prompt ,
152
+ openai_client ,
153
+ model = self .config .evaluation_model
154
+ )
155
+ print (f"result: { result } " )
156
+ score = result ["overall_score" ]
157
+ print (f"Simulation Results, evaluate selected path:" )
158
+ print (f"Overall Score: { score :.3f} " )
159
+ print (f"Efficiency Score: { result ['efficiency_score' ]:.3f} " )
160
+ print (f"Accuracy Score: { result ['accuracy_score' ]:.3f} " )
161
+ print (f"Robustness Score: { result ['robustness_score' ]:.3f} " )
162
+ return score
163
+
164
+ async def reflection_backtracking (self , path ) -> List [LATSNode ]:
165
+ """
166
+ Implement reflection-based backtracking to improve search trajectory.
167
+
168
+ Args:
169
+ node: Current node
170
+ path: Current path from root to node
171
+
172
+ Returns:
173
+ List[LATSNode]: Modified path after backtracking
174
+ """
175
+ # Create trajectory for reflection
176
+ trajectory = []
177
+ for n in path [1 :]: # Skip root node
178
+ trajectory .append ({
179
+ "natural_language_description" : n .natural_language_description ,
180
+ "action" : n .action ,
181
+ "feedback" : n .feedback if hasattr (n , 'feedback' ) else None
182
+ })
183
+
184
+ score = await self .evaluate_selected_path (path )
185
+ print (f"\n Reflection Step (Score { score :.3f} < { self .config .reflection_score } ):" )
186
+
187
+ # Generate reflection prompt
188
+ reflection_prompt = f"""Analyze the current trajectory and suggest improvements for the current website.
189
+
190
+ Goal: { self .goal }
191
+
192
+ Current Trajectory:
193
+ { json .dumps (trajectory , indent = 2 )}
194
+
195
+ Score: { score }
196
+
197
+ Return a JSON response with:
198
+ {{
199
+ "backtrack_to_step": int, # Which step to backtrack to (0-based index)
200
+ "reason": str, # Why backtrack to this step
201
+ "suggested_improvements": [str] # List of suggested improvements specific to current websites
202
+ }}"""
203
+
204
+ reflection = openai_client .chat .completions .create (
205
+ model = self .config .evaluation_model ,
206
+ messages = [
207
+ {"role" : "system" , "content" : "You are an expert at analyzing and improving search trajectories." },
208
+ {"role" : "user" , "content" : reflection_prompt }
209
+ ],
210
+ response_format = {"type" : "json_object" }
211
+ )
212
+
213
+ reflection_result = json .loads (reflection .choices [0 ].message .content )
214
+ backtrack_step = reflection_result ["backtrack_to_step" ]
215
+
216
+ # Backtrack to the suggested step
217
+ if 0 <= backtrack_step < len (path ):
218
+ # Prevent backtracking to root when we have actions
219
+ if backtrack_step == 0 and len (path ) > 1 :
220
+ backtrack_step = 1
221
+ print ("Adjusted backtracking to maintain at least one action" )
222
+
223
+ current_node = path [backtrack_step ]
224
+ # Remove nodes after the backtrack point
225
+ while len (path ) > backtrack_step + 1 :
226
+ path .pop ()
227
+
228
+ print (f"Backtracking to step { backtrack_step } " )
229
+ print (f"Reason: { reflection_result ['reason' ]} " )
230
+ print ("Suggested improvements:" )
231
+ for improvement in reflection_result ["suggested_improvements" ]:
232
+ print (f"- { improvement } " )
233
+
234
+ return path
235
+
236
+ async def mcts_search (self , websocket = None ) -> Optional [LATSNode ]:
237
+ best_score = float ('-inf' )
238
+ best_node = None
239
+ print (f"iterations: { self .config .iterations } " )
240
+
241
+ for i in range (self .config .iterations ):
242
+ await self .websocket_iteration_start (i , websocket = websocket )
243
+
244
+ print (f"\n { '=' * 50 } " )
245
+ print (f"MCTS Iteration { i + 1 } /{ self .config .iterations } " )
246
+ print (f"{ '=' * 50 } \n " )
247
+
248
+ # Step 1: Node Selection (contain simulation)
249
+ # "node selection" combines selection and partial simulation
250
+ print (f"{ GREEN } Step 1: Node Selection{ RESET } " )
251
+ await self .websocket_step_start (step = 1 , step_name = "node_selection" , websocket = websocket )
252
+ node = await self .node_selection (self .root_node , websocket )
253
+
254
+ if node is None :
255
+ logger .warning ("All paths lead to terminal nodes. Ending search." )
256
+ break
257
+
258
+ # Step 2: Node Expansion
259
+ print (f"{ GREEN } Step 2: Node Expansion{ RESET } " )
260
+ await self .websocket_step_start (step = 2 , step_name = "node_expansion" , websocket = websocket )
261
+ await self .node_expansion (node , websocket )
262
+ if node is None :
263
+ # all the nodes are terminal, stop the search
264
+ print (f"{ RED } All nodes are terminal, stopping search{ RESET } " )
265
+ break
266
+ tree_data = self ._get_tree_data ()
267
+ if websocket :
268
+ await self .websocket_tree_update (type = "tree_update_node_expansion" , websocket = websocket , tree_data = tree_data )
269
+ else :
270
+ print_entire_tree (self .root_node )
271
+
272
+
273
+ # Step 3: simulation using the current node, (generate a path using the current node, and score the path)
274
+ # TODO: implement simulation using openai
275
+ print (f"{ GREEN } Step 3: Simulation{ RESET } " )
276
+ await self .websocket_step_start (step = 3 , step_name = "simulation" , websocket = websocket )
277
+ path = self .get_path_to_root (node )
278
+ score = await self .evaluate_selected_path (path )
279
+ # change to reward later?
280
+ if score > best_score :
281
+ best_score = score
282
+ best_path = path
283
+ print (f"\n New best path found!" )
284
+ print (f"Previous best score: { best_score :.3f} " )
285
+ print (f"New best score: { score :.3f} " )
286
+
287
+
288
+ ## Step 4: reflection backtracking
289
+ print (f"{ GREEN } Step 4: Reflection Backtracking{ RESET } " )
290
+ await self .websocket_step_start (step = 4 , step_name = "reflection_backtracking" , websocket = websocket )
291
+ if score >= self .config .reflection_score :
292
+ # Convert path to serializable trajectory
293
+ trajectory = [node .action for node in path if node .action is not None ]
294
+ await self .websocket_search_complete ("success" , score , trajectory , websocket = websocket )
295
+ return node
296
+
297
+ print (f"path: { path } " )
298
+ path = await self .reflection_backtracking (path )
299
+ print (f"path: { path } " )
300
+
301
+ # Step 5: backpropagation
302
+ print (f"{ GREEN } Step 5: Backpropagation{ RESET } " )
303
+ await self .websocket_step_start (step = 5 , step_name = "backpropagation" , websocket = websocket )
304
+ for node in path :
305
+ old_value = node .value
306
+ node .visits += 1
307
+ node .value = (node .value * (node .visits - 1 ) + score ) / node .visits
308
+ print (f"Node { node .action } :" )
309
+ print (f" Visits: { node .visits } " )
310
+ print (f" Value: { old_value :.3f} -> { node .value :.3f} " )
311
+ if best_node :
312
+ # Convert node to serializable trajectory
313
+ trajectory = [n .action for n in self .get_path_to_root (best_node ) if n .action is not None ]
314
+ await self .websocket_search_complete ("partial_success" , best_node .value , best_node .get_trajectory (), websocket = websocket )
315
+ return best_node
0 commit comments