@@ -119,6 +119,12 @@ async def run(self, websocket=None) -> list[LATSNode]:
119119 print_trajectory (best_node )
120120
121121 if websocket :
122+ # trajectory_data = self._get_trajectory_data(best_node)
123+ # await websocket.send_json({
124+ # "type": "trajectory_update",
125+ # "trajectory": trajectory_data,
126+ # "timestamp": datetime.utcnow().isoformat()
127+ # })
122128 # TODO: use score instead of reward to determine success
123129 await websocket .send_json ({
124130 "type" : "search_complete" ,
@@ -168,6 +174,12 @@ async def lats_search(self, websocket=None) -> LATSNode:
168174 })
169175
170176 node = self .select_node (self .root_node )
177+ if websocket :
178+ await websocket .send_json ({
179+ "type" : "node_selected" ,
180+ "node_id" : id (node ),
181+ "timestamp" : datetime .utcnow ().isoformat ()
182+ })
171183
172184 if node is None :
173185 print ("All paths lead to terminal nodes with reward 0. Ending search." )
@@ -207,6 +219,12 @@ async def lats_search(self, websocket=None) -> LATSNode:
207219 print (f"{ GREEN } Tree:{ RESET } " )
208220 better_print (self .root_node )
209221 print (f"" )
222+ tree_data = self ._get_tree_data ()
223+ await websocket .send_json ({
224+ "type" : "tree_update" ,
225+ "tree" : tree_data ,
226+ "timestamp" : datetime .utcnow ().isoformat ()
227+ })
210228
211229 # Step 3: Evaluation
212230 print (f"" )
@@ -224,6 +242,15 @@ async def lats_search(self, websocket=None) -> LATSNode:
224242 print (f"{ GREEN } Tree:{ RESET } " )
225243 better_print (self .root_node )
226244 print (f"" )
245+ ## send tree update, since evaluation is added to the tree
246+ if websocket :
247+ tree_data = self ._get_tree_data ()
248+ await websocket .send_json ({
249+ "type" : "tree_update" ,
250+ "tree" : tree_data ,
251+ "timestamp" : datetime .utcnow ().isoformat ()
252+ })
253+
227254
228255 # Step 4: Simulation
229256 print (f"{ GREEN } Step 4: simulation{ RESET } " )
@@ -236,7 +263,7 @@ async def lats_search(self, websocket=None) -> LATSNode:
236263 "timestamp" : datetime .utcnow ().isoformat ()
237264 })
238265 ## always = 1
239- reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 )
266+ reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 , websocket = websocket )
240267 terminal_nodes .append (terminal_node )
241268
242269 if reward == 1 :
@@ -362,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None:
362389 child .value = score
363390 child .reward = score
364391
365- async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 ) -> tuple [float , LATSNode ]:
392+ ## TODO: make number of simulations configurable
393+ async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 , websocket = None ) -> tuple [float , LATSNode ]:
366394 """
367395 Perform a rollout simulation from a node.
368396
@@ -378,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
378406 print_trajectory (node )
379407 print ("print the entire tree" )
380408 print_entire_tree (self .root_node )
381- return await self .rollout (node , max_depth = max_depth )
409+ if websocket :
410+ tree_data = self ._get_tree_data ()
411+ await websocket .send_json ({
412+ "type" : "tree_update" ,
413+ "tree" : tree_data ,
414+ "timestamp" : datetime .utcnow ().isoformat ()
415+ })
416+ trajectory_data = self ._get_trajectory_data (node )
417+ await websocket .send_json ({
418+ "type" : "trajectory_update" ,
419+ "trajectory" : trajectory_data ,
420+ "timestamp" : datetime .utcnow ().isoformat ()
421+ })
422+ return await self .rollout (node , max_depth = max_depth , websocket = websocket )
382423
383- async def send_completion_request (self , plan , depth , node , trajectory = []):
424+ async def send_completion_request (self , plan , depth , node , trajectory = [], websocket = None ):
384425 print ("print the trajectory" )
385426 print_trajectory (node )
386427 print ("print the entire tree" )
387428 print_entire_tree (self .root_node )
429+ if websocket :
430+ # tree_data = self._get_tree_data()
431+ # await websocket.send_json({
432+ # "type": "tree_update",
433+ # "tree": tree_data,
434+ # "timestamp": datetime.utcnow().isoformat()
435+ # })
436+ trajectory_data = self ._get_trajectory_data (node )
437+ await websocket .send_json ({
438+ "type" : "trajectory_update" ,
439+ "trajectory" : trajectory_data ,
440+ "timestamp" : datetime .utcnow ().isoformat ()
441+ })
388442
389443 if depth >= self .config .max_depth :
390444 return trajectory , node
@@ -447,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]):
447501 if goal_finished :
448502 return trajectory , new_node
449503
450- return await self .send_completion_request (plan , depth + 1 , new_node , trajectory )
504+ return await self .send_completion_request (plan , depth + 1 , new_node , trajectory , websocket )
451505
452506 except Exception as e :
453507 print (f"Attempt { attempt + 1 } failed with error: { e } " )
454508 if attempt + 1 == retry_count :
455509 print ("Max retries reached. Skipping this step and retrying the whole request." )
456510 # Retry the entire request from the same state
457- return await self .send_completion_request (plan , depth , node , trajectory )
511+ return await self .send_completion_request (plan , depth , node , trajectory , websocket )
458512
459513 # If all retries and retries of retries fail, return the current trajectory and node
460514 return trajectory , node
461515
462516
463- async def rollout (self , node : LATSNode , max_depth : int = 2 )-> tuple [float , LATSNode ]:
517+ async def rollout (self , node : LATSNode , max_depth : int = 2 , websocket = None )-> tuple [float , LATSNode ]:
464518 # Reset browser state
465519 await self ._reset_browser ()
466520 path = self .get_path_to_root (node )
@@ -494,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN
494548 ## call the prompt agent
495549 print ("current depth: " , len (path ) - 1 )
496550 print ("max depth: " , self .config .max_depth )
497- trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory )
551+ trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory , websocket = websocket )
498552 print ("print the trajectory" )
499553 print_trajectory (node )
500554 print ("print the entire tree" )
501555 print_entire_tree (self .root_node )
556+ if websocket :
557+ # tree_data = self._get_tree_data()
558+ # await websocket.send_json({
559+ # "type": "tree_update",
560+ # "tree": tree_data,
561+ # "timestamp": datetime.utcnow().isoformat()
562+ # })
563+ trajectory_data = self ._get_trajectory_data (node )
564+ await websocket .send_json ({
565+ "type" : "trajectory_update" ,
566+ "trajectory" : trajectory_data ,
567+ "timestamp" : datetime .utcnow ().isoformat ()
568+ })
502569
503570 page = await self .playwright_manager .get_page ()
504571 page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
@@ -801,3 +868,40 @@ def _get_tree_data(self):
801868 tree_data .append (node_data )
802869
803870 return tree_data
871+
872+ def _get_trajectory_data (self , terminal_node : LATSNode ):
873+ """Get trajectory data in a format suitable for visualization
874+
875+ Args:
876+ terminal_node: The leaf node to start the trajectory from
877+
878+ Returns:
879+ list: List of node data dictionaries representing the trajectory
880+ """
881+ trajectory_data = []
882+ path = []
883+
884+ # Collect path from terminal to root
885+ current = terminal_node
886+ while current is not None :
887+ path .append (current )
888+ current = current .parent
889+
890+ # Process nodes in order from root to terminal
891+ for level , node in enumerate (reversed (path )):
892+ node_data = {
893+ "id" : id (node ),
894+ "level" : level ,
895+ "action" : node .action if node .action else "ROOT" ,
896+ "description" : node .natural_language_description ,
897+ "visits" : node .visits ,
898+ "value" : float (f"{ node .value :.3f} " ) if hasattr (node , 'value' ) else None ,
899+ "reward" : float (f"{ node .reward :.3f} " ) if hasattr (node , 'reward' ) else None ,
900+ "is_terminal" : node .is_terminal ,
901+ "feedback" : node .feedback if hasattr (node , 'feedback' ) else None ,
902+ "is_root" : not hasattr (node , 'parent' ) or node .parent is None ,
903+ "is_terminal_node" : node == terminal_node
904+ }
905+ trajectory_data .append (node_data )
906+
907+ return trajectory_data
0 commit comments