Skip to content

Commit 39a8094

Browse files
SyHeeeTataKKKL
authored andcommitted
adding rmcts
1 parent ceb6598 commit 39a8094

File tree

2 files changed

+239
-219
lines changed

2 files changed

+239
-219
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py

Lines changed: 98 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from ...webagent_utils_async.utils.utils import urls_to_images
2828

2929
logger = logging.getLogger(__name__)
30+
logger.setLevel(logging.INFO)
3031
openai_client = OpenAI()
3132

3233
class MCTSAgent:
@@ -183,6 +184,23 @@ async def rmcts(self) -> List[Dict[str, Any]]:
183184
except Exception as e:
184185
logger.error(f"Error expanding node: {str(e)}")
185186
current_node.is_terminal = True
187+
# Expansion Step: Expand the selected node if possible
188+
if not current_node.is_terminal and current_node.depth < self.config.max_depth:
189+
logger.info(f"\nExpansion Step:")
190+
logger.info(f"Expanding node: {current_node.action}")
191+
192+
expansion_success = await self.expand(current_node, None)
193+
if not expansion_success:
194+
# No children were generated; backtrack if possible.
195+
if len(path) > 1:
196+
logger.info("Backtracking due to expansion failure (no children generated).")
197+
path.pop() # Remove the current dead-end node.
198+
current_node = path[-1] # Set current_node to its parent.
199+
else:
200+
logger.warning("Expansion failed at root; no further backtracking possible.")
201+
break
202+
else:
203+
logger.info(f"Successfully expanded node with {len(current_node.children)} children")
186204

187205
# Simulation: Evaluate the current path
188206
logger.info(f"\nSimulation Step:")
@@ -217,8 +235,8 @@ async def rmcts(self) -> List[Dict[str, Any]]:
217235
logger.info(f"New best score: {score:.3f}")
218236

219237
# Reflection-based backpropagation
220-
if score < 0.25: # If the path is not satisfactory
221-
logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):")
238+
if score < 0.75: # If the path is not satisfactory
239+
logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):")
222240

