77# pyre-unsafe
88
99import pickle
10- from typing import Any , Callable , Coroutine , Iterable , List , TYPE_CHECKING
10+ from typing import Any , Callable , cast , Coroutine , Iterable , List , TYPE_CHECKING
1111
1212import monarch
1313import pytest
@@ -57,6 +57,9 @@ async def allocate() -> ProcMesh:
5757
5858
5959class MyActor :
60+ def __init__ (self ) -> None :
61+ self ._root_rank : int = - 1
62+
6063 async def handle (
6164 self ,
6265 mailbox : Mailbox ,
@@ -68,8 +71,19 @@ async def handle(
6871 local_state : Iterable [Any ],
6972 response_port : "PortProtocol[Any]" ,
7073 ) -> None :
71- assert rank is not None
72- response_port .send (f"rank: { rank } " )
74+ match method :
75+ case MethodSpecifier .Init ():
76+ self ._root_rank = rank
77+ response_port .send (None )
78+ return None
79+ case MethodSpecifier .ReturnsResponse (name = _):
80+ response_port .send (self ._root_rank )
81+ return None
82+ case MethodSpecifier .ExplicitPort (name = _):
83+ response_port .exception (
84+ NotImplementedError ("ExplicitPort is not supported yet" )
85+ )
86+ return None
7387
7488
7589# TODO - re-enable after resolving T232206970
@@ -95,35 +109,70 @@ async def run() -> None:
95109 run ()
96110
97111
98- async def verify_cast (
112+ async def spawn_actor_mesh (proc_mesh : ProcMesh ) -> PythonActorMesh :
113+ actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
114+ # init actors to record their root ranks
115+ receiver : PortReceiver
116+ handle , receiver = proc_mesh .client .open_port ()
117+ port_ref = handle .bind ()
118+
119+ message = PythonMessage (
120+ PythonMessageKind .CallMethod (MethodSpecifier .Init (), port_ref ),
121+ pickle .dumps (None ),
122+ )
123+ actor_mesh .cast (Selection .all (), message )
124+ # wait for init to complete
125+ for _ in range (len (actor_mesh .shape .ndslice )):
126+ await receiver .recv_task ()
127+
128+ return actor_mesh
129+
130+
131+ async def cast_to_call (
132+ actor_mesh : PythonActorMesh | PythonActorMeshRef ,
133+ mailbox : Mailbox ,
134+ message : PythonMessage ,
135+ ) -> None :
136+ sel = Selection .all ()
137+ if isinstance (actor_mesh , PythonActorMesh ):
138+ actor_mesh .cast (sel , message )
139+ elif isinstance (actor_mesh , PythonActorMeshRef ):
140+ actor_mesh .cast (mailbox , sel , message )
141+
142+
143+ async def verify_cast_to_call (
99144 actor_mesh : PythonActorMesh | PythonActorMeshRef ,
100145 mailbox : Mailbox ,
101- cast_ranks : List [int ],
146+ root_ranks : List [int ],
102147) -> None :
103148 receiver : PortReceiver
104149 handle , receiver = mailbox .open_port ()
105150 port_ref = handle .bind ()
106151
152+ # Now send the real message
107153 message = PythonMessage (
108154 PythonMessageKind .CallMethod (MethodSpecifier .ReturnsResponse ("echo" ), port_ref ),
109155 pickle .dumps ("ping" ),
110156 )
111- sel = Selection .from_string ("*" )
112- if isinstance (actor_mesh , PythonActorMesh ):
113- actor_mesh .cast (sel , message )
114- elif isinstance (actor_mesh , PythonActorMeshRef ):
115- actor_mesh .cast (mailbox , sel , message )
157+ await cast_to_call (actor_mesh , mailbox , message )
116158
117159 rcv_ranks = []
118- for _ in range (len (cast_ranks )):
160+ for _ in range (len (root_ranks )):
119161 message = await receiver .recv_task ()
120162 result_kind = message .kind
121163 assert isinstance (result_kind , PythonMessageKind .Result )
122- rank = result_kind .rank
123- assert rank is not None
124- rcv_ranks .append (rank )
125- rcv_ranks .sort ()
126- assert rcv_ranks == cast_ranks
164+ cast_rank = result_kind .rank
165+ assert cast_rank is not None
166+ root_rank = cast (int , pickle .loads (message .message ))
167+ rcv_ranks .append ((cast_rank , root_rank ))
168+ rcv_ranks .sort (key = lambda pair : pair [0 ])
169+ recv_cast_ranks , recv_root_ranks = zip (* rcv_ranks )
170+ assert recv_root_ranks == tuple (
171+ root_ranks
172+ ), f"recv_root_ranks={ recv_root_ranks } , root_ranks={ tuple (root_ranks )} "
173+ assert recv_cast_ranks == tuple (
174+ range (len (root_ranks ))
175+ ), f"recv_cast_ranks={ recv_cast_ranks } , root_ranks={ tuple (root_ranks )} "
127176 # verify no more messages are received
128177 with pytest .raises (TimeoutError ):
129178 await receiver .recv_task ().with_timeout (1 )
@@ -136,8 +185,8 @@ async def test_cast_handle() -> None:
136185 @run_on_tokio
137186 async def run () -> None :
138187 proc_mesh = await allocate ()
139- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
140- await verify_cast (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
188+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
189+ await verify_cast_to_call (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
141190
142191 await proc_mesh .stop_nonblocking ()
143192
@@ -151,9 +200,11 @@ async def test_cast_ref() -> None:
151200 @run_on_tokio
152201 async def run () -> None :
153202 proc_mesh = await allocate ()
154- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
203+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
155204 actor_mesh_ref = actor_mesh .bind ()
156- await verify_cast (actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 )))
205+ await verify_cast_to_call (
206+ actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 ))
207+ )
157208
158209 await proc_mesh .stop_nonblocking ()
159210
@@ -184,7 +235,7 @@ async def verify_slice(
184235 assert (
185236 sliced_shape .ranks () == replica_0_ranks + replica_1_ranks
186237 ), f"left is { sliced_shape .ranks ()} "
187- await verify_cast (sliced_mesh , mailbox , sliced_shape .ranks ())
238+ await verify_cast_to_call (sliced_mesh , mailbox , sliced_shape .ranks ())
188239
189240 assert sliced_shape .labels == ["replicas" , "hosts" , "gpus" ]
190241 assert sliced_shape .ndslice .sizes == [2 , 4 , 3 ]
@@ -224,7 +275,8 @@ async def test_slice_actor_mesh_handle() -> None:
224275 @run_on_tokio
225276 async def run () -> None :
226277 proc_mesh = await allocate ()
227- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
278+ actor_mesh = await spawn_actor_mesh (proc_mesh )
279+
228280 await verify_slice (actor_mesh , proc_mesh .client )
229281
230282 await proc_mesh .stop_nonblocking ()
@@ -239,7 +291,8 @@ async def test_slice_actor_mesh_ref() -> None:
239291 @run_on_tokio
240292 async def run () -> None :
241293 proc_mesh = await allocate ()
242- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
294+ actor_mesh = await spawn_actor_mesh (proc_mesh )
295+
243296 actor_mesh_ref = actor_mesh .bind ()
244297 await verify_slice (actor_mesh_ref , proc_mesh .client )
245298
0 commit comments