77# pyre-unsafe
88
99import pickle
10- from typing import Any , Callable , Coroutine , Iterable , List , TYPE_CHECKING
10+ from typing import Any , Callable , cast , Coroutine , Iterable , List , TYPE_CHECKING
1111
1212import monarch
1313import pytest
@@ -57,6 +57,12 @@ async def allocate() -> ProcMesh:
5757
5858
5959class MyActor :
60+ def __init__ (self ) -> None :
61+ # Note: for the same actor, its rank on the root mesh could be different
62+ # from its rank on the mesh it is cast to. This is because the cast
63+ # mesh could be a sliced mesh.
64+ self ._rank_on_root_mesh : int = - 1
65+
6066 async def handle (
6167 self ,
6268 mailbox : Mailbox ,
@@ -68,8 +74,21 @@ async def handle(
6874 local_state : Iterable [Any ],
6975 response_port : "PortProtocol[Any]" ,
7076 ) -> None :
71- assert rank is not None
72- response_port .send (f"rank: { rank } " )
77+ match method :
78+ case MethodSpecifier .Init ():
79+ # Since this actor is spawn from the root proc mesh, the rank
80+ # passed from init should be the rank on the root mesh.
81+ self ._rank_on_root_mesh = rank
82+ response_port .send (None )
83+ return None
84+ case MethodSpecifier .ReturnsResponse (name = _):
85+ response_port .send (self ._rank_on_root_mesh )
86+ return None
87+ case MethodSpecifier .ExplicitPort (name = _):
88+ response_port .exception (
89+ NotImplementedError ("ExplicitPort is not supported yet" )
90+ )
91+ return None
7392
7493
7594# TODO - re-enable after resolving T232206970
@@ -95,35 +114,70 @@ async def run() -> None:
95114 run ()
96115
97116
98- async def verify_cast (
117+ async def spawn_actor_mesh (proc_mesh : ProcMesh ) -> PythonActorMesh :
118+ actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
119+ # init actors to record their root ranks
120+ receiver : PortReceiver
121+ handle , receiver = proc_mesh .client .open_port ()
122+ port_ref = handle .bind ()
123+
124+ message = PythonMessage (
125+ PythonMessageKind .CallMethod (MethodSpecifier .Init (), port_ref ),
126+ pickle .dumps (None ),
127+ )
128+ actor_mesh .cast (Selection .all (), message )
129+ # wait for init to complete
130+ for _ in range (len (actor_mesh .shape .ndslice )):
131+ await receiver .recv_task ()
132+
133+ return actor_mesh
134+
135+
136+ async def cast_to_call (
137+ actor_mesh : PythonActorMesh | PythonActorMeshRef ,
138+ mailbox : Mailbox ,
139+ message : PythonMessage ,
140+ ) -> None :
141+ sel = Selection .all ()
142+ if isinstance (actor_mesh , PythonActorMesh ):
143+ actor_mesh .cast (sel , message )
144+ elif isinstance (actor_mesh , PythonActorMeshRef ):
145+ actor_mesh .cast (mailbox , sel , message )
146+
147+
148+ async def verify_cast_to_call (
99149 actor_mesh : PythonActorMesh | PythonActorMeshRef ,
100150 mailbox : Mailbox ,
101- cast_ranks : List [int ],
151+ root_ranks : List [int ],
102152) -> None :
103153 receiver : PortReceiver
104154 handle , receiver = mailbox .open_port ()
105155 port_ref = handle .bind ()
106156
157+ # Now send the real message
107158 message = PythonMessage (
108159 PythonMessageKind .CallMethod (MethodSpecifier .ReturnsResponse ("echo" ), port_ref ),
109160 pickle .dumps ("ping" ),
110161 )
111- sel = Selection .from_string ("*" )
112- if isinstance (actor_mesh , PythonActorMesh ):
113- actor_mesh .cast (sel , message )
114- elif isinstance (actor_mesh , PythonActorMeshRef ):
115- actor_mesh .cast (mailbox , sel , message )
162+ await cast_to_call (actor_mesh , mailbox , message )
116163
117164 rcv_ranks = []
118- for _ in range (len (cast_ranks )):
165+ for _ in range (len (root_ranks )):
119166 message = await receiver .recv_task ()
120167 result_kind = message .kind
121168 assert isinstance (result_kind , PythonMessageKind .Result )
122- rank = result_kind .rank
123- assert rank is not None
124- rcv_ranks .append (rank )
125- rcv_ranks .sort ()
126- assert rcv_ranks == cast_ranks
169+ cast_rank = result_kind .rank
170+ assert cast_rank is not None
171+ root_rank = cast (int , pickle .loads (message .message ))
172+ rcv_ranks .append ((cast_rank , root_rank ))
173+ rcv_ranks .sort (key = lambda pair : pair [0 ])
174+ recv_cast_ranks , recv_root_ranks = zip (* rcv_ranks )
175+ assert recv_root_ranks == tuple (
176+ root_ranks
177+ ), f"recv_root_ranks={ recv_root_ranks } , root_ranks={ tuple (root_ranks )} "
178+ assert recv_cast_ranks == tuple (
179+ range (len (root_ranks ))
180+ ), f"recv_cast_ranks={ recv_cast_ranks } , root_ranks={ tuple (root_ranks )} "
127181 # verify no more messages are received
128182 with pytest .raises (TimeoutError ):
129183 await receiver .recv_task ().with_timeout (1 )
@@ -136,8 +190,8 @@ async def test_cast_handle() -> None:
136190 @run_on_tokio
137191 async def run () -> None :
138192 proc_mesh = await allocate ()
139- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
140- await verify_cast (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
193+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
194+ await verify_cast_to_call (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
141195
142196 await proc_mesh .stop_nonblocking ()
143197
@@ -151,9 +205,11 @@ async def test_cast_ref() -> None:
151205 @run_on_tokio
152206 async def run () -> None :
153207 proc_mesh = await allocate ()
154- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
208+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
155209 actor_mesh_ref = actor_mesh .bind ()
156- await verify_cast (actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 )))
210+ await verify_cast_to_call (
211+ actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 ))
212+ )
157213
158214 await proc_mesh .stop_nonblocking ()
159215
@@ -184,7 +240,7 @@ async def verify_slice(
184240 assert (
185241 sliced_shape .ranks () == replica_0_ranks + replica_1_ranks
186242 ), f"left is { sliced_shape .ranks ()} "
187- await verify_cast (sliced_mesh , mailbox , sliced_shape .ranks ())
243+ await verify_cast_to_call (sliced_mesh , mailbox , sliced_shape .ranks ())
188244
189245 assert sliced_shape .labels == ["replicas" , "hosts" , "gpus" ]
190246 assert sliced_shape .ndslice .sizes == [2 , 4 , 3 ]
@@ -224,7 +280,8 @@ async def test_slice_actor_mesh_handle() -> None:
224280 @run_on_tokio
225281 async def run () -> None :
226282 proc_mesh = await allocate ()
227- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
283+ actor_mesh = await spawn_actor_mesh (proc_mesh )
284+
228285 await verify_slice (actor_mesh , proc_mesh .client )
229286
230287 await proc_mesh .stop_nonblocking ()
@@ -239,7 +296,8 @@ async def test_slice_actor_mesh_ref() -> None:
239296 @run_on_tokio
240297 async def run () -> None :
241298 proc_mesh = await allocate ()
242- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
299+ actor_mesh = await spawn_actor_mesh (proc_mesh )
300+
243301 actor_mesh_ref = actor_mesh .bind ()
244302 await verify_slice (actor_mesh_ref , proc_mesh .client )
245303
0 commit comments