Skip to content

Commit cf8bca0

Browse files
committed
add 5) Browser Set Up, Account Reset, 6) Tree Update & Trajectory Update, 7) node_selected for LATS
1 parent 210fadb commit cf8bca0

File tree

5 files changed

+196
-11
lines changed

5 files changed

+196
-11
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 112 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ async def run(self, websocket=None) -> list[LATSNode]:
119119
print_trajectory(best_node)
120120

121121
if websocket:
122+
# trajectory_data = self._get_trajectory_data(best_node)
123+
# await websocket.send_json({
124+
# "type": "trajectory_update",
125+
# "trajectory": trajectory_data,
126+
# "timestamp": datetime.utcnow().isoformat()
127+
# })
122128
# TODO: use score instead of reward to determine success
123129
await websocket.send_json({
124130
"type": "search_complete",
@@ -168,6 +174,12 @@ async def lats_search(self, websocket=None) -> LATSNode:
168174
})
169175

170176
node = self.select_node(self.root_node)
177+
if websocket:
178+
await websocket.send_json({
179+
"type": "node_selected",
180+
"node_id": id(node),
181+
"timestamp": datetime.utcnow().isoformat()
182+
})
171183

172184
if node is None:
173185
print("All paths lead to terminal nodes with reward 0. Ending search.")
@@ -207,6 +219,12 @@ async def lats_search(self, websocket=None) -> LATSNode:
207219
print(f"{GREEN}Tree:{RESET}")
208220
better_print(self.root_node)
209221
print(f"")
222+
tree_data = self._get_tree_data()
223+
await websocket.send_json({
224+
"type": "tree_update",
225+
"tree": tree_data,
226+
"timestamp": datetime.utcnow().isoformat()
227+
})
210228

211229
# Step 3: Evaluation
212230
print(f"")
@@ -224,6 +242,15 @@ async def lats_search(self, websocket=None) -> LATSNode:
224242
print(f"{GREEN}Tree:{RESET}")
225243
better_print(self.root_node)
226244
print(f"")
245+
## send tree update, since evaluation is added to the tree
246+
if websocket:
247+
tree_data = self._get_tree_data()
248+
await websocket.send_json({
249+
"type": "tree_update",
250+
"tree": tree_data,
251+
"timestamp": datetime.utcnow().isoformat()
252+
})
253+
227254

228255
# Step 4: Simulation
229256
print(f"{GREEN}Step 4: simulation{RESET}")
@@ -236,7 +263,7 @@ async def lats_search(self, websocket=None) -> LATSNode:
236263
"timestamp": datetime.utcnow().isoformat()
237264
})
238265
## always = 1
239-
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1)
266+
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1, websocket=websocket)
240267
terminal_nodes.append(terminal_node)
241268

242269
if reward == 1:
@@ -362,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None:
362389
child.value = score
363390
child.reward = score
364391