223241
# Generate reflection prompt
224242
reflection_prompt = f"""Analyze the current trajectory and suggest improvements.
@@ -265,10 +283,10 @@ async def rmcts(self) -> List[Dict[str, Any]]:
265283
except Exception as e:
266284
logger.error(f"Error in reflection: {str(e)}")
267285

268-
# # If we've found a satisfactory solution, return it
269-
# if score >= 0.75:
270-
# logger.info(f"\nFound satisfactory solution with score {score:.3f}")
271-
# return [{"action": node.action} for node in path[1:]]
286+
# If we've found a satisfactory solution, return it
287+
if score >= 0.75:
288+
logger.info(f"\nFound satisfactory solution with score {score:.3f}")
289+
return [{"action": node.action} for node in path[1:]]
272290

273291
except Exception as e:
274292
logger.error(f"Error in simulation: {str(e)}")
@@ -286,13 +304,13 @@ async def rmcts(self) -> List[Dict[str, Any]]:
286304

287305
# If we've exhausted all iterations and haven't found a perfect solution,
288306
# return the best path we found
289-
if best_path:
307+
if best_path and len(best_path) > 1:
290308
logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}")
291309
return [{"action": node.action} for node in best_path[1:]]
292-
293-
# If no path was found at all
294-
logger.warning("\nNo valid path found")
295-
return []
310+
311+
# If no valid path was found or path was just the root, return a default action
312+
logger.warning("\nNo valid path found, returning fallback action")
313+
return [{"action": "refresh()", "description": "Fallback action - no valid path found"}]
296314

297315
except Exception as e:
298316
error_msg = f"Error in RMCTS search: {str(e)}"
@@ -500,8 +518,8 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]:
500518
})
501519

502520
# Reflection-based backpropagation
503-
if score < 0.25: # If the path is not satisfactory
504-
logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):")
521+
if score < 0.75: # If the path is not satisfactory
522+
logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):")
505523

506524
await websocket.send_json({
507525
"type": "reflection_start",
@@ -568,20 +586,20 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]:
568586
"timestamp": datetime.utcnow().isoformat()
569587
})
570588

571-
# # If we've found a satisfactory solution, return it
572-
# if score >= 0.75:
573-
# logger.info(f"\nFound satisfactory solution with score {score:.3f}")
589+
# If we've found a satisfactory solution, return it
590+
if score >= 0.75:
591+
logger.info(f"\nFound satisfactory solution with score {score:.3f}")
574592

575-
# # Send completion update if websocket is provided
576-
# await websocket.send_json({
577-
# "type": "search_complete",
578-
# "status": "success",
579-
# "score": score,
580-
# "path": [{"id": id(node), "action": node.action} for node in path[1:]],
581-
# "timestamp": datetime.utcnow().isoformat()
582-
# })
593+
# Send completion update if websocket is provided
594+
await websocket.send_json({
595+
"type": "search_complete",
596+
"status": "success",
597+
"score": score,
598+
"path": [{"id": id(node), "action": node.action} for node in path[1:]],
599+
"timestamp": datetime.utcnow().isoformat()
600+
})
583601

584-
# return [{"action": node.action} for node in path[1:]]
602+
return [{"action": node.action} for node in path[1:]]
585603

586604
except Exception as e:
587605
logger.error(f"Error in simulation: {str(e)}")
@@ -611,7 +629,7 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]:
611629

612630
# If we've exhausted all iterations and haven't found a perfect solution,
613631
# return the best path we found
614-
if best_path:
632+
if best_path and len(best_path) > 1:
615633
logger.info(f"\nSearch complete. Returning best path found with score {best_score:.3f}")
616634

617635
# Send completion update if websocket is provided
@@ -635,8 +653,10 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]:
635653
"message": "No valid path found",
636654
"timestamp": datetime.utcnow().isoformat()
637655
})
638-
639-
return []
656+
657+
# If no valid path was found or path was just the root, return a default action
658+
logger.warning("\nNo valid path found, returning fallback action")
659+
return [{"action": "refresh()", "description": "Fallback action - no valid path found"}]
640660

641661
except Exception as e:
642662
error_msg = f"Error in RMCTS search: {str(e)}"
@@ -753,48 +773,29 @@ async def _reset_browser(self, websocket=None) -> Optional[tuple]:
753773
})
754774
return None, None
755775

756-
async def expand(self, node: LATSNode, websocket=None) -> None:
776+
async def expand(self, node: LATSNode, websocket=None) -> bool:
757777
"""
758-
Expand a node by generating its children.
778+
Expand a node by generating its children. If no children are generated,
779+
mark the node as terminal and return False to trigger backtracking.
759780
760781
Args:
761-
node: Node to expand
762-
websocket: Optional WebSocket connection to send updates to
782+
node: Node to expand.
783+
websocket: Optional WebSocket connection to send updates.
784+
785+
Returns:
786+
bool: True if expansion succeeded (children generated), False otherwise.
763787
"""
764788
try:
765789
children_state = await self.generate_children(node, websocket)
766-
logger.info(f"Generated {len(children_state)} children for node: {node.action}")
767790
except Exception as e:
768791
logger.error(f"Exception during generation of children for node {node.action}: {e}")
769792
children_state = []
770-
771-
# if not children_state:
772-
# logger.warning(f"No valid children found for node: {node.action}")
773-
# # Mark the node as terminal but don't halt the entire search
774-
# node.is_terminal = True
775-
# return
793+
776794
if not children_state:
777-
logger.warning("No valid children returned, creating fallback children")
778-
children_state = [
779-
{
780-
"natural_language_description": "Navigate back to try a different approach",
781-
"action": "navigate_backward()",
782-
"prob": 0.15,
783-
"element": None
784-
},
785-
{
786-
"natural_language_description": "Refresh the page to reinitialize search",
787-
"action": "refresh()",
788-
"prob": 0.1,
789-
"element": None
790-
},
791-
{
792-
"natural_language_description": "Click a random element for exploration",
793-
"action": "click('random')",
794-
"prob": 0.05,
795-
"element": None
796-
}
797-
]
795+
logger.warning("No children generated. Marking node as terminal and triggering backtracking.")
796+
node.is_terminal = True
797+
return False # Indicate that expansion did not generate children.
798+
798799
for child_state in children_state:
799800
try:
800801
child = LATSNode(
@@ -818,6 +819,7 @@ async def expand(self, node: LATSNode, websocket=None) -> None:
818819
})
819820
except Exception as e:
820821
logger.error(f"Error creating child node from state {child_state}: {e}")
822+
return True # Expansion succeeded (children were generated).
821823

822824
async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
823825
"""
@@ -833,7 +835,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
833835
# Reset browser and get live URL
834836
live_browser_url, session_id = await self._reset_browser(websocket)
835837
path = self.get_path_to_root(node)
836-
838+
logger.info(f"######### Generating children for path with {len(path)} nodes")
837839
# Execute path
838840
for n in path[1:]: # Skip root node
839841
if websocket:
@@ -843,23 +845,41 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
843845
"action": n.action,
844846
"timestamp": datetime.utcnow().isoformat()
845847
})
846-
847-
success = await playwright_step_execution(
848-
n,
849-
self.goal,
850-
self.playwright_manager,
851-
is_replay=False,
852-
log_folder=self.config.log_folder
853-
)
854-
if not success:
855-
n.is_terminal = True
856-
if websocket:
857-
await websocket.send_json({
858-
"type": "replay_failed",
859-
"node_id": id(n),
860-
"timestamp": datetime.utcnow().isoformat()
861-
})
862-
return []
848+
try:
849+
success = await playwright_step_execution(
850+
n,
851+
self.goal,
852+
self.playwright_manager,
853+
is_replay=False,
854+
log_folder=self.config.log_folder
855+
)
856+
logger.info(f"#########Success: {success}")
857+
858+
if not success:
859+
logger.warning(f"Action execution failed: {n.action}")
860+
n.is_terminal = True
861+
if websocket:
862+
await websocket.send_json({
863+
"type": "replay_failed",
864+
"node_id": id(n),
865+
"timestamp": datetime.utcnow().isoformat()
866+
})
867+
return [{
868+
"natural_language_description": "Recover from failed action",
869+
"action": "refresh()",
870+
"prob": 0.1,
871+
"element": None
872+
}]
873+
except Exception as e:
874+
logger.error(f"Error executing action {n.action}: {str(e)}")
875+
# Provide fallback actions instead of bubbling up the exception
876+
return [{
877+
"natural_language_description": "Recover from action error",
878+
"action": "refresh()",
879+
"prob": 0.1,
880+
"element": None
881+
}]
882+
863883

864884
if not n.feedback:
865885
n.feedback = await generate_feedback(

0 commit comments

Comments
 (0)