Skip to content

Commit e1a75b1

Browse files
authored
Merge pull request #57 from PathOnAI/review-reflective-mcts
Review reflective mcts
2 parents bda4a67 + 8dc5d13 commit e1a75b1

File tree

9 files changed

+2075
-15
lines changed

9 files changed

+2075
-15
lines changed

visual-tree-search-backend/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,12 @@ python run_demo_treesearch_async.py \
9090
```
9191
uvicorn app.main:app --host 0.0.0.0 --port 3000
9292
python test/test-tree-search-ws-lats.py
93+
```
94+
95+
## 7. Add MCTS agent
96+
* test run_demo_treesearch_async.py
97+
* test web socket
98+
```
99+
uvicorn app.main:app --host 0.0.0.0 --port 3000
100+
python test/test-tree-search-ws-mcts.py
93101
```

visual-tree-search-backend/app/api/lwats/agents_async/SearchAgents/mcts_agent.py

Lines changed: 941 additions & 1 deletion
Large diffs are not rendered by default.

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/mcts_agent.py

Lines changed: 941 additions & 1 deletion
Large diffs are not rendered by default.

visual-tree-search-backend/app/api/lwats/agents_async/SimpleSearchAgents/simple_search_agent.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1146,4 +1146,3 @@ def _get_tree_data(self):
11461146
tree_data.append(node_data)
11471147

11481148
return tree_data
1149-

