Skip to content

Commit fcce9ef

Browse files
pzhan9facebook-github-bot
authored andcommitted
Pass cast rank to python actor (meta-pytorch#747)
Summary: Rollback Plan: Differential Revision: D79530146
1 parent 5768bdd commit fcce9ef

File tree

5 files changed

+109
-30
lines changed

5 files changed

+109
-30
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ pub(crate) fn actor_mesh_cast<A, M>(
6767
selection_of_root: Selection,
6868
root_mesh_shape: &Shape,
6969
message: M,
70+
// None means cast to root mesh.
71+
cast_mesh_shape: Option<&Shape>,
7072
) -> Result<(), CastError>
7173
where
7274
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
@@ -80,13 +82,16 @@ where
8082
let message = CastMessageEnvelope::new::<A, M>(
8183
actor_mesh_id,
8284
sender.clone(),
83-
root_mesh_shape.clone(),
85+
cast_mesh_shape.unwrap_or(root_mesh_shape).clone(),
8486
message,
8587
)?;
8688

8789
comm_actor_ref.send(
8890
caps,
8991
CastMessage {
92+
// Note: `dest` is on the root mesh' shape, which could be different
93+
// from the cast mesh's shape if the cast is on a view, e.g. a sliced
94+
// mesh.
9095
dest: Uslice {
9196
slice: root_mesh_shape.slice().clone(),
9297
selection: selection_of_root,
@@ -136,6 +141,7 @@ where
136141
sel_of_root,
137142
root_mesh_shape,
138143
message,
144+
Some(sliced_shape),
139145
)
140146
}
141147

@@ -161,6 +167,7 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
161167
selection, // the selected actors
162168
self.shape(), // root mesh shape
163169
message, // the message
170+
None, // cast mesh shape
164171
)
165172
}
166173

@@ -407,7 +414,7 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
407414
/*sel_of_sliced=*/ &sel,
408415
/*message=*/ message,
409416
/*sliced_shape=*/ self.shape(),
410-
/*base_shape=*/ self.0.shape(),
417+
/*roo_mesh_shape=*/ self.0.shape(),
411418
)
412419
}
413420
}

hyperactor_mesh/src/comm.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,14 @@ impl CommActor {
249249

250250
// Deliver message here, if necessary.
251251
if deliver_here {
252+
let rank_on_root_mesh = mode.self_rank(cx.self_id());
253+
let cast_rank = message.relative_rank(rank_on_root_mesh)?;
254+
let cast_shape = message.shape();
252255
let mut headers = cx.headers().clone();
253256
set_cast_info_on_headers(
254257
&mut headers,
255-
mode.self_rank(cx.self_id()),
256-
message.shape().clone(),
258+
cast_rank,
259+
cast_shape.clone(),
257260
message.sender().clone(),
258261
);
259262
cx.post(

hyperactor_mesh/src/comm/multicast.rs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use hyperactor::message::Castable;
2121
use hyperactor::message::ErasedUnbound;
2222
use hyperactor::message::IndexedErasedUnbound;
2323
use hyperactor::reference::ActorId;
24+
use ndslice::Extent;
2425
use ndslice::Shape;
2526
use ndslice::Slice;
2627
use ndslice::selection::Selection;
@@ -120,6 +121,15 @@ impl CastMessageEnvelope {
120121
&self.shape
121122
}
122123

124+
/// Given a rank in the root shape, return the corresponding point in the
125+
/// provided shape, which is a view of the root shape.
126+
pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
127+
let shape = self.shape();
128+
let coords = shape.slice().coordinates(rank_on_root_mesh)?;
129+
let extent = Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec())?;
130+
Ok(extent.point(coords)?.rank())
131+
}
132+
123133
/// The unique key used to indicate the stream to which to deliver this message.
124134
/// Concretely, the comm actors along the path should use this key to manage
125135
/// sequence numbers and reorder buffers.
@@ -203,9 +213,14 @@ declare_attrs! {
203213
pub attr CAST_ORIGINATING_SENDER: ActorId;
204214
}
205215

206-
pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape, sender: ActorId) {
207-
headers.set(CAST_RANK, rank);
208-
headers.set(CAST_SHAPE, shape);
216+
pub fn set_cast_info_on_headers(
217+
headers: &mut Attrs,
218+
cast_rank: usize,
219+
cast_shape: Shape,
220+
sender: ActorId,
221+
) {
222+
headers.set(CAST_RANK, cast_rank);
223+
headers.set(CAST_SHAPE, cast_shape);
209224
headers.set(CAST_ORIGINATING_SENDER, sender);
210225
}
211226

hyperactor_mesh/src/reference.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ impl<A: RemoteActor> ActorMeshRef<A> {
148148
selection,
149149
&self.root,
150150
message,
151+
None,
151152
),
152153
}
153154
}

