37
37
38
38
openai_client = OpenAI ()
39
39
40
+ ## TODO: add best_path_update
41
+
40
42
class LATSAgent :
41
43
"""
42
44
Language-based Action Tree Search Agent implementation.
@@ -117,6 +119,13 @@ async def run(self, websocket=None) -> list[LATSNode]:
117
119
print_trajectory (best_node )
118
120
119
121
if websocket :
122
+ # trajectory_data = self._get_trajectory_data(best_node)
123
+ # await websocket.send_json({
124
+ # "type": "trajectory_update",
125
+ # "trajectory": trajectory_data,
126
+ # "timestamp": datetime.utcnow().isoformat()
127
+ # })
128
+ # TODO: use score instead of reward to determine success
120
129
await websocket .send_json ({
121
130
"type" : "search_complete" ,
122
131
"status" : "success" if best_node .reward == 1 else "partial_success" ,
@@ -158,12 +167,19 @@ async def lats_search(self, websocket=None) -> LATSNode:
158
167
if websocket :
159
168
await websocket .send_json ({
160
169
"type" : "step_start" ,
161
- "step" : "selection" ,
170
+ "step" : 1 ,
171
+ "step_name" : "selection" ,
162
172
"iteration" : i + 1 ,
163
173
"timestamp" : datetime .utcnow ().isoformat ()
164
174
})
165
175
166
176
node = self .select_node (self .root_node )
177
+ if websocket :
178
+ await websocket .send_json ({
179
+ "type" : "node_selected" ,
180
+ "node_id" : id (node ),
181
+ "timestamp" : datetime .utcnow ().isoformat ()
182
+ })
167
183
168
184
if node is None :
169
185
print ("All paths lead to terminal nodes with reward 0. Ending search." )
@@ -177,7 +193,8 @@ async def lats_search(self, websocket=None) -> LATSNode:
177
193
if websocket :
178
194
await websocket .send_json ({
179
195
"type" : "step_start" ,
180
- "step" : "expansion" ,
196
+ "step" : 2 ,
197
+ "step_name" : "expansion" ,
181
198
"iteration" : i + 1 ,
182
199
"timestamp" : datetime .utcnow ().isoformat ()
183
200
})
@@ -202,28 +219,65 @@ async def lats_search(self, websocket=None) -> LATSNode:
202
219
print (f"{ GREEN } Tree:{ RESET } " )
203
220
better_print (self .root_node )
204
221
print (f"" )
222
+ tree_data = self ._get_tree_data ()
223
+ await websocket .send_json ({
224
+ "type" : "tree_update" ,
225
+ "tree" : tree_data ,
226
+ "timestamp" : datetime .utcnow ().isoformat ()
227
+ })
205
228
206
229
# Step 3: Evaluation
207
230
print (f"" )
208
231
print (f"{ GREEN } Step 3: evaluation{ RESET } " )
232
+ if websocket :
233
+ await websocket .send_json ({
234
+ "type" : "step_start" ,
235
+ "step" : 3 ,
236
+ "step_name" : "evaluation" ,
237
+ "iteration" : i + 1 ,
238
+ "timestamp" : datetime .utcnow ().isoformat ()
239
+ })
209
240
await self .evaluate_node (node )
210
241
211
242
print (f"{ GREEN } Tree:{ RESET } " )
212
243
better_print (self .root_node )
213
244
print (f"" )
245
+ ## send tree update, since evaluation is added to the tree
246
+ if websocket :
247
+ tree_data = self ._get_tree_data ()
248
+ await websocket .send_json ({
249
+ "type" : "tree_update" ,
250
+ "tree" : tree_data ,
251
+ "timestamp" : datetime .utcnow ().isoformat ()
252
+ })
253
+
214
254
215
255
# Step 4: Simulation
216
256
print (f"{ GREEN } Step 4: simulation{ RESET } " )
217
- # # Find the child with the highest value
257
+ if websocket :
258
+ await websocket .send_json ({
259
+ "type" : "step_start" ,
260
+ "step" : 4 ,
261
+ "step_name" : "simulation" ,
262
+ "iteration" : i + 1 ,
263
+ "timestamp" : datetime .utcnow ().isoformat ()
264
+ })
218
265
## always = 1
219
- reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 )
266
+ reward , terminal_node = await self .simulate (max (node .children , key = lambda child : child .value ), max_depth = self .config .max_depth , num_simulations = 1 , websocket = websocket )
220
267
terminal_nodes .append (terminal_node )
221
268
222
269
if reward == 1 :
223
270
return terminal_node
224
271
225
272
# Step 5: Backpropagation
226
273
print (f"{ GREEN } Step 5: backpropagation{ RESET } " )
274
+ if websocket :
275
+ await websocket .send_json ({
276
+ "type" : "step_start" ,
277
+ "step" : 5 ,
278
+ "step_name" : "backpropagation" ,
279
+ "timestamp" : datetime .utcnow ().isoformat ()
280
+ })
227
281
self .backpropagate (terminal_node , reward )
228
282
print (f"{ GREEN } Tree:{ RESET } " )
229
283
better_print (self .root_node )
@@ -335,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None:
335
389
child .value = score
336
390
child .reward = score
337
391
338
- async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 ) -> tuple [float , LATSNode ]:
392
+ ## TODO: make number of simulations configurable
393
+ async def simulate (self , node : LATSNode , max_depth : int = 2 , num_simulations = 1 , websocket = None ) -> tuple [float , LATSNode ]:
339
394
"""
340
395
Perform a rollout simulation from a node.
341
396
@@ -351,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
351
406
print_trajectory (node )
352
407
print ("print the entire tree" )
353
408
print_entire_tree (self .root_node )
354
- return await self .rollout (node , max_depth = max_depth )
409
+ if websocket :
410
+ tree_data = self ._get_tree_data ()
411
+ await websocket .send_json ({
412
+ "type" : "tree_update" ,
413
+ "tree" : tree_data ,
414
+ "timestamp" : datetime .utcnow ().isoformat ()
415
+ })
416
+ trajectory_data = self ._get_trajectory_data (node )
417
+ await websocket .send_json ({
418
+ "type" : "trajectory_update" ,
419
+ "trajectory" : trajectory_data ,
420
+ "timestamp" : datetime .utcnow ().isoformat ()
421
+ })
422
+ return await self .rollout (node , max_depth = max_depth , websocket = websocket )
355
423
356
- async def send_completion_request (self , plan , depth , node , trajectory = []):
424
+ async def send_completion_request (self , plan , depth , node , trajectory = [], websocket = None ):
357
425
print ("print the trajectory" )
358
426
print_trajectory (node )
359
427
print ("print the entire tree" )
360
428
print_entire_tree (self .root_node )
429
+ if websocket :
430
+ # tree_data = self._get_tree_data()
431
+ # await websocket.send_json({
432
+ # "type": "tree_update",
433
+ # "tree": tree_data,
434
+ # "timestamp": datetime.utcnow().isoformat()
435
+ # })
436
+ trajectory_data = self ._get_trajectory_data (node )
437
+ await websocket .send_json ({
438
+ "type" : "trajectory_update" ,
439
+ "trajectory" : trajectory_data ,
440
+ "timestamp" : datetime .utcnow ().isoformat ()
441
+ })
361
442
362
443
if depth >= self .config .max_depth :
363
444
return trajectory , node
@@ -420,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]):
420
501
if goal_finished :
421
502
return trajectory , new_node
422
503
423
- return await self .send_completion_request (plan , depth + 1 , new_node , trajectory )
504
+ return await self .send_completion_request (plan , depth + 1 , new_node , trajectory , websocket )
424
505
425
506
except Exception as e :
426
507
print (f"Attempt { attempt + 1 } failed with error: { e } " )
427
508
if attempt + 1 == retry_count :
428
509
print ("Max retries reached. Skipping this step and retrying the whole request." )
429
510
# Retry the entire request from the same state
430
- return await self .send_completion_request (plan , depth , node , trajectory )
511
+ return await self .send_completion_request (plan , depth , node , trajectory , websocket )
431
512
432
513
# If all retries and retries of retries fail, return the current trajectory and node
433
514
return trajectory , node
434
515
435
516
436
- async def rollout (self , node : LATSNode , max_depth : int = 2 )-> tuple [float , LATSNode ]:
517
+ async def rollout (self , node : LATSNode , max_depth : int = 2 , websocket = None )-> tuple [float , LATSNode ]:
437
518
# Reset browser state
438
519
await self ._reset_browser ()
439
520
path = self .get_path_to_root (node )
@@ -467,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN
467
548
## call the prompt agent
468
549
print ("current depth: " , len (path ) - 1 )
469
550
print ("max depth: " , self .config .max_depth )
470
- trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory )
551
+ trajectory , node = await self .send_completion_request (self .goal , len (path ) - 1 , node = n , trajectory = trajectory , websocket = websocket )
471
552
print ("print the trajectory" )
472
553
print_trajectory (node )
473
554
print ("print the entire tree" )
474
555
print_entire_tree (self .root_node )
556
+ if websocket :
557
+ # tree_data = self._get_tree_data()
558
+ # await websocket.send_json({
559
+ # "type": "tree_update",
560
+ # "tree": tree_data,
561
+ # "timestamp": datetime.utcnow().isoformat()
562
+ # })
563
+ trajectory_data = self ._get_trajectory_data (node )
564
+ await websocket .send_json ({
565
+ "type" : "trajectory_update" ,
566
+ "trajectory" : trajectory_data ,
567
+ "timestamp" : datetime .utcnow ().isoformat ()
568
+ })
475
569
476
570
page = await self .playwright_manager .get_page ()
477
571
page_info = await extract_page_info (page , self .config .fullpage , self .config .log_folder )
@@ -769,8 +863,46 @@ def _get_tree_data(self):
769
863
"is_terminal" : node .is_terminal ,
770
864
"value" : node .value ,
771
865
"visits" : node .visits ,
866
+ "feedback" : node .feedback ,
772
867
"reward" : node .reward
773
868
}
774
869
tree_data .append (node_data )
775
870
776
871
return tree_data
872
+
873
+ def _get_trajectory_data (self , terminal_node : LATSNode ):
874
+ """Get trajectory data in a format suitable for visualization
875
+
876
+ Args:
877
+ terminal_node: The leaf node to start the trajectory from
878
+
879
+ Returns:
880
+ list: List of node data dictionaries representing the trajectory
881
+ """
882
+ trajectory_data = []
883
+ path = []
884
+
885
+ # Collect path from terminal to root
886
+ current = terminal_node
887
+ while current is not None :
888
+ path .append (current )
889
+ current = current .parent
890
+
891
+ # Process nodes in order from root to terminal
892
+ for level , node in enumerate (reversed (path )):
893
+ node_data = {
894
+ "id" : id (node ),
895
+ "level" : level ,
896
+ "action" : node .action if node .action else "ROOT" ,
897
+ "description" : node .natural_language_description ,
898
+ "visits" : node .visits ,
899
+ "value" : float (f"{ node .value :.3f} " ) if hasattr (node , 'value' ) else None ,
900
+ "reward" : float (f"{ node .reward :.3f} " ) if hasattr (node , 'reward' ) else None ,
901
+ "is_terminal" : node .is_terminal ,
902
+ "feedback" : node .feedback if hasattr (node , 'feedback' ) else None ,
903
+ "is_root" : not hasattr (node , 'parent' ) or node .parent is None ,
904
+ "is_terminal_node" : node == terminal_node
905
+ }
906
+ trajectory_data .append (node_data )
907
+
908
+ return trajectory_data
0 commit comments