Skip to content

Pass cast rank to python actor #747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ use ndslice::Selection;
use ndslice::Shape;
use ndslice::Slice;
use ndslice::reshape::Limit;
use ndslice::reshape::ReshapeSliceExt;
use ndslice::reshape::ReshapeShapeExt;
use ndslice::selection::dsl;
use ndslice::shape::Range;
use serde::Deserialize;
Expand Down Expand Up @@ -425,6 +425,14 @@ impl ControllerMessageHandler for ControllerActor {
}),
};

let slice = Slice::new(0usize, vec![self.world_size], vec![1])?;
// Use a made-up label to create a fake shape. This shape is used by
// comm actor to determine the cast rank. Cast rank is not used by
// DeviceMesh, but we still need a shape there to make the logic happy.
let made_up_shape = Shape::new(vec!["fake_in_controller".to_string()], slice.clone())?
.reshape(Limit::from(CASTING_FANOUT_SIZE))
.shape;

let message = CastMessageEnvelope::from_serialized(
ActorMeshId(
ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
Expand All @@ -439,15 +447,10 @@ impl ControllerMessageHandler for ControllerActor {
.name()
.to_string(),
),
// Not reflective of the actual shape, but this is never actually used.
Shape::unity(),
made_up_shape,
message,
);

let slice = Slice::new(0usize, vec![self.world_size], vec![1])
.unwrap()
.reshape_with_limit(Limit::from(CASTING_FANOUT_SIZE));

self.comm_actor_ref.send(
cx,
CastMessage {
Expand Down
10 changes: 8 additions & 2 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub(crate) fn actor_mesh_cast<A, M>(
comm_actor_ref: &ActorRef<CommActor>,
selection_of_root: Selection,
root_mesh_shape: &Shape,
cast_mesh_shape: &Shape,
message: M,
) -> Result<(), CastError>
where
Expand All @@ -89,10 +90,13 @@ where
let message = CastMessageEnvelope::new::<A, M>(
actor_mesh_id.clone(),
sender.clone(),
root_mesh_shape.clone(),
cast_mesh_shape.clone(),
message,
)?;
let cast_message = CastMessage {
// Note: `dest` is on the root mesh' shape, which could be different
// from the cast mesh's shape if the cast is on a view, e.g. a sliced
// mesh.
dest: Uslice {
slice: root_mesh_shape.slice().clone(),
selection: selection_of_root,
Expand Down Expand Up @@ -147,6 +151,7 @@ where
comm_actor_ref,
sel_of_root,
root_mesh_shape,
sliced_shape,
message,
)
}
Expand All @@ -172,6 +177,7 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
self.proc_mesh().comm_actor(), // comm actor
selection, // the selected actors
self.shape(), // root mesh shape
self.shape(), // cast mesh shape
message, // the message
)
}
Expand Down Expand Up @@ -419,7 +425,7 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
/*sel_of_sliced=*/ &sel,
/*message=*/ message,
/*sliced_shape=*/ self.shape(),
/*base_shape=*/ self.0.shape(),
/*root_mesh_shape=*/ self.0.shape(),
)
}
}
Expand Down
7 changes: 5 additions & 2 deletions hyperactor_mesh/src/comm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,14 @@ impl CommActor {

// Deliver message here, if necessary.
if deliver_here {
let rank_on_root_mesh = mode.self_rank(cx.self_id());
let cast_rank = message.relative_rank(rank_on_root_mesh)?;
let cast_shape = message.shape();
let mut headers = cx.headers().clone();
set_cast_info_on_headers(
&mut headers,
mode.self_rank(cx.self_id()),
message.shape().clone(),
cast_rank,
cast_shape.clone(),
message.sender().clone(),
);
cx.post(
Expand Down
45 changes: 42 additions & 3 deletions hyperactor_mesh/src/comm/multicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use hyperactor::message::Castable;
use hyperactor::message::ErasedUnbound;
use hyperactor::message::IndexedErasedUnbound;
use hyperactor::reference::ActorId;
use ndslice::Extent;
use ndslice::Shape;
use ndslice::Slice;
use ndslice::selection::Selection;
Expand Down Expand Up @@ -120,6 +121,39 @@ impl CastMessageEnvelope {
&self.shape
}

/// Given a rank in the root shape, return the corresponding point in the
/// provided shape, which is a view of the root shape.
pub(crate) fn relative_rank(&self, rank_on_root_mesh: usize) -> anyhow::Result<usize> {
let shape = self.shape();
let coords = shape.slice().coordinates(rank_on_root_mesh).map_err(|e| {
anyhow::anyhow!(
"fail to calculate coords for root rank {} due to error: {}; shape is {:?}",
rank_on_root_mesh,
e,
shape,
)
})?;
let extent =
Extent::new(shape.labels().to_vec(), shape.slice().sizes().to_vec()).map_err(|e| {
anyhow::anyhow!(
"fail to calculate extent for root rank {} due to error: {}; shape is {}",
rank_on_root_mesh,
e,
shape,
)
})?;
let point = extent.point(coords).map_err(|e| {
anyhow::anyhow!(
"fail to calculate point for root rank {} due to error: {}; extent is {}, shape is {}",
rank_on_root_mesh,
e,
extent,
shape,
)
})?;
Ok(point.rank())
}

/// The unique key used to indicate the stream to which to deliver this message.
/// Concretely, the comm actors along the path should use this key to manage
/// sequence numbers and reorder buffers.
Expand Down Expand Up @@ -203,9 +237,14 @@ declare_attrs! {
pub attr CAST_ORIGINATING_SENDER: ActorId;
}

pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape, sender: ActorId) {
headers.set(CAST_RANK, rank);
headers.set(CAST_SHAPE, shape);
pub fn set_cast_info_on_headers(
headers: &mut Attrs,
cast_rank: usize,
cast_shape: Shape,
sender: ActorId,
) {
headers.set(CAST_RANK, cast_rank);
headers.set(CAST_SHAPE, cast_shape);
headers.set(CAST_ORIGINATING_SENDER, sender);
}

Expand Down
1 change: 1 addition & 0 deletions hyperactor_mesh/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ impl<A: RemoteActor> ActorMeshRef<A> {
&self.comm_actor_ref,
selection,
&self.root,
&self.root,
message,
),
}
Expand Down
104 changes: 81 additions & 23 deletions python/tests/_monarch/test_actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-unsafe

import pickle
from typing import Any, Callable, Coroutine, Iterable, List, TYPE_CHECKING
from typing import Any, Callable, cast, Coroutine, Iterable, List, TYPE_CHECKING

import monarch
import pytest
Expand Down Expand Up @@ -57,6 +57,12 @@ async def allocate() -> ProcMesh:


class MyActor:
def __init__(self) -> None:
# Note: for the same actor, its rank on the root mesh could be different
# from its rank on the mesh it is cast to. This is because the cast
# mesh could be a sliced mesh.
self._rank_on_root_mesh: int = -1

async def handle(
self,
mailbox: Mailbox,
Expand All @@ -68,8 +74,21 @@ async def handle(
local_state: Iterable[Any],
response_port: "PortProtocol[Any]",
) -> None:
assert rank is not None
response_port.send(f"rank: {rank}")
match method:
case MethodSpecifier.Init():
# Since this actor is spawn from the root proc mesh, the rank
# passed from init should be the rank on the root mesh.
self._rank_on_root_mesh = rank
response_port.send(None)
return None
case MethodSpecifier.ReturnsResponse(name=_):
response_port.send(self._rank_on_root_mesh)
return None
case MethodSpecifier.ExplicitPort(name=_):
response_port.exception(
NotImplementedError("ExplicitPort is not supported yet")
)
return None


# TODO - re-enable after resolving T232206970
Expand All @@ -95,35 +114,70 @@ async def run() -> None:
run()


async def verify_cast(
async def spawn_actor_mesh(proc_mesh: ProcMesh) -> PythonActorMesh:
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
# init actors to record their root ranks
receiver: PortReceiver
handle, receiver = proc_mesh.client.open_port()
port_ref = handle.bind()

message = PythonMessage(
PythonMessageKind.CallMethod(MethodSpecifier.Init(), port_ref),
pickle.dumps(None),
)
actor_mesh.cast(Selection.all(), message)
# wait for init to complete
for _ in range(len(actor_mesh.shape.ndslice)):
await receiver.recv_task()

return actor_mesh


async def cast_to_call(
actor_mesh: PythonActorMesh | PythonActorMeshRef,
mailbox: Mailbox,
message: PythonMessage,
) -> None:
sel = Selection.all()
if isinstance(actor_mesh, PythonActorMesh):
actor_mesh.cast(sel, message)
elif isinstance(actor_mesh, PythonActorMeshRef):
actor_mesh.cast(mailbox, sel, message)


async def verify_cast_to_call(
actor_mesh: PythonActorMesh | PythonActorMeshRef,
mailbox: Mailbox,
cast_ranks: List[int],
root_ranks: List[int],
) -> None:
receiver: PortReceiver
handle, receiver = mailbox.open_port()
port_ref = handle.bind()

# Now send the real message
message = PythonMessage(
PythonMessageKind.CallMethod(MethodSpecifier.ReturnsResponse("echo"), port_ref),
pickle.dumps("ping"),
)
sel = Selection.from_string("*")
if isinstance(actor_mesh, PythonActorMesh):
actor_mesh.cast(sel, message)
elif isinstance(actor_mesh, PythonActorMeshRef):
actor_mesh.cast(mailbox, sel, message)
await cast_to_call(actor_mesh, mailbox, message)

rcv_ranks = []
for _ in range(len(cast_ranks)):
for _ in range(len(root_ranks)):
message = await receiver.recv_task()
result_kind = message.kind
assert isinstance(result_kind, PythonMessageKind.Result)
rank = result_kind.rank
assert rank is not None
rcv_ranks.append(rank)
rcv_ranks.sort()
assert rcv_ranks == cast_ranks
cast_rank = result_kind.rank
assert cast_rank is not None
root_rank = cast(int, pickle.loads(message.message))
rcv_ranks.append((cast_rank, root_rank))
rcv_ranks.sort(key=lambda pair: pair[0])
recv_cast_ranks, recv_root_ranks = zip(*rcv_ranks)
assert recv_root_ranks == tuple(
root_ranks
), f"recv_root_ranks={recv_root_ranks}, root_ranks={tuple(root_ranks)}"
assert recv_cast_ranks == tuple(
range(len(root_ranks))
), f"recv_cast_ranks={recv_cast_ranks}, root_ranks={tuple(root_ranks)}"
# verify no more messages are received
with pytest.raises(TimeoutError):
await receiver.recv_task().with_timeout(1)
Expand All @@ -136,8 +190,8 @@ async def test_cast_handle() -> None:
@run_on_tokio
async def run() -> None:
proc_mesh = await allocate()
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
await verify_cast(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8)))
actor_mesh = await spawn_actor_mesh(proc_mesh)
await verify_cast_to_call(actor_mesh, proc_mesh.client, list(range(3 * 8 * 8)))

await proc_mesh.stop_nonblocking()

Expand All @@ -151,9 +205,11 @@ async def test_cast_ref() -> None:
@run_on_tokio
async def run() -> None:
proc_mesh = await allocate()
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
actor_mesh = await spawn_actor_mesh(proc_mesh)
actor_mesh_ref = actor_mesh.bind()
await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8)))
await verify_cast_to_call(
actor_mesh_ref, proc_mesh.client, list(range(3 * 8 * 8))
)