python/tests/_monarch/test_actor_mesh.py

Lines changed: 76 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88

99
import pickle
10-
from typing import Any, Callable, Coroutine, Iterable, List, TYPE_CHECKING
10+
from typing import Any, Callable, cast, Coroutine, Iterable, List, TYPE_CHECKING
1111

1212
import monarch
1313
import pytest
@@ -57,6 +57,9 @@ async def allocate() -> ProcMesh:
5757

5858

5959
class MyActor:
60+
def __init__(self) -> None:
61+
self._root_rank: int = -1
62+
6063
async def handle(
6164
self,
6265
mailbox: Mailbox,
@@ -68,8 +71,19 @@ async def handle(
6871
local_state: Iterable[Any],
6972
response_port: "PortProtocol[Any]",
7073
) -> None:
71-
assert rank is not None
72-
response_port.send(f"rank: {rank}")
74+
match method:
75+
case MethodSpecifier.Init():
76+
self._root_rank = rank
77+
response_port.send(None)
78+
return None
79+
case MethodSpecifier.ReturnsResponse(name=_):
80+
response_port.send(self._root_rank)
81+
return None
82+
case MethodSpecifier.ExplicitPort(name=_):
83+
response_port.exception(
84+
NotImplementedError("ExplicitPort is not supported yet")
85+
)
86+
return None
7387

7488

7589
# TODO - re-enable after resolving T232206970
@@ -95,35 +109,70 @@ async def run() -> None:
95109
run()
96110

97111

98-
async def verify_cast(
112+
async def spawn_actor_mesh(proc_mesh: ProcMesh) -> PythonActorMesh:
113+
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
114+
# init actors to record their root ranks
115+
receiver: PortReceiver
116+
handle, receiver = proc_mesh.client.open_port()
117+
port_ref = handle.bind()
118+
119+
message = PythonMessage(
120+
PythonMessageKind.CallMethod(MethodSpecifier.Init(), port_ref),
121+
pickle.dumps(None),
122+
)
123+
actor_mesh.cast(Selection.all(), message)
124+
# wait for init to complete
125+
for _ in range(len(actor_mesh.shape.ndslice)):
126+
await receiver.recv_task()
127+
128+
return actor_mesh
129+
130+
131+
async def cast_to_call(
132+
actor_mesh: PythonActorMesh | PythonActorMeshRef,
133+
mailbox: Mailbox,
134+
message: PythonMessage,
135+
) -> None:
136+
sel = Selection.all()
137+
if isinstance(actor_mesh, PythonActorMesh):
138+
actor_mesh.cast(sel, message)
139+
elif isinstance(actor_mesh, PythonActorMeshRef):
140+
actor_mesh.cast(mailbox, sel, message)
141+
142+
143+
async def verify_cast_to_call(
99144
actor_mesh: PythonActorMesh | PythonActorMeshRef,
100145
mailbox: Mailbox,
101-
cast_ranks: List[int],
146+
root_ranks: List[int],
102147
) -> None:
103148
receiver: PortReceiver
104149
handle, receiver = mailbox.open_port()
105150
port_ref = handle.bind()
106151

152+
# Now send the real message
107153
message = PythonMessage(
108154
PythonMessageKind.CallMethod(MethodSpecifier.ReturnsResponse("echo"), port_ref),
109155
pickle.dumps("ping"),
110156
)
111-
sel = Selection.from_string("*")
112-
if isinstance(actor_mesh, PythonActorMesh):
113-
actor_mesh.cast(sel, message)
114-
elif isinstance(actor_mesh, PythonActorMeshRef):
115-
actor_mesh.cast(mailbox, sel, message)
157+
await cast_to_call(actor_mesh, mailbox, message)
116158

117159
rcv_ranks = []
118-
for _ in range(len(cast_ranks)):
160+
for _ in range(len(root_ranks)):
119161
message = await receiver.recv_task()
120162
result_kind = message.kind
121163
assert isinstance(result_kind, PythonMessageKind.Result)
122-
rank = result_kind.rank
123-
assert rank is not None
124-
rcv_ranks.append(rank)
125-
rcv_ranks.sort()
126-
assert rcv_ranks == cast_ranks
164+
cast_rank = result_kind.rank
165+
assert cast_rank is not None
166+
root_rank = cast(int, pickle.loads(message.message))
167+
rcv_ranks.append((cast_rank, root_rank))
168+
rcv_ranks.sort(key=lambda pair: pair[0])
169+
recv_cast_ranks, recv_root_ranks = zip(*rcv_ranks)
170+
assert recv_root_ranks == tuple(
171+
root_ranks
172+
), f"recv_root_ranks={recv_root_ranks}, root_ranks={tuple(root_ranks)}"
173+
assert recv_cast_ranks == tuple(
174+
range(len(root_ranks))
175+
), f"recv_cast_ranks={recv_cast_ranks}, root_ranks={tuple(root_ranks)}"
127176
# verify no more messages are received
128177
with pytest.raises(TimeoutError):
129178
await receiver.recv_task().with_timeout(1)
@@ -136,8 +185,8 @@ async def test_cast_handle() -> None:
136185
@run_on_tokio
137186
async def run() -> None:
138187
proc_mesh = await allocate()
139-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
140-
await verify_cast(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8)))
188+
actor_mesh = await spawn_actor_mesh(proc_mesh)
189+
await verify_cast_to_call(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8)))
141190

