From caa2fa46e127fd998b974b4aa25008e718e90cb4 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Tue, 5 Aug 2025 17:44:57 -0700 Subject: [PATCH] Pass cast rank to python actor (#747) Summary: When casting to a sliced mesh, the actors rank on the sliced mesh is different from the rank on the root mesh. Currently, the root mesh's rank is passed to Python actor. [That is wrong](https://www.internalfb.com/diff/D78355743?dst_version_fbid=1460199675405905&transaction_fbid=1279216557057466). We need to pass the rank on the cast mesh. If the cast mesh is a sliced mesh, then it should be sliced mesh rank. This diff fixes this. Reviewed By: mariusae Differential Revision: D79530146 --- controller/src/lib.rs | 17 ++-- hyperactor_mesh/src/actor_mesh.rs | 10 ++- hyperactor_mesh/src/comm.rs | 7 +- hyperactor_mesh/src/comm/multicast.rs | 45 +++++++++- hyperactor_mesh/src/reference.rs | 1 + python/tests/_monarch/test_actor_mesh.py | 104 ++++++++++++++++++----- 6 files changed, 147 insertions(+), 37 deletions(-) diff --git a/controller/src/lib.rs b/controller/src/lib.rs index ce21cc50..7f534788 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -65,7 +65,7 @@ use ndslice::Selection; use ndslice::Shape; use ndslice::Slice; use ndslice::reshape::Limit; -use ndslice::reshape::ReshapeSliceExt; +use ndslice::reshape::ReshapeShapeExt; use ndslice::selection::dsl; use ndslice::shape::Range; use serde::Deserialize; @@ -425,6 +425,14 @@ impl ControllerMessageHandler for ControllerActor { }), }; + let slice = Slice::new(0usize, vec![self.world_size], vec![1])?; + // Use a made-up label to create a fake shape. This shape is used by + // comm actor to determine the cast rank. Cast rank is not used by + // DeviceMesh, but we still need a shape there to make the logic happy. + let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())? + .reshape(Limit::from(CASTING_FANOUT_SIZE)) + .shape; + let message = CastMessageEnvelope::from_serialized( ActorMeshId( ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()), @@ -439,15 +447,10 @@ impl ControllerMessageHandler for ControllerActor { .name() .to_string(), ), - // Not reflective of the actual shape, but this is never actually used. - Shape::unity(), + made_up_shape, message, ); - let slice = Slice::new(0usize, vec![self.world_size], vec![1]) - .unwrap() - .reshape_with_limit(Limit::from(CASTING_FANOUT_SIZE)); - self.comm_actor_ref.send( cx, CastMessage { diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index f97f79e1..9995fcf9 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -75,6 +75,7 @@ pub(crate) fn actor_mesh_cast( comm_actor_ref: &ActorRef, selection_of_root: Selection, root_mesh_shape: &Shape, + cast_mesh_shape: &Shape, message: M, ) -> Result<(), CastError> where @@ -89,10 +90,13 @@ where let message = CastMessageEnvelope::new::( actor_mesh_id.clone(), sender.clone(), - root_mesh_shape.clone(), + cast_mesh_shape.clone(), message, )?; let cast_message = CastMessage { + // Note: `dest` is on the root mesh' shape, which could be different + // from the cast mesh's shape if the cast is on a view, e.g. a sliced + // mesh. dest: Uslice { slice: root_mesh_shape.slice().clone(), selection: selection_of_root, @@ -147,6 +151,7 @@ where comm_actor_ref, sel_of_root, root_mesh_shape, + sliced_shape, message, ) } @@ -172,6 +177,7 @@ pub trait ActorMesh: Mesh { self.proc_mesh().comm_actor(), // comm actor selection, // the selected actors self.shape(), // root mesh shape + self.shape(), // cast mesh shape message, // the message ) } @@ -419,7 +425,7 @@ impl ActorMesh for SlicedActorMesh<'_, A> { /*sel_of_sliced=*/ &sel, /*message=*/ message, /*sliced_shape=*/ self.shape(), - /*base_shape=*/ self.0.shape(), + /*root_mesh_shape=*/ self.0.shape(), ) } } diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index c1ec901d..bf8408ea 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -249,11 +249,14 @@ impl CommActor { // Deliver message here, if necessary. if deliver_here { + let rank_on_root_mesh = mode.self_rank(cx.self_id()); + let cast_rank = message.relative_rank(rank_on_root_mesh)?; + let cast_shape = message.shape(); let mut headers = cx.headers().clone(); set_cast_info_on_headers( &mut headers, - mode.self_rank(cx.self_id()), - message.shape().clone(), + cast_rank, + cast_shape.clone(), message.sender().clone(), ); cx.post( diff --git a/hyperactor_mesh/src/comm/multicast.rs b/hyperactor_mesh/src/comm/multicast.rs index 4f9792f3..95be2aa1 100644 --- a/hyperactor_mesh/src/comm/multicast.rs +++ b/hyperactor_mesh/src/comm/multicast.rs @@ -21,6 +21,7 @@ use hyperactor::message::Castable; use hyperactor::message::ErasedUnbound; use hyperactor::message::IndexedErasedUnbound; use hyperactor::reference::ActorId; +use ndslice::Extent; use ndslice::Shape; use ndslice::Slice; use ndslice::selection::Selection; @@ -120,6 +121,39 @@ impl CastMessageEnvelope { &self.shape } + /// Given a rank in the root shape, return the corresponding point in the + /// provided shape, which is a view of the root shape. + pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result { + let shape = self.shape(); + let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| { + anyhow::anyhow!( + "fail to calculate coords for root rank {} due to error: {}; shape is {:?}", + rank_on_root_mesh, + e, + shape, + ) + })?; + let extent = + Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| { + anyhow::anyhow!( + "fail to calculate extent for root rank {} due to error: {}; shape is {}", + rank_on_root_mesh, + e, + shape, + ) + })?; + let point = extent.point(coords).map_err(|e| { + anyhow::anyhow!( + "fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}", + rank_on_root_mesh, + e, + extent, + shape, + ) + })?; + Ok(point.rank()) + } + /// The unique key used to indicate the stream to which to deliver this message. /// Concretely, the comm actors along the path should use this key to manage /// sequence numbers and reorder buffers. @@ -203,9 +237,14 @@ declare_attrs! { pub attr CAST_ORIGINATING_SENDER: ActorId; } -pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape, sender: ActorId) { - headers.set(CAST_RANK, rank); - headers.set(CAST_SHAPE, shape); +pub fn set_cast_info_on_headers( + headers: &mut Attrs, + cast_rank: usize, + cast_shape: Shape, + sender: ActorId, +) { + headers.set(CAST_RANK, cast_rank); + headers.set(CAST_SHAPE, cast_shape); headers.set(CAST_ORIGINATING_SENDER, sender); } diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index 2e464a1e..62b73708 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -148,6 +148,7 @@ impl ActorMeshRef { &self.comm_actor_ref, selection, &self.root, + &self.root, message, ), } diff --git a/python/tests/_monarch/test_actor_mesh.py b/python/tests/_monarch/test_actor_mesh.py index 5975b365..026cac09 100644 --- a/python/tests/_monarch/test_actor_mesh.py +++ b/python/tests/_monarch/test_actor_mesh.py @@ -7,7 +7,7 @@ # pyre-unsafe import pickle -from typing import Any, Callable, Coroutine, Iterable, List, TYPE_CHECKING +from typing import Any, Callable, cast, Coroutine, Iterable, List, TYPE_CHECKING import monarch import pytest @@ -57,6 +57,12 @@ async def allocate() -> ProcMesh: class MyActor: + def __init__(self) -> None: + # Note: for the same actor, its rank on the root mesh could be different + # from its rank on the mesh it is cast to. This is because the cast + # mesh could be a sliced mesh. + self._rank_on_root_mesh: int = -1 + async def handle( self, mailbox: Mailbox, @@ -68,8 +74,21 @@ async def handle( local_state: Iterable[Any], response_port: "PortProtocol[Any]", ) -> None: - assert rank is not None - response_port.send(f"rank: {rank}") + match method: + case MethodSpecifier.Init(): + # Since this actor is spawn from the root proc mesh, the rank + # passed from init should be the rank on the root mesh. + self._rank_on_root_mesh = rank + response_port.send(None) + return None + case MethodSpecifier.ReturnsResponse(name=_): + response_port.send(self._rank_on_root_mesh) + return None + case MethodSpecifier.ExplicitPort(name=_): + response_port.exception( + NotImplementedError("ExplicitPort is not supported yet") + ) + return None # TODO - re-enable after resolving T232206970 @@ -95,35 +114,70 @@ async def run() -> None: run() -async def verify_cast( +async def spawn_actor_mesh(proc_mesh: ProcMesh) -> PythonActorMesh: + actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + # init actors to record their root ranks + receiver: PortReceiver + handle, receiver = proc_mesh.client.open_port() + port_ref = handle.bind() + + message = PythonMessage( + PythonMessageKind.CallMethod(MethodSpecifier.Init(), port_ref), + pickle.dumps(None), + ) + actor_mesh.cast(Selection.all(), message) + # wait for init to complete + for _ in range(len(actor_mesh.shape.ndslice)): + await receiver.recv_task() + + return actor_mesh + + +async def cast_to_call( + actor_mesh: PythonActorMesh | PythonActorMeshRef, + mailbox: Mailbox, + message: PythonMessage, +) -> None: + sel = Selection.all() + if isinstance(actor_mesh, PythonActorMesh): + actor_mesh.cast(sel, message) + elif isinstance(actor_mesh, PythonActorMeshRef): + actor_mesh.cast(mailbox, sel, message) + + +async def verify_cast_to_call( actor_mesh: PythonActorMesh | PythonActorMeshRef, mailbox: Mailbox, - cast_ranks: List[int], + root_ranks: List[int], ) -> None: receiver: PortReceiver handle, receiver = mailbox.open_port() port_ref = handle.bind() + # Now send the real message message = PythonMessage( PythonMessageKind.CallMethod(MethodSpecifier.ReturnsResponse("echo"), port_ref), pickle.dumps("ping"), ) - sel = Selection.from_string("*") - if isinstance(actor_mesh, PythonActorMesh): - actor_mesh.cast(sel, message) - elif isinstance(actor_mesh, PythonActorMeshRef): - actor_mesh.cast(mailbox, sel, message) + await cast_to_call(actor_mesh, mailbox, message) rcv_ranks = [] - for _ in range(len(cast_ranks)): + for _ in range(len(root_ranks)): message = await receiver.recv_task() result_kind = message.kind assert isinstance(result_kind, PythonMessageKind.Result) - rank = result_kind.rank - assert rank is not None - rcv_ranks.append(rank) - rcv_ranks.sort() - assert rcv_ranks == cast_ranks + cast_rank = result_kind.rank + assert cast_rank is not None + root_rank = cast(int, pickle.loads(message.message)) + rcv_ranks.append((cast_rank, root_rank)) + rcv_ranks.sort(key=lambda pair: pair[0]) + recv_cast_ranks, recv_root_ranks = zip(*rcv_ranks) + assert recv_root_ranks == tuple( + root_ranks + ), f"recv_root_ranks={recv_root_ranks}, root_ranks={tuple(root_ranks)}" + assert recv_cast_ranks == tuple( + range(len(root_ranks)) + ), f"recv_cast_ranks={recv_cast_ranks}, root_ranks={tuple(root_ranks)}" # verify no more messages are received with pytest.raises(TimeoutError): await receiver.recv_task().with_timeout(1) @@ -136,8 +190,8 @@ async def test_cast_handle() -> None: @run_on_tokio async def run() -> None: proc_mesh = await allocate() - actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) - await verify_cast(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8))) + actor_mesh = await spawn_actor_mesh(proc_mesh) + await verify_cast_to_call(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8))) await proc_mesh.stop_nonblocking() @@ -151,9 +205,11 @@ async def test_cast_ref() -> None: @run_on_tokio async def run() -> None: proc_mesh = await allocate() - actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + actor_mesh = await spawn_actor_mesh(proc_mesh) actor_mesh_ref = actor_mesh.bind() - await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8))) + await verify_cast_to_call( + actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8)) + ) await proc_mesh.stop_nonblocking() @@ -184,7 +240,7 @@ async def verify_slice( assert ( sliced_shape.ranks() == replica_0_ranks + replica_1_ranks ), f"left is {sliced_shape.ranks()}" - await verify_cast(sliced_mesh, mailbox, sliced_shape.ranks()) + await verify_cast_to_call(sliced_mesh, mailbox, sliced_shape.ranks()) assert sliced_shape.labels == ["replicas", "hosts", "gpus"] assert sliced_shape.ndslice.sizes == [2, 4, 3] @@ -224,7 +280,8 @@ async def test_slice_actor_mesh_handle() -> None: @run_on_tokio async def run() -> None: proc_mesh = await allocate() - actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + actor_mesh = await spawn_actor_mesh(proc_mesh) + await verify_slice(actor_mesh, proc_mesh.client) await proc_mesh.stop_nonblocking() @@ -239,7 +296,8 @@ async def test_slice_actor_mesh_ref() -> None: @run_on_tokio async def run() -> None: proc_mesh = await allocate() - actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + actor_mesh = await spawn_actor_mesh(proc_mesh) + actor_mesh_ref = actor_mesh.bind() await verify_slice(actor_mesh_ref, proc_mesh.client)