await proc_mesh.stop_nonblocking()

Expand Down Expand Up @@ -184,7 +240,7 @@ async def verify_slice(
assert (
sliced_shape.ranks() == replica_0_ranks + replica_1_ranks
), f"left is {sliced_shape.ranks()}"
await verify_cast(sliced_mesh, mailbox, sliced_shape.ranks())
await verify_cast_to_call(sliced_mesh, mailbox, sliced_shape.ranks())

assert sliced_shape.labels == ["replicas", "hosts", "gpus"]
assert sliced_shape.ndslice.sizes == [2, 4, 3]
Expand Down Expand Up @@ -224,7 +280,8 @@ async def test_slice_actor_mesh_handle() -> None:
@run_on_tokio
async def run() -> None:
proc_mesh = await allocate()
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
actor_mesh = await spawn_actor_mesh(proc_mesh)

await verify_slice(actor_mesh, proc_mesh.client)

await proc_mesh.stop_nonblocking()
Expand All @@ -239,7 +296,8 @@ async def test_slice_actor_mesh_ref() -> None:
@run_on_tokio
async def run() -> None:
proc_mesh = await allocate()
actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor)
actor_mesh = await spawn_actor_mesh(proc_mesh)

actor_mesh_ref = actor_mesh.bind()
await verify_slice(actor_mesh_ref, proc_mesh.client)

Expand Down