365-
async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]:
392+
## TODO: make number of simulations configurable
393+
async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1, websocket=None) -> tuple[float, LATSNode]:
366394
"""
367395
Perform a rollout simulation from a node.
368396
@@ -378,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
378406
print_trajectory(node)
379407
print("print the entire tree")
380408
print_entire_tree(self.root_node)
381-
return await self.rollout(node, max_depth=max_depth)
409+
if websocket:
410+
tree_data = self._get_tree_data()
411+
await websocket.send_json({
412+
"type": "tree_update",
413+
"tree": tree_data,
414+
"timestamp": datetime.utcnow().isoformat()
415+
})
416+
trajectory_data = self._get_trajectory_data(node)
417+
await websocket.send_json({
418+
"type": "trajectory_update",
419+
"trajectory": trajectory_data,
420+
"timestamp": datetime.utcnow().isoformat()
421+
})
422+
return await self.rollout(node, max_depth=max_depth, websocket=websocket)
382423

383-
async def send_completion_request(self, plan, depth, node, trajectory=[]):
424+
async def send_completion_request(self, plan, depth, node, trajectory=[], websocket=None):
384425
print("print the trajectory")
385426
print_trajectory(node)
386427
print("print the entire tree")
387428
print_entire_tree(self.root_node)
429+
if websocket:
430+
# tree_data = self._get_tree_data()
431+
# await websocket.send_json({
432+
# "type": "tree_update",
433+
# "tree": tree_data,
434+
# "timestamp": datetime.utcnow().isoformat()
435+
# })
436+
trajectory_data = self._get_trajectory_data(node)
437+
await websocket.send_json({
438+
"type": "trajectory_update",
439+
"trajectory": trajectory_data,
440+
"timestamp": datetime.utcnow().isoformat()
441+
})
388442

389443
if depth >= self.config.max_depth:
390444
return trajectory, node
@@ -447,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]):
447501
if goal_finished:
448502
return trajectory, new_node
449503

450-
return await self.send_completion_request(plan, depth + 1, new_node, trajectory)
504+
return await self.send_completion_request(plan, depth + 1, new_node, trajectory, websocket)
451505

452506
except Exception as e:
453507
print(f"Attempt {attempt + 1} failed with error: {e}")
454508
if attempt + 1 == retry_count:
455509
print("Max retries reached. Skipping this step and retrying the whole request.")
456510
# Retry the entire request from the same state
457-
return await self.send_completion_request(plan, depth, node, trajectory)
511+
return await self.send_completion_request(plan, depth, node, trajectory, websocket)
458512

459513
# If all retries and retries of retries fail, return the current trajectory and node
460514
return trajectory, node
461515

462516

463-
async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
517+
async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tuple[float, LATSNode]:
464518
# Reset browser state
465519
await self._reset_browser()
466520
path = self.get_path_to_root(node)
@@ -494,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN
494548
## call the prompt agent
495549
print("current depth: ", len(path) - 1)
496550
print("max depth: ", self.config.max_depth)
497-
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory)
551+
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory, websocket=websocket)
498552
print("print the trajectory")
499553
print_trajectory(node)
500554
print("print the entire tree")
501555
print_entire_tree(self.root_node)
556+
if websocket:
557+
# tree_data = self._get_tree_data()
558+
# await websocket.send_json({
559+
# "type": "tree_update",
560+
# "tree": tree_data,
561+
# "timestamp": datetime.utcnow().isoformat()
562+
# })
563+
trajectory_data = self._get_trajectory_data(node)
564+
await websocket.send_json({
565+
"type": "trajectory_update",
566+
"trajectory": trajectory_data,
567+
"timestamp": datetime.utcnow().isoformat()
568+
})
502569

503570
page = await self.playwright_manager.get_page()
504571
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
@@ -801,3 +868,40 @@ def _get_tree_data(self):
801868
tree_data.append(node_data)
802869

803870
return tree_data
871+
872+
def _get_trajectory_data(self, terminal_node: LATSNode):
873+
"""Get trajectory data in a format suitable for visualization
874+
875+
Args:
876+
terminal_node: The leaf node to start the trajectory from
877+
878+
Returns:
879+
list: List of node data dictionaries representing the trajectory
880+
"""
881+
trajectory_data = []
882+
path = []
883+
884+
# Collect path from terminal to root
885+
current = terminal_node
886+
while current is not None:
887+
path.append(current)
888+
current = current.parent
889+
890+
# Process nodes in order from root to terminal
891+
for level, node in enumerate(reversed(path)):
892+
node_data = {
893+
"id": id(node),
894+
"level": level,
895+
"action": node.action if node.action else "ROOT",
896+
"description": node.natural_language_description,
897+
"visits": node.visits,
898+
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
899+
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
900+
"is_terminal": node.is_terminal,
901+
"feedback": node.feedback if hasattr(node, 'feedback') else None,
902+
"is_root": not hasattr(node, 'parent') or node.parent is None,
903+
"is_terminal_node": node == terminal_node
904+
}
905+
trajectory_data.append(node_data)
906+
907+
return trajectory_data

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/simple_search_agent.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1145,9 +1145,51 @@ def _get_tree_data(self):
11451145
"action": node.action if node.action else "ROOT",
11461146
"description": node.natural_language_description,
11471147
"depth": node.depth,
1148-
"is_terminal": node.is_terminal
1148+
"is_terminal": node.is_terminal,
1149+
"value": node.value,
1150+
"visits": node.visits,
1151+
"reward": node.reward
11491152
}
11501153
tree_data.append(node_data)
11511154

11521155
return tree_data
1156+
1157+
def _get_trajectory_data(self, terminal_node: LATSNode):
1158+
"""Get trajectory data in a format suitable for visualization
1159+
1160+
Args:
1161+
terminal_node: The leaf node to start the trajectory from
1162+
1163+
Returns:
1164+
list: List of node data dictionaries representing the trajectory
1165+
"""
1166+
trajectory_data = []
1167+
path = []
1168+
1169+
# Collect path from terminal to root
1170+
current = terminal_node
1171+
while current is not None:
1172+
path.append(current)
1173+
current = current.parent
1174+
1175+
# Process nodes in order from root to terminal
1176+
for level, node in enumerate(reversed(path)):
1177+
node_data = {
1178+
"id": id(node),
1179+
"level": level,
1180+
"action": node.action if node.action else "ROOT",
1181+
"description": node.natural_language_description,
1182+
"visits": node.visits,
1183+
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
1184+
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
1185+
"is_terminal": node.is_terminal,
1186+
"feedback": node.feedback if hasattr(node, 'feedback') else None,
1187+
"is_root": not hasattr(node, 'parent') or node.parent is None,
1188+
"is_terminal_node": node == terminal_node
1189+
}
1190+
trajectory_data.append(node_data)
1191+
1192+
return trajectory_data
1193+
1194+
11531195

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def _get_trajectory_data(self, terminal_node: LATSNode):
2+
"""Get trajectory data in a format suitable for visualization
3+
4+
Args:
5+
terminal_node: The leaf node to start the trajectory from
6+
7+
Returns:
8+
list: List of node data dictionaries representing the trajectory
9+
"""
10+
trajectory_data = []
11+
path = []
12+
13+
# Collect path from terminal to root
14+
current = terminal_node
15+
while current is not None:
16+
path.append(current)
17+
current = current.parent
18+
19+
# Process nodes in order from root to terminal
20+
for level, node in enumerate(reversed(path)):
21+
node_data = {
22+
"id": id(node),
23+
"level": level,
24+
"action": node.action if node.action else "ROOT",
25+
"description": node.natural_language_description,
26+
"visits": node.visits,
27+
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
28+
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
29+
"is_terminal": node.is_terminal,
30+
"feedback": node.feedback if hasattr(node, 'feedback') else None,
31+
"is_root": not hasattr(node, 'parent') or node.parent is None,
32+
"is_terminal_node": node == terminal_node
33+
}
34+
trajectory_data.append(node_data)
35+
36+
return trajectory_data

visual-tree-search-backend/test/test-tree-search-ws-lats.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,11 @@ async def connect_and_test_search(
8585
if status == "scored":
8686
logger.info(f"Node score: {data.get('score')}")
8787

88+
elif msg_type == "trajectory_update":
89+
logger.info(f"Trajectory update received with {data.get('trajectory')}")
90+
8891
elif msg_type == "tree_update":
89-
logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes")
92+
logger.info(f"Tree update received with {data.get('tree')}")
9093

9194
elif msg_type == "best_path_update":
9295
logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}")

visual-tree-search-backend/test/test-tree-search-ws-simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def connect_and_test_search(
8080
logger.info(f"Node score: {data.get('score')}")
8181

8282
elif msg_type == "tree_update":
83-
logger.info(f"Tree update received with {len(data.get('nodes', []))} nodes")
83+
logger.info(f"Tree update received with {data.get('tree')}")
8484

8585
elif msg_type == "best_path_update":
8686
logger.info(f"Best path update: score={data.get('score')}, path={data.get('path')}")

0 commit comments

Comments
 (0)