1
+ import asyncio
2
+ import json
3
+ import websockets
4
+ import argparse
5
+ import logging
6
+ from datetime import datetime
7
+
8
+ # Configure logging
9
+ logging .basicConfig (
10
+ level = logging .INFO ,
11
+ format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
12
+ )
13
+ logger = logging .getLogger (__name__ )
14
+
15
+ # Default values
16
+ DEFAULT_WS_URL = "ws://localhost:3000/new-tree-search-ws"
17
+ DEFAULT_STARTING_URL = "http://128.105.145.205:7770/"
18
+ DEFAULT_GOAL = "search running shoes, click on the first result"
19
+
20
+ async def connect_and_test_search (
21
+ ws_url : str ,
22
+ starting_url : str ,
23
+ goal : str ,
24
+ search_algorithm : str = "bfs" ,
25
+ max_depth : int = 3
26
+ ):
27
+ """
28
+ Connect to the WebSocket endpoint and test the tree search functionality.
29
+
30
+ Args:
31
+ ws_url: WebSocket URL to connect to
32
+ starting_url: URL to start the search from
33
+ goal: Goal to achieve
34
+ search_algorithm: Search algorithm to use (bfs or dfs)
35
+ max_depth: Maximum depth for the search tree
36
+ """
37
+ logger .info (f"Connecting to WebSocket at { ws_url } " )
38
+
39
+ async with websockets .connect (ws_url ) as websocket :
40
+ logger .info ("Connected to WebSocket" )
41
+
42
+ # Wait for connection established message
43
+ response = await websocket .recv ()
44
+ data = json .loads (response )
45
+ if data .get ("type" ) == "connection_established" :
46
+ logger .info (f"Connection established with ID: { data .get ('connection_id' )} " )
47
+
48
+ # Send search request
49
+ request = {
50
+ "type" : "start_search" ,
51
+ "agent_type" : "MCTSAgent" ,
52
+ "starting_url" : starting_url ,
53
+ "goal" : goal ,
54
+ "search_algorithm" : search_algorithm ,
55
+ "max_depth" : max_depth
56
+ }
57
+
58
+ logger .info (f"Sending search request: { request } " )
59
+ await websocket .send (json .dumps (request ))
60
+
61
+ # Process responses
62
+ while True :
63
+ try :
64
+ response = await websocket .recv ()
65
+ data = json .loads (response )
66
+
67
+ # Log the message type and some key information
68
+ msg_type = data .get ("type" , "unknown" )
69
+
70
+ if msg_type == "status_update" :
71
+ logger .info (f"Status update: { data .get ('status' )} - { data .get ('message' )} " )
72
+
73
+ elif msg_type == "iteration_start" :
74
+ logger .info (f"Iteration start: { data .get ('iteration' )} " )
75
+
76
+ elif msg_type == "step_start" :
77
+ logger .info (f"Step start: { data .get ('step' )} - { data .get ('step_name' )} " )
78
+
79
+ elif msg_type == "node_update" :
80
+ node_id = data .get ("node_id" )
81
+ status = data .get ("status" )
82
+ logger .info (f"Node update: { node_id } - { status } " )
83
+
84
+ # If node was scored, log the score
85
+ if status == "scored" :
86
+ logger .info (f"Node score: { data .get ('score' )} " )
87
+
88
+ elif msg_type == "trajectory_update" :
89
+ logger .info (f"Trajectory update received with { data .get ('trajectory' )} " )
90
+
91
+ elif msg_type == "tree_update" :
92
+ logger .info (f"Tree update received with { data .get ('tree' )} " )
93
+
94
+ elif msg_type == "best_path_update" :
95
+ logger .info (f"Best path update: score={ data .get ('score' )} , path length={ len (data .get ('path' , []))} " )
96
+
97
+ elif msg_type == "search_complete" :
98
+ status = data .get ("status" )
99
+ score = data .get ("score" , "N/A" )
100
+ path_length = len (data .get ("path" , []))
101
+
102
+ logger .info (f"Search complete: { status } , score={ score } , path length={ path_length } " )
103
+ logger .info ("Path actions:" )
104
+
105
+ for i , node in enumerate (data .get ("path" , [])):
106
+ logger .info (f" { i + 1 } . { node .get ('action' )} " )
107
+
108
+ # Exit the loop when search is complete
109
+ break
110
+
111
+ elif msg_type == "error" :
112
+ logger .error (f"Error: { data .get ('message' )} " )
113
+ break
114
+
115
+ else :
116
+ logger .info (f"Received message of type { msg_type } " )
117
+ logger .info (f"Message: { data } " )
118
+
119
+ except websockets .exceptions .ConnectionClosed :
120
+ logger .warning ("WebSocket connection closed" )
121
+ break
122
+ except Exception as e :
123
+ logger .error (f"Error processing message: { e } " )
124
+ break
125
+
126
+ logger .info ("Test completed" )
127
+
128
+ def parse_arguments ():
129
+ """Parse command line arguments"""
130
+ parser = argparse .ArgumentParser (description = "Test the tree search WebSocket functionality" )
131
+
132
+ parser .add_argument ("--ws-url" , type = str , default = DEFAULT_WS_URL ,
133
+ help = f"WebSocket URL (default: { DEFAULT_WS_URL } )" )
134
+
135
+ parser .add_argument ("--starting-url" , type = str , default = DEFAULT_STARTING_URL ,
136
+ help = f"Starting URL for the search (default: { DEFAULT_STARTING_URL } )" )
137
+
138
+ parser .add_argument ("--goal" , type = str , default = DEFAULT_GOAL ,
139
+ help = f"Goal to achieve (default: { DEFAULT_GOAL } )" )
140
+
141
+ parser .add_argument ("--algorithm" , type = str , choices = ["bfs" , "dfs" , "lats" , "mcts" ], default = "mcts" ,
142
+ help = "Search algorithm to use (default: lats)" )
143
+
144
+ parser .add_argument ("--max-depth" , type = int , default = 3 ,
145
+ help = "Maximum depth for the search tree (default: 3)" )
146
+
147
+ return parser .parse_args ()
148
+
149
+ async def main ():
150
+ """Main entry point"""
151
+ args = parse_arguments ()
152
+
153
+ logger .info ("Starting tree search WebSocket test" )
154
+ logger .info (f"WebSocket URL: { args .ws_url } " )
155
+ logger .info (f"Starting URL: { args .starting_url } " )
156
+ logger .info (f"Goal: { args .goal } " )
157
+ logger .info (f"Algorithm: { args .algorithm } " )
158
+ logger .info (f"Max depth: { args .max_depth } " )
159
+
160
+ await connect_and_test_search (
161
+ ws_url = args .ws_url ,
162
+ starting_url = args .starting_url ,
163
+ goal = args .goal ,
164
+ search_algorithm = args .algorithm ,
165
+ max_depth = args .max_depth
166
+ )
167
+
168
+ if __name__ == "__main__" :
169
+ asyncio .run (main ())
0 commit comments