7
7
# pyre-unsafe
8
8
9
9
import pickle
10
- from typing import Any , Callable , Coroutine , Iterable , List , TYPE_CHECKING
10
+ from typing import Any , Callable , cast , Coroutine , Iterable , List , TYPE_CHECKING
11
11
12
12
import monarch
13
13
import pytest
@@ -57,6 +57,9 @@ async def allocate() -> ProcMesh:
57
57
58
58
59
59
class MyActor :
60
+ def __init__ (self ) -> None :
61
+ self ._root_rank : int = - 1
62
+
60
63
async def handle (
61
64
self ,
62
65
mailbox : Mailbox ,
@@ -68,8 +71,19 @@ async def handle(
68
71
local_state : Iterable [Any ],
69
72
response_port : "PortProtocol[Any]" ,
70
73
) -> None :
71
- assert rank is not None
72
- response_port .send (f"rank: { rank } " )
74
+ match method :
75
+ case MethodSpecifier .Init ():
76
+ self ._root_rank = rank
77
+ response_port .send (None )
78
+ return None
79
+ case MethodSpecifier .ReturnsResponse (name = _):
80
+ response_port .send (self ._root_rank )
81
+ return None
82
+ case MethodSpecifier .ExplicitPort (name = _):
83
+ response_port .exception (
84
+ NotImplementedError ("ExplicitPort is not supported yet" )
85
+ )
86
+ return None
73
87
74
88
75
89
# TODO - re-enable after resolving T232206970
@@ -95,35 +109,70 @@ async def run() -> None:
95
109
run ()
96
110
97
111
98
- async def verify_cast (
112
+ async def spawn_actor_mesh (proc_mesh : ProcMesh ) -> PythonActorMesh :
113
+ actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
114
+ # init actors to record their root ranks
115
+ receiver : PortReceiver
116
+ handle , receiver = proc_mesh .client .open_port ()
117
+ port_ref = handle .bind ()
118
+
119
+ message = PythonMessage (
120
+ PythonMessageKind .CallMethod (MethodSpecifier .Init (), port_ref ),
121
+ pickle .dumps (None ),
122
+ )
123
+ actor_mesh .cast (Selection .all (), message )
124
+ # wait for init to complete
125
+ for _ in range (len (actor_mesh .shape .ndslice )):
126
+ await receiver .recv_task ()
127
+
128
+ return actor_mesh
129
+
130
+
131
+ async def cast_to_call (
132
+ actor_mesh : PythonActorMesh | PythonActorMeshRef ,
133
+ mailbox : Mailbox ,
134
+ message : PythonMessage ,
135
+ ) -> None :
136
+ sel = Selection .all ()
137
+ if isinstance (actor_mesh , PythonActorMesh ):
138
+ actor_mesh .cast (sel , message )
139
+ elif isinstance (actor_mesh , PythonActorMeshRef ):
140
+ actor_mesh .cast (mailbox , sel , message )
141
+
142
+
143
+ async def verify_cast_to_call (
99
144
actor_mesh : PythonActorMesh | PythonActorMeshRef ,
100
145
mailbox : Mailbox ,
101
- cast_ranks : List [int ],
146
+ root_ranks : List [int ],
102
147
) -> None :
103
148
receiver : PortReceiver
104
149
handle , receiver = mailbox .open_port ()
105
150
port_ref = handle .bind ()
106
151
152
+ # Now send the real message
107
153
message = PythonMessage (
108
154
PythonMessageKind .CallMethod (MethodSpecifier .ReturnsResponse ("echo" ), port_ref ),
109
155
pickle .dumps ("ping" ),
110
156
)
111
- sel = Selection .from_string ("*" )
112
- if isinstance (actor_mesh , PythonActorMesh ):
113
- actor_mesh .cast (sel , message )
114
- elif isinstance (actor_mesh , PythonActorMeshRef ):
115
- actor_mesh .cast (mailbox , sel , message )
157
+ await cast_to_call (actor_mesh , mailbox , message )
116
158
117
159
rcv_ranks = []
118
- for _ in range (len (cast_ranks )):
160
+ for _ in range (len (root_ranks )):
119
161
message = await receiver .recv_task ()
120
162
result_kind = message .kind
121
163
assert isinstance (result_kind , PythonMessageKind .Result )
122
- rank = result_kind .rank
123
- assert rank is not None
124
- rcv_ranks .append (rank )
125
- rcv_ranks .sort ()
126
- assert rcv_ranks == cast_ranks
164
+ cast_rank = result_kind .rank
165
+ assert cast_rank is not None
166
+ root_rank = cast (int , pickle .loads (message .message ))
167
+ rcv_ranks .append ((cast_rank , root_rank ))
168
+ rcv_ranks .sort (key = lambda pair : pair [0 ])
169
+ recv_cast_ranks , recv_root_ranks = zip (* rcv_ranks )
170
+ assert recv_root_ranks == tuple (
171
+ root_ranks
172
+ ), f"recv_root_ranks={ recv_root_ranks } , root_ranks={ tuple (root_ranks )} "
173
+ assert recv_cast_ranks == tuple (
174
+ range (len (root_ranks ))
175
+ ), f"recv_cast_ranks={ recv_cast_ranks } , root_ranks={ tuple (root_ranks )} "
127
176
# verify no more messages are received
128
177
with pytest .raises (TimeoutError ):
129
178
await receiver .recv_task ().with_timeout (1 )
@@ -136,8 +185,8 @@ async def test_cast_handle() -> None:
136
185
@run_on_tokio
137
186
async def run () -> None :
138
187
proc_mesh = await allocate ()
139
- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
140
- await verify_cast (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
188
+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
189
+ await verify_cast_to_call (actor_mesh , proc_mesh .client , list (range (3 * 8 * 8 )))
141
190
142
191
await proc_mesh .stop_nonblocking ()
143
192
@@ -151,9 +200,11 @@ async def test_cast_ref() -> None:
151
200
@run_on_tokio
152
201
async def run () -> None :
153
202
proc_mesh = await allocate ()
154
- actor_mesh = await proc_mesh . spawn_nonblocking ( "test" , MyActor )
203
+ actor_mesh = await spawn_actor_mesh ( proc_mesh )
155
204
actor_mesh_ref = actor_mesh .bind ()
156
- await verify_cast (actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 )))
205
+ await verify_cast_to_call (
206
+ actor_mesh_ref , proc_mesh .client , list (range (3 * 8 * 8 ))
207
+ )
157
208
158
209
await proc_mesh .stop_nonblocking ()
159
210
@@ -184,7 +235,7 @@ async def verify_slice(
184
235
assert (
185
236
sliced_shape .ranks () == replica_0_ranks + replica_1_ranks
186
237
), f"left is { sliced_shape .ranks ()} "
187
- await verify_cast (sliced_mesh , mailbox , sliced_shape .ranks ())
238
+ await verify_cast_to_call (sliced_mesh , mailbox , sliced_shape .ranks ())
188
239
189
240
assert sliced_shape .labels == ["replicas" , "hosts" , "gpus" ]
190
241
assert sliced_shape .ndslice .sizes == [2 , 4 , 3 ]
@@ -224,7 +275,8 @@ async def test_slice_actor_mesh_handle() -> None:
224
275
@run_on_tokio
225
276
async def run () -> None :
226
277
proc_mesh = await allocate ()
227
- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
278
+ actor_mesh = await spawn_actor_mesh (proc_mesh )
279
+
228
280
await verify_slice (actor_mesh , proc_mesh .client )
229
281
230
282
await proc_mesh .stop_nonblocking ()
@@ -239,7 +291,8 @@ async def test_slice_actor_mesh_ref() -> None:
239
291
@run_on_tokio
240
292
async def run () -> None :
241
293
proc_mesh = await allocate ()
242
- actor_mesh = await proc_mesh .spawn_nonblocking ("test" , MyActor )
294
+ actor_mesh = await spawn_actor_mesh (proc_mesh )
295
+
243
296
actor_mesh_ref = actor_mesh .bind ()
244
297
await verify_slice (actor_mesh_ref , proc_mesh .client )
245
298
0 commit comments