142191
await proc_mesh.stop_nonblocking()
143192

@@ -151,9 +200,11 @@ async def test_cast_ref() -> None:
151200
@run_on_tokio
152201
async def run() -> None:
153202
proc_mesh = await allocate()
154-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
203+
actor_mesh = await spawn_actor_mesh(proc_mesh)
155204
actor_mesh_ref = actor_mesh.bind()
156-
await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8)))
205+
await verify_cast_to_call(
206+
actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8))
207+
)
157208

158209
await proc_mesh.stop_nonblocking()
159210

@@ -184,7 +235,7 @@ async def verify_slice(
184235
assert (
185236
sliced_shape.ranks() == replica_0_ranks + replica_1_ranks
186237
), f"left is {sliced_shape.ranks()}"
187-
await verify_cast(sliced_mesh, mailbox, sliced_shape.ranks())
238+
await verify_cast_to_call(sliced_mesh, mailbox, sliced_shape.ranks())
188239

189240
assert sliced_shape.labels == ["replicas", "hosts", "gpus"]
190241
assert sliced_shape.ndslice.sizes == [2, 4, 3]
@@ -224,7 +275,8 @@ async def test_slice_actor_mesh_handle() -> None:
224275
@run_on_tokio
225276
async def run() -> None:
226277
proc_mesh = await allocate()
227-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
278+
actor_mesh = await spawn_actor_mesh(proc_mesh)
279+
228280
await verify_slice(actor_mesh, proc_mesh.client)
229281

230282
await proc_mesh.stop_nonblocking()
@@ -239,7 +291,8 @@ async def test_slice_actor_mesh_ref() -> None:
239291
@run_on_tokio
240292
async def run() -> None:
241293
proc_mesh = await allocate()
242-
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
294+
actor_mesh = await spawn_actor_mesh(proc_mesh)
295+
243296
actor_mesh_ref = actor_mesh.bind()
244297
await verify_slice(actor_mesh_ref, proc_mesh.client)
245298

0 commit comments

Comments
 (0)