Skip to content

Commit b460be5

Browse files
authored
Merge pull request #53 from PathOnAI/consolidate-websocket
Consolidate websocket
2 parents 2df5493 + bcb3ecc commit b460be5

File tree

5 files changed

+248
-25
lines changed

5 files changed

+248
-25
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/lats_agent.py

Lines changed: 143 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737

3838
openai_client = OpenAI()
3939

40+
## TODO: add best_path_update
41+
4042
class LATSAgent:
4143
"""
4244
Language-based Action Tree Search Agent implementation.
@@ -117,6 +119,13 @@ async def run(self, websocket=None) -> list[LATSNode]:
117119
print_trajectory(best_node)
118120

119121
if websocket:
122+
# trajectory_data = self._get_trajectory_data(best_node)
123+
# await websocket.send_json({
124+
# "type": "trajectory_update",
125+
# "trajectory": trajectory_data,
126+
# "timestamp": datetime.utcnow().isoformat()
127+
# })
128+
# TODO: use score instead of reward to determine success
120129
await websocket.send_json({
121130
"type": "search_complete",
122131
"status": "success" if best_node.reward == 1 else "partial_success",
@@ -158,12 +167,19 @@ async def lats_search(self, websocket=None) -> LATSNode:
158167
if websocket:
159168
await websocket.send_json({
160169
"type": "step_start",
161-
"step": "selection",
170+
"step": 1,
171+
"step_name": "selection",
162172
"iteration": i + 1,
163173
"timestamp": datetime.utcnow().isoformat()
164174
})
165175

166176
node = self.select_node(self.root_node)
177+
if websocket:
178+
await websocket.send_json({
179+
"type": "node_selected",
180+
"node_id": id(node),
181+
"timestamp": datetime.utcnow().isoformat()
182+
})
167183

168184
if node is None:
169185
print("All paths lead to terminal nodes with reward 0. Ending search.")
@@ -177,7 +193,8 @@ async def lats_search(self, websocket=None) -> LATSNode:
177193
if websocket:
178194
await websocket.send_json({
179195
"type": "step_start",
180-
"step": "expansion",
196+
"step": 2,
197+
"step_name": "expansion",
181198
"iteration": i + 1,
182199
"timestamp": datetime.utcnow().isoformat()
183200
})
@@ -202,28 +219,65 @@ async def lats_search(self, websocket=None) -> LATSNode:
202219
print(f"{GREEN}Tree:{RESET}")
203220
better_print(self.root_node)
204221
print(f"")
222+
tree_data = self._get_tree_data()
223+
await websocket.send_json({
224+
"type": "tree_update",
225+
"tree": tree_data,
226+
"timestamp": datetime.utcnow().isoformat()
227+
})
205228

206229
# Step 3: Evaluation
207230
print(f"")
208231
print(f"{GREEN}Step 3: evaluation{RESET}")
232+
if websocket:
233+
await websocket.send_json({
234+
"type": "step_start",
235+
"step": 3,
236+
"step_name": "evaluation",
237+
"iteration": i + 1,
238+
"timestamp": datetime.utcnow().isoformat()
239+
})
209240
await self.evaluate_node(node)
210241

211242
print(f"{GREEN}Tree:{RESET}")
212243
better_print(self.root_node)
213244
print(f"")
245+
## send tree update, since evaluation is added to the tree
246+
if websocket:
247+
tree_data = self._get_tree_data()
248+
await websocket.send_json({
249+
"type": "tree_update",
250+
"tree": tree_data,
251+
"timestamp": datetime.utcnow().isoformat()
252+
})
253+
214254

215255
# Step 4: Simulation
216256
print(f"{GREEN}Step 4: simulation{RESET}")
217-
# # Find the child with the highest value
257+
if websocket:
258+
await websocket.send_json({
259+
"type": "step_start",
260+
"step": 4,
261+
"step_name": "simulation",
262+
"iteration": i + 1,
263+
"timestamp": datetime.utcnow().isoformat()
264+
})
218265
## always = 1
219-
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1)
266+
reward, terminal_node = await self.simulate(max(node.children, key=lambda child: child.value), max_depth=self.config.max_depth, num_simulations=1, websocket=websocket)
220267
terminal_nodes.append(terminal_node)
221268

222269
if reward == 1:
223270
return terminal_node
224271

225272
# Step 5: Backpropagation
226273
print(f"{GREEN}Step 5: backpropagation{RESET}")
274+
if websocket:
275+
await websocket.send_json({
276+
"type": "step_start",
277+
"step": 5,
278+
"step_name": "backpropagation",
279+
"timestamp": datetime.utcnow().isoformat()
280+
})
227281
self.backpropagate(terminal_node, reward)
228282
print(f"{GREEN}Tree:{RESET}")
229283
better_print(self.root_node)
@@ -335,7 +389,8 @@ async def evaluate_node(self, node: LATSNode) -> None:
335389
child.value = score
336390
child.reward = score
337391

338-
async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1) -> tuple[float, LATSNode]:
392+
## TODO: make number of simulations configurable
393+
async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1, websocket=None) -> tuple[float, LATSNode]:
339394
"""
340395
Perform a rollout simulation from a node.
341396
@@ -351,13 +406,39 @@ async def simulate(self, node: LATSNode, max_depth: int = 2, num_simulations=1)
351406
print_trajectory(node)
352407
print("print the entire tree")
353408
print_entire_tree(self.root_node)
354-
return await self.rollout(node, max_depth=max_depth)
409+
if websocket:
410+
tree_data = self._get_tree_data()
411+
await websocket.send_json({
412+
"type": "tree_update",
413+
"tree": tree_data,
414+
"timestamp": datetime.utcnow().isoformat()
415+
})
416+
trajectory_data = self._get_trajectory_data(node)
417+
await websocket.send_json({
418+
"type": "trajectory_update",
419+
"trajectory": trajectory_data,
420+
"timestamp": datetime.utcnow().isoformat()
421+
})
422+
return await self.rollout(node, max_depth=max_depth, websocket=websocket)
355423

