Skip to content

Commit ceb6598

Browse files
SyHeeeTataKKKL
authored andcommitted
checkin buggy algo
1 parent 8c39e14 commit ceb6598

File tree

1 file changed

+90
-48
lines changed
  • visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents

1 file changed

+90
-48
lines changed

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py

Lines changed: 90 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ async def rmcts(self) -> List[Dict[str, Any]]:
217217
logger.info(f"New best score: {score:.3f}")
218218

219219
# Reflection-based backpropagation
220-
if score < 0.75: # If the path is not satisfactory
221-
logger.info(f"\nReflection Step (Score {score:.3f} < 0.5):")
220+
if score < 0.25: # If the path is not satisfactory
221+
logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):")
222222

223223
# Generate reflection prompt
224224
reflection_prompt = f"""Analyze the current trajectory and suggest improvements.
@@ -500,8 +500,8 @@ async def rmcts_with_websocket(self, websocket) -> List[Dict[str, Any]]:
500500
})
501501

502502
# Reflection-based backpropagation
503-
if score < 0.75: # If the path is not satisfactory
504-
logger.info(f"\nReflection Step (Score {score:.3f} < 0.75):")
503+
if score < 0.25: # If the path is not satisfactory
504+
logger.info(f"\nReflection Step (Score {score:.3f} < 0.25):")
505505

506506
await websocket.send_json({
507507
"type": "reflection_start",
@@ -761,35 +761,63 @@ async def expand(self, node: LATSNode, websocket=None) -> None:
761761
node: Node to expand
762762
websocket: Optional WebSocket connection to send updates to
763763
"""
764-
children_state = await self.generate_children(node, websocket)
765-
logger.info(f"Generated {len(children_state)} children for node: {node.action}")
764+
try:
765+
children_state = await self.generate_children(node, websocket)
766+
logger.info(f"Generated {len(children_state)} children for node: {node.action}")
767+
except Exception as e:
768+
logger.error(f"Exception during generation of children for node {node.action}: {e}")
769+
children_state = []
770+
771+
# if not children_state:
772+
# logger.warning(f"No valid children found for node: {node.action}")
773+
# # Mark the node as terminal but don't halt the entire search
774+
# node.is_terminal = True
775+
# return
766776
if not children_state:
767-
logger.warning(f"No valid children found for node: {node.action}")
768-
# Mark the node as terminal but don't halt the entire search
769-
node.is_terminal = True
770-
return
771-
777+
logger.warning("No valid children returned, creating fallback children")
778+
children_state = [
779+
{
780+
"natural_language_description": "Navigate back to try a different approach",
781+
"action": "navigate_backward()",
782+
"prob": 0.15,
783+
"element": None
784+
},
785+
{
786+
"natural_language_description": "Refresh the page to reinitialize search",
787+
"action": "refresh()",
788+
"prob": 0.1,
789+
"element": None
790+
},
791+
{
792+
"natural_language_description": "Click a random element for exploration",
793+
"action": "click('random')",
794+
"prob": 0.05,
795+
"element": None
796+
}
797+
]
772798
for child_state in children_state:
773-
child = LATSNode(
774-
natural_language_description=child_state["natural_language_description"],
775-
action=child_state["action"],
776-
prob=child_state["prob"],
777-
element=child_state["element"],
778-
goal=node.goal,
779-
parent=node
780-
)
781-
node.children.append(child)
782-
783-
# Send child creation update if websocket is provided
784-
if websocket:
785-
await websocket.send_json({
786-
"type": "node_created",
787-
"node_id": id(child),
788-
"parent_id": id(node),
789-
"action": child.action,
790-
"description": child.natural_language_description,
791-
"timestamp": datetime.utcnow().isoformat()
792-
})
799+
try:
800+
child = LATSNode(
801+
natural_language_description=child_state.get("natural_language_description", ""),
802+
action=child_state.get("action", ""),
803+
prob=child_state.get("prob", 0.0),
804+
element=child_state.get("element", None),
805+
goal=node.goal,
806+
parent=node
807+
)
808+
node.children.append(child)
809+
810+
if websocket:
811+
await websocket.send_json({
812+
"type": "node_created",
813+
"node_id": id(child),
814+
"parent_id": id(node),
815+
"action": child.action,
816+
"description": child.natural_language_description,
817+
"timestamp": datetime.utcnow().isoformat()
818+
})
819+
except Exception as e:
820+
logger.error(f"Error creating child node from state {child_state}: {e}")
793821

794822
async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
795823
"""
@@ -880,7 +908,7 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
880908
for action in next_actions:
881909
if action["action"] == "FINISH":
882910
logger.info(f"Found FINISH action with probability: {action['prob']}")
883-
if action["prob"] > 0.8:
911+
if action["prob"] > 0.99:
884912
node.is_terminal = True
885913
if websocket:
886914
await websocket.send_json({
@@ -916,24 +944,38 @@ async def generate_children(self, node: LATSNode, websocket=None) -> list[dict]:
916944
children.append(action)
917945

918946
if not children:
919-
node.is_terminal = True
920-
if websocket:
921-
await websocket.send_json({
922-
"type": "node_terminal",
923-
"node_id": id(node),
924-
"reason": "no_valid_actions",
925-
"timestamp": datetime.utcnow().isoformat()
926-
})
927-
logger.warning("No children generated")
928-
# logger.warning("No children generated, creating a dummy 'retry' child to keep search alive")
947+
# node.is_terminal = True
948+
# if websocket:
949+
# await websocket.send_json({
950+
# "type": "node_terminal",
951+
# "node_id": id(node),
952+
# "reason": "no_valid_actions",
953+
# "timestamp": datetime.utcnow().isoformat()
954+
# })
955+
# logger.warning("No children generated")
956+
logger.warning("No viable children, creating fallback exploration actions")
929957

930958
# # If empty list would terminate search, create a "fallback" child
931-
# children.append({
932-
# "natural_language_description": "Retry with different approach",
933-
# "action": "refresh()", # Or some other generic action
934-
# "prob": 0.1,
935-
# "element": None
936-
# })
959+
children.extend([
960+
{
961+
"natural_language_description": "Navigate back to try a different approach",
962+
"action": "navigate_backward()",
963+
"prob": 0.15,
964+
"element": None
965+
},
966+
{
967+
"natural_language_description": "Try refreshing the page",
968+
"action": "refresh()",
969+
"prob": 0.1,
970+
"element": None
971+
},
972+
{
973+
"natural_language_description": "Try clicking on a random element",
974+
"action": "click('random')",
975+
"prob": 0.05,
976+
"element": None
977+
}
978+
])
937979
print(f"****** Generated children: {children}")
938980
return children
939981

0 commit comments

Comments
 (0)