visual-tree-search-backend/app/api/routes/new_tree_search_websocket.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
128128
await agent.dfs_with_websocket(websocket)
129129
elif search_algorithm.lower() == "lats":
130130
await agent.run(websocket)
131+
elif search_algorithm.lower() == "mcts":
132+
await agent.run(websocket)
131133
else:
132134
await websocket.send_json({
133135
"type": "error",

visual-tree-search-backend/app/api/routes/tree_search_websocket.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ async def handle_search_request(websocket: WebSocket, message: Dict[str, Any]):
128128
await agent.dfs_with_websocket(websocket)
129129
elif search_algorithm.lower() == "lats":
130130
await agent.run(websocket)
131+
elif search_algorithm.lower() == "mcts":
132+
await agent.run(websocket)
131133
else:
132134
await websocket.send_json({
133135
"type": "error",

visual-tree-search-backend/app/api/shopping.json

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"value": "",
2525
"domain": "128.105.145.205",
2626
"path": "/",
27-
"expires": 1775272527,
27+
"expires": 1775370120,
2828
"httpOnly": false,
2929
"secure": false,
3030
"sameSite": "Strict"
@@ -81,20 +81,20 @@
8181
},
8282
{
8383
"name": "private_content_version",
84-
"value": "c514ad01bc9816ee30ab1f02510aa34a",
84+
"value": "ff4bba58081243f67b1adee2cc6974bd",
8585
"domain": "128.105.145.205",
8686
"path": "/",
87-
"expires": 1778296522.505323,
87+
"expires": 1778394118.522057,
8888
"httpOnly": false,
8989
"secure": false,
9090
"sameSite": "Lax"
9191
},
9292
{
9393
"name": "PHPSESSID",
94-
"value": "007c737fca0eb4173ab5362c2d9c8b09",
94+
"value": "30247306b1ad824f37f3e0384d86d991",
9595
"domain": "128.105.145.205",
9696
"path": "/",
97-
"expires": 1775272526.468528,
97+
"expires": 1775370122.385659,
9898
"httpOnly": true,
9999
"secure": false,
100100
"sameSite": "Lax"
@@ -104,17 +104,17 @@
104104
"value": "9bf9a599123e6402b85cde67144717a08b817412",
105105
"domain": "128.105.145.205",
106106
"path": "/",
107-
"expires": 1775272526.468754,
107+
"expires": 1775370122.385877,
108108
"httpOnly": true,
109109
"secure": false,
110110
"sameSite": "Lax"
111111
},
112112
{
113113
"name": "form_key",
114-
"value": "Hsr3n5ycGkOPfr1K",
114+
"value": "IEPmx1hKh4NWjeUa",
115115
"domain": "128.105.145.205",
116116
"path": "/",
117-
"expires": 1775272526.468692,
117+
"expires": 1775370122.385817,
118118
"httpOnly": false,
119119
"secure": false,
120120
"sameSite": "Lax"
@@ -124,17 +124,17 @@
124124
"value": "true",
125125
"domain": "128.105.145.205",
126126
"path": "/",
127-
"expires": 1775272527,
127+
"expires": 1775370120,
128128
"httpOnly": false,
129129
"secure": false,
130130
"sameSite": "Lax"
131131
},
132132
{
133133
"name": "section_data_ids",
134-
"value": "{%22messages%22:1743736525%2C%22customer%22:1743736525%2C%22compare-products%22:1743736525%2C%22last-ordered-items%22:1743736525%2C%22cart%22:1743736525%2C%22directory-data%22:1743736525%2C%22captcha%22:1743736525%2C%22instant-purchase%22:1743736525%2C%22loggedAsCustomer%22:1743736525%2C%22persistent%22:1743736525%2C%22review%22:1743736525%2C%22wishlist%22:1743736525%2C%22recently_viewed_product%22:1743736525%2C%22recently_compared_product%22:1743736525%2C%22product_data_storage%22:1743736525%2C%22paypal-billing-agreement%22:1743736525}",
134+
"value": "{%22messages%22:1743834120%2C%22customer%22:1743834120%2C%22compare-products%22:1743834120%2C%22last-ordered-items%22:1743834120%2C%22cart%22:1743834120%2C%22directory-data%22:1743834120%2C%22captcha%22:1743834120%2C%22instant-purchase%22:1743834120%2C%22loggedAsCustomer%22:1743834120%2C%22persistent%22:1743834120%2C%22review%22:1743834120%2C%22wishlist%22:1743834120%2C%22recently_viewed_product%22:1743834120%2C%22recently_compared_product%22:1743834120%2C%22product_data_storage%22:1743834120%2C%22paypal-billing-agreement%22:1743834120}",
135135
"domain": "128.105.145.205",
136136
"path": "/",
137-
"expires": 1775272524,
137+
"expires": 1775370120,
138138
"httpOnly": false,
139139
"secure": false,
140140
"sameSite": "Lax"

visual-tree-search-backend/test/test-tree-search-ws-lats.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def parse_arguments():
138138
parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
139139
help=f"Goal to achieve (default: {DEFAULT_GOAL})")
140140

141-
parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats"], default="lats",
141+
parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="lats",
142142
help="Search algorithm to use (default: lats)")
143143

144144
parser.add_argument("--max-depth", type=int, default=3,
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import asyncio
2+
import json
3+
import websockets
4+
import argparse
5+
import logging
6+
from datetime import datetime
7+
8+
# Configure logging
9+
logging.basicConfig(
10+
level=logging.INFO,
11+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12+
)
13+
logger = logging.getLogger(__name__)
14+
15+
# Default values
16+
DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws"
17+
DEFAULT_STARTING_URL = "http://128.105.145.205:7770/"
18+
DEFAULT_GOAL = "search running shoes, click on the first result"
19+
20+
async def connect_and_test_search(
21+
ws_url: str,
22+
starting_url: str,
23+
goal: str,
24+
search_algorithm: str = "bfs",
25+
max_depth: int = 3
26+
):
27+
"""
28+
Connect to the WebSocket endpoint and test the tree search functionality.
29+
30+
Args:
31+
ws_url: WebSocket URL to connect to
32+
starting_url: URL to start the search from
33+
goal: Goal to achieve
34+
search_algorithm: Search algorithm to use (bfs or dfs)
35+
max_depth: Maximum depth for the search tree
36+
"""
37+
logger.info(f"Connecting to WebSocket at {ws_url}")
38+
39+
async with websockets.connect(ws_url) as websocket:
40+
logger.info("Connected to WebSocket")
41+
42+
# Wait for connection established message
43+
response = await websocket.recv()
44+
data = json.loads(response)
45+
if data.get("type") == "connection_established":
46+
logger.info(f"Connection established with ID: {data.get('connection_id')}")
47+
48+
# Send search request
49+
request = {
50+
"type": "start_search",
51+
"agent_type": "MCTSAgent",
52+
"starting_url": starting_url,
53+
"goal": goal,
54+
"search_algorithm": search_algorithm,
55+
"max_depth": max_depth
56+
}
57+
58+
logger.info(f"Sending search request: {request}")
59+
await websocket.send(json.dumps(request))
60+
61+
# Process responses
62+
while True:
63+
try:
64+
response = await websocket.recv()
65+
data = json.loads(response)
66+
67+
# Log the message type and some key information
68+
msg_type = data.get("type", "unknown")
69+
70+
if msg_type == "status_update":
71+
logger.info(f"Status update: {data.get('status')} - {data.get('message')}")
72+
73+
elif msg_type == "iteration_start":
74+
logger.info(f"Iteration start: {data.get('iteration')}")
75+
76+
elif msg_type == "step_start":
77+
logger.info(f"Step start: {data.get('step')} - {data.get('step_name')}")
78+
79+
elif msg_type == "node_update":
80+
node_id = data.get("node_id")
81+
status = data.get("status")
82+
logger.info(f"Node update: {node_id} - {status}")
83+
84+
# If node was scored, log the score
85+
if status == "scored":
86+
logger.info(f"Node score: {data.get('score')}")
87+
88+
elif msg_type == "trajectory_update":
89+
logger.info(f"Trajectory update received with {data.get('trajectory')}")
90+
91+
elif msg_type == "tree_update":
92+
logger.info(f"Tree update received with {data.get('tree')}")
93+
94+
elif msg_type == "best_path_update":
95+
logger.info(f"Best path update: score={data.get('score')}, path length={len(data.get('path', []))}")
96+
97+
elif msg_type == "search_complete":
98+
status = data.get("status")
99+
score = data.get("score", "N/A")
100+
path_length = len(data.get("path", []))
101+
102+
logger.info(f"Search complete: {status}, score={score}, path length={path_length}")
103+
logger.info("Path actions:")
104+
105+
for i, node in enumerate(data.get("path", [])):
106+
logger.info(f" {i+1}. {node.get('action')}")
107+
108+
# Exit the loop when search is complete
109+
break
110+
111+
elif msg_type == "error":
112+
logger.error(f"Error: {data.get('message')}")
113+
break
114+
115+
else:
116+
logger.info(f"Received message of type {msg_type}")
117+
logger.info(f"Message: {data}")
118+
119+
except websockets.exceptions.ConnectionClosed:
120+
logger.warning("WebSocket connection closed")
121+
break
122+
except Exception as e:
123+
logger.error(f"Error processing message: {e}")
124+
break
125+
126+
logger.info("Test completed")
127+
128+
def parse_arguments():
129+
"""Parse command line arguments"""
130+
parser = argparse.ArgumentParser(description="Test the tree search WebSocket functionality")
131+
132+
parser.add_argument("--ws-url", type=str, default=DEFAULT_WS_URL,
133+
help=f"WebSocket URL (default: {DEFAULT_WS_URL})")
134+
135+
parser.add_argument("--starting-url", type=str, default=DEFAULT_STARTING_URL,
136+
help=f"Starting URL for the search (default: {DEFAULT_STARTING_URL})")
137+
138+
parser.add_argument("--goal", type=str, default=DEFAULT_GOAL,
139+
help=f"Goal to achieve (default: {DEFAULT_GOAL})")
140+
141+
parser.add_argument("--algorithm", type=str, choices=["bfs", "dfs", "lats", "mcts"], default="mcts",
142+
help="Search algorithm to use (default: lats)")
143+
144+
parser.add_argument("--max-depth", type=int, default=3,
145+
help="Maximum depth for the search tree (default: 3)")
146+
147+
return parser.parse_args()
148+
149+
async def main():
150+
"""Main entry point"""
151+
args = parse_arguments()
152+
153+
logger.info("Starting tree search WebSocket test")
154+
logger.info(f"WebSocket URL: {args.ws_url}")
155+
logger.info(f"Starting URL: {args.starting_url}")
156+
logger.info(f"Goal: {args.goal}")
157+
logger.info(f"Algorithm: {args.algorithm}")
158+
logger.info(f"Max depth: {args.max_depth}")
159+
160+
await connect_and_test_search(
161+
ws_url=args.ws_url,
162+
starting_url=args.starting_url,
163+
goal=args.goal,
164+
search_algorithm=args.algorithm,
165+
max_depth=args.max_depth
166+
)
167+
168+
if __name__ == "__main__":
169+
asyncio.run(main())

0 commit comments

Comments
 (0)