356-
async def send_completion_request(self, plan, depth, node, trajectory=[]):
424+
async def send_completion_request(self, plan, depth, node, trajectory=[], websocket=None):
357425
print("print the trajectory")
358426
print_trajectory(node)
359427
print("print the entire tree")
360428
print_entire_tree(self.root_node)
429+
if websocket:
430+
# tree_data = self._get_tree_data()
431+
# await websocket.send_json({
432+
# "type": "tree_update",
433+
# "tree": tree_data,
434+
# "timestamp": datetime.utcnow().isoformat()
435+
# })
436+
trajectory_data = self._get_trajectory_data(node)
437+
await websocket.send_json({
438+
"type": "trajectory_update",
439+
"trajectory": trajectory_data,
440+
"timestamp": datetime.utcnow().isoformat()
441+
})
361442

362443
if depth >= self.config.max_depth:
363444
return trajectory, node
@@ -420,20 +501,20 @@ async def send_completion_request(self, plan, depth, node, trajectory=[]):
420501
if goal_finished:
421502
return trajectory, new_node
422503

423-
return await self.send_completion_request(plan, depth + 1, new_node, trajectory)
504+
return await self.send_completion_request(plan, depth + 1, new_node, trajectory, websocket)
424505

425506
except Exception as e:
426507
print(f"Attempt {attempt + 1} failed with error: {e}")
427508
if attempt + 1 == retry_count:
428509
print("Max retries reached. Skipping this step and retrying the whole request.")
429510
# Retry the entire request from the same state
430-
return await self.send_completion_request(plan, depth, node, trajectory)
511+
return await self.send_completion_request(plan, depth, node, trajectory, websocket)
431512

432513
# If all retries and retries of retries fail, return the current trajectory and node
433514
return trajectory, node
434515

435516

436-
async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSNode]:
517+
async def rollout(self, node: LATSNode, max_depth: int = 2, websocket=None)-> tuple[float, LATSNode]:
437518
# Reset browser state
438519
await self._reset_browser()
439520
path = self.get_path_to_root(node)
@@ -467,11 +548,24 @@ async def rollout(self, node: LATSNode, max_depth: int = 2)-> tuple[float, LATSN
467548
## call the prompt agent
468549
print("current depth: ", len(path) - 1)
469550
print("max depth: ", self.config.max_depth)
470-
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory)
551+
trajectory, node = await self.send_completion_request(self.goal, len(path) - 1, node=n, trajectory=trajectory, websocket=websocket)
471552
print("print the trajectory")
472553
print_trajectory(node)
473554
print("print the entire tree")
474555
print_entire_tree(self.root_node)
556+
if websocket:
557+
# tree_data = self._get_tree_data()
558+
# await websocket.send_json({
559+
# "type": "tree_update",
560+
# "tree": tree_data,
561+
# "timestamp": datetime.utcnow().isoformat()
562+
# })
563+
trajectory_data = self._get_trajectory_data(node)
564+
await websocket.send_json({
565+
"type": "trajectory_update",
566+
"trajectory": trajectory_data,
567+
"timestamp": datetime.utcnow().isoformat()
568+
})
475569

476570
page = await self.playwright_manager.get_page()
477571
page_info = await extract_page_info(page, self.config.fullpage, self.config.log_folder)
@@ -769,8 +863,46 @@ def _get_tree_data(self):
769863
"is_terminal": node.is_terminal,
770864
"value": node.value,
771865
"visits": node.visits,
866+
"feedback": node.feedback,
772867
"reward": node.reward
773868
}
774869
tree_data.append(node_data)
775870

776871
return tree_data
872+
873+
def _get_trajectory_data(self, terminal_node: LATSNode):
874+
"""Get trajectory data in a format suitable for visualization
875+
876+
Args:
877+
terminal_node: The leaf node to start the trajectory from
878+
879+
Returns:
880+
list: List of node data dictionaries representing the trajectory
881+
"""
882+
trajectory_data = []
883+
path = []
884+
885+
# Collect path from terminal to root
886+
current = terminal_node
887+
while current is not None:
888+
path.append(current)
889+
current = current.parent
890+
891+
# Process nodes in order from root to terminal
892+
for level, node in enumerate(reversed(path)):
893+
node_data = {
894+
"id": id(node),
895+
"level": level,
896+
"action": node.action if node.action else "ROOT",
897+
"description": node.natural_language_description,
898+
"visits": node.visits,
899+
"value": float(f"{node.value:.3f}") if hasattr(node, 'value') else None,
900+
"reward": float(f"{node.reward:.3f}") if hasattr(node, 'reward') else None,
901+
"is_terminal": node.is_terminal,
902+
"feedback": node.feedback if hasattr(node, 'feedback') else None,
903+
"is_root": not hasattr(node, 'parent') or node.parent is None,
904+
"is_terminal_node": node == terminal_node
905+
}
906+
trajectory_data.append(node_data)
907+
908+
return trajectory_data

0 commit comments

Comments
 (0)