diff --git a/controller/src/lib.rs b/controller/src/lib.rs index 37a7691a3..82a15a89c 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -175,7 +175,11 @@ impl ControllerActor { ) -> Result<(ActorHandle, ActorRef), anyhow::Error> { let bootstrap = ProcActor::bootstrap( controller_id.proc_id().clone(), - controller_id.proc_id().world_id().clone(), // REFACTOR(marius): make world_id a parameter of ControllerActor::bootstrap + controller_id + .proc_id() + .world_id() + .expect("multiprocess supports only ranked procs") + .clone(), // REFACTOR(marius): make world_id a parameter of ControllerActor::bootstrap listen_addr, bootstrap_addr.clone(), supervision_update_interval, @@ -685,7 +689,12 @@ mod tests { world_size: 1, comm_actor_ref: comm_handle.bind(), worker_gang_ref: GangId( - WorldId(proc.proc_id().world_name().to_string()), + WorldId( + proc.proc_id() + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ), "worker".to_string(), ) .into(), @@ -869,7 +878,12 @@ mod tests { world_size: 1, comm_actor_ref: comm_handle.bind(), worker_gang_ref: GangId( - WorldId(proc.proc_id().world_name().to_string()), + WorldId( + proc.proc_id() + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ), "worker".to_string(), ) .into(), @@ -975,7 +989,12 @@ mod tests { .await .unwrap(); - let world_id = WorldId(proc.proc_id().world_name().to_string()); + let world_id = WorldId( + proc.proc_id() + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ); let controller_handle = proc .spawn::( "controller", @@ -1500,7 +1519,12 @@ mod tests { world_size: 1, comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), worker_gang_ref: GangId( - WorldId(proc_id.world_name().to_string()), + WorldId( + proc_id + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ), "worker".to_string(), ) .into(), @@ -1591,7 +1615,12 @@ mod tests { world_size: 1, comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), worker_gang_ref: GangId( - WorldId(proc_id.world_name().to_string()), + WorldId( + proc_id + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ), "worker".to_string(), ) .into(), @@ -1692,7 +1721,12 @@ mod tests { world_size: 1, comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), worker_gang_ref: GangId( - WorldId(proc_id.world_name().to_string()), + WorldId( + proc_id + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ), "worker".to_string(), ) .into(), @@ -1835,7 +1869,12 @@ mod tests { world_size: 1, comm_actor_ref: ActorRef::attest(controller_id.proc_id().actor_id("comm", 0)), worker_gang_ref: GangId( - WorldId(proc_id.world_name().to_string()), + WorldId( + proc_id + .world_name() + .expect("only ranked actors are supported in the controller tests") + .to_string(), + ), "worker".to_string(), ) .into(), diff --git a/hyper/src/commands/demo.rs b/hyper/src/commands/demo.rs index 6a0e12055..202e8afed 100644 --- a/hyper/src/commands/demo.rs +++ b/hyper/src/commands/demo.rs @@ -61,7 +61,10 @@ impl DemoCommand { let proc_actor = ProcActor::bootstrap( proc_id.clone(), - proc_id.0.clone(), + proc_id + .world_id() + .expect("unranked proc not supported") + .clone(), addr, system_addr, Duration::from_secs(5), diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 9032ea5a2..3c2650ad4 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -19,7 +19,7 @@ //! # use hyperactor::mailbox::Mailbox; //! # use hyperactor::reference::{ActorId, ProcId, WorldId}; //! # tokio_test::block_on(async { -//! # let proc_id = ProcId(WorldId("world".to_string()), 0); +//! # let proc_id = ProcId::Ranked(WorldId("world".to_string()), 0); //! # let actor_id = ActorId(proc_id, "actor".to_string(), 0); //! let mbox = Mailbox::new_detached(actor_id); //! let (port, mut receiver) = mbox.open_port::(); @@ -36,7 +36,7 @@ //! # use hyperactor::mailbox::Mailbox; //! # use hyperactor::reference::{ActorId, ProcId, WorldId}; //! # tokio_test::block_on(async { -//! # let proc_id = ProcId(WorldId("world".to_string()), 0); +//! # let proc_id = ProcId::Ranked(WorldId("world".to_string()), 0); //! # let actor_id = ActorId(proc_id, "actor".to_string(), 0); //! let mbox = Mailbox::new_detached(actor_id); //! @@ -2517,7 +2517,11 @@ mod tests { #[test] fn test_error() { let err = MailboxError::new( - ActorId(ProcId(WorldId("myworld".into()), 2), "myactor".into(), 5), + ActorId( + ProcId::Ranked(WorldId("myworld".to_string()), 2), + "myactor".to_string(), + 5, + ), MailboxErrorKind::Closed, ); assert_eq!(format!("{}", err), "myworld[2].myactor[5]: mailbox closed"); diff --git a/hyperactor/src/proc.rs b/hyperactor/src/proc.rs index b0d82c60a..62c781a39 100644 --- a/hyperactor/src/proc.rs +++ b/hyperactor/src/proc.rs @@ -376,7 +376,7 @@ impl Proc { pub fn local() -> Self { // TODO: name these something that is ~ globally unique, e.g., incorporate // the hostname, some GUID, etc. - let proc_id = ProcId(id!(local), NEXT_LOCAL_RANK.fetch_add(1, Ordering::Relaxed)); + let proc_id = ProcId::Ranked(id!(local), NEXT_LOCAL_RANK.fetch_add(1, Ordering::Relaxed)); // TODO: make it so that local procs can talk to each other. Proc::new(proc_id, BoxedMailboxSender::new(PanickingMailboxSender)) } diff --git a/hyperactor/src/reference.rs b/hyperactor/src/reference.rs index 59cd2a6bd..8521207c5 100644 --- a/hyperactor/src/reference.rs +++ b/hyperactor/src/reference.rs @@ -34,6 +34,7 @@ use std::num::ParseIntError; use std::str::FromStr; use derivative::Derivative; +use enum_as_inner::EnumAsInner; use rand::Rng; use serde::Deserialize; use serde::Serialize; @@ -48,6 +49,7 @@ use crate::accum::ReducerSpec; use crate::actor::RemoteActor; use crate::attrs::Attrs; use crate::cap; +use crate::channel::ChannelAddr; use crate::data::Serialized; use crate::mailbox::MailboxSenderError; use crate::mailbox::MailboxSenderErrorKind; @@ -94,9 +96,7 @@ impl Reference { pub fn is_prefix_of(&self, other: &Reference) -> bool { match self { Self::World(_) => self.world_id() == other.world_id(), - Self::Proc(_) => { - self.world_id() == other.world_id() && self.proc_id() == other.proc_id() - } + Self::Proc(_) => self.proc_id() == other.proc_id(), Self::Actor(_) => self == other, Self::Port(_) => self == other, Self::Gang(_) => self == other, @@ -104,13 +104,13 @@ impl Reference { } /// The world id of the reference. - pub fn world_id(&self) -> &WorldId { + pub fn world_id(&self) -> Option<&WorldId> { match self { - Self::World(world_id) => world_id, - Self::Proc(ProcId(world_id, _)) => world_id, - Self::Actor(ActorId(ProcId(world_id, _), _, _)) => world_id, - Self::Port(PortId(ActorId(ProcId(world_id, _), _, _), _)) => world_id, - Self::Gang(GangId(world_id, _)) => world_id, + Self::World(world_id) => Some(world_id), + Self::Proc(proc_id) => proc_id.world_id(), + Self::Actor(ActorId(proc_id, _, _)) => proc_id.world_id(), + Self::Port(PortId(ActorId(proc_id, _, _), _)) => proc_id.world_id(), + Self::Gang(GangId(world_id, _)) => Some(world_id), } } @@ -127,7 +127,7 @@ impl Reference { /// The rank of the reference, if any. fn rank(&self) -> Option { - self.proc_id().map(|proc_id| proc_id.rank()) + self.proc_id().and_then(|proc_id| proc_id.rank()) } /// The actor id of the reference, if any. @@ -222,14 +222,22 @@ impl fmt::Display for Reference { /// # use hyperactor::reference::ActorId; /// # use hyperactor::reference::GangId; /// assert_eq!(id!(hello), WorldId("hello".into())); -/// assert_eq!(id!(hello[0]), ProcId(WorldId("hello".into()), 0)); +/// assert_eq!(id!(hello[0]), ProcId::Ranked(WorldId("hello".into()), 0)); /// assert_eq!( /// id!(hello[0].actor), -/// ActorId(ProcId(WorldId("hello".into()), 0), "actor".into(), 0) +/// ActorId( +/// ProcId::Ranked(WorldId("hello".into()), 0), +/// "actor".into(), +/// 0 +/// ) /// ); /// assert_eq!( /// id!(hello[0].actor[1]), -/// ActorId(ProcId(WorldId("hello".into()), 0), "actor".into(), 1) +/// ActorId( +/// ProcId::Ranked(WorldId("hello".into()), 0), +/// "actor".into(), +/// 1 +/// ) /// ); /// assert_eq!( /// id!(hello.actor), @@ -246,14 +254,14 @@ macro_rules! id { $crate::reference::WorldId(stringify!($world).to_string()) }; ($world:ident [$rank:expr_2021]) => { - $crate::reference::ProcId( + $crate::reference::ProcId::Ranked( $crate::reference::WorldId(stringify!($world).to_string()), $rank, ) }; ($world:ident [$rank:expr_2021] . $actor:ident) => { $crate::reference::ActorId( - $crate::reference::ProcId( + $crate::reference::ProcId::Ranked( $crate::reference::WorldId(stringify!($world).to_string()), $rank, ), @@ -263,7 +271,7 @@ macro_rules! id { }; ($world:ident [$rank:expr_2021] . $actor:ident [$pid:expr_2021]) => { $crate::reference::ActorId( - $crate::reference::ProcId( + $crate::reference::ProcId::Ranked( $crate::reference::WorldId(stringify!($world).to_string()), $rank, ), @@ -280,7 +288,7 @@ macro_rules! id { ($world:ident [$rank:expr_2021] . $actor:ident [$pid:expr_2021] [$port:expr_2021]) => { $crate::reference::PortId( $crate::reference::ActorId( - $crate::reference::ProcId( + $crate::reference::ProcId::Ranked( $crate::reference::WorldId(stringify!($world).to_string()), $rank, ), @@ -329,25 +337,25 @@ impl FromStr for Reference { // world[rank] Token::Elem(world) Token::LeftBracket Token::Uint(rank) Token::RightBracket => - Self::Proc(ProcId(WorldId(world.into()), rank)), + Self::Proc(ProcId::Ranked(WorldId(world.into()), rank)), // world[rank].actor (implied pid=0) Token::Elem(world) Token::LeftBracket Token::Uint(rank) Token::RightBracket Token::Dot Token::Elem(actor) => - Self::Actor(ActorId(ProcId(WorldId(world.into()), rank), actor.into(), 0)), + Self::Actor(ActorId(ProcId::Ranked(WorldId(world.into()), rank), actor.into(), 0)), // world[rank].actor[pid] Token::Elem(world) Token::LeftBracket Token::Uint(rank) Token::RightBracket Token::Dot Token::Elem(actor) Token::LeftBracket Token::Uint(pid) Token::RightBracket => - Self::Actor(ActorId(ProcId(WorldId(world.into()), rank), actor.into(), pid)), + Self::Actor(ActorId(ProcId::Ranked(WorldId(world.into()), rank), actor.into(), pid)), // world[rank].actor[pid][port] Token::Elem(world) Token::LeftBracket Token::Uint(rank) Token::RightBracket Token::Dot Token::Elem(actor) Token::LeftBracket Token::Uint(pid) Token::RightBracket Token::LeftBracket Token::Uint(index) Token::RightBracket => - Self::Port(PortId(ActorId(ProcId(WorldId(world.into()), rank), actor.into(), pid), index as u64)), + Self::Port(PortId(ActorId(ProcId::Ranked(WorldId(world.into()), rank), actor.into(), pid), index as u64)), // world.actor Token::Elem(world) Token::Dot Token::Elem(actor) => @@ -410,7 +418,7 @@ pub struct WorldId(pub String); impl WorldId { /// Create a proc ID with the provided index in this world. pub fn proc_id(&self, index: Index) -> ProcId { - ProcId(self.clone(), index) + ProcId::Ranked(self.clone(), index) } /// The world index. @@ -421,7 +429,7 @@ impl WorldId { /// Return a randomly selected user proc in this world. pub fn random_user_proc(&self) -> ProcId { let mask = 1usize << (std::mem::size_of::() * 8 - 1); - ProcId(self.clone(), rand::thread_rng().r#gen::() | mask) + ProcId::Ranked(self.clone(), rand::thread_rng().r#gen::() | mask) } } @@ -443,8 +451,8 @@ impl FromStr for WorldId { } } -/// Procs are identified by their _rank_ within a world. Each proc -/// represents an actor runtime that can locally route to all of its +/// Procs are identified by their _rank_ within a world or by a direct channel address. +/// Each proc represents an actor runtime that can locally route to all of its /// constituent actors. /// /// Ranks >= 1usize << (no. bits in usize - 1) (i.e., with the high bit set) are "user" @@ -460,9 +468,15 @@ impl FromStr for WorldId { PartialOrd, Hash, Ord, - Named + Named, + EnumAsInner )] -pub struct ProcId(pub WorldId, pub Index); +pub enum ProcId { + /// A ranked proc within a world + Ranked(WorldId, Index), + /// A proc reachable via a direct channel address + Direct(ChannelAddr), +} impl ProcId { /// Create an actor ID with the provided name, pid within this proc. @@ -470,26 +484,34 @@ impl ProcId { ActorId(self.clone(), name.into(), pid) } - /// The proc's world id. - pub fn world_id(&self) -> &WorldId { - &self.0 + /// The proc's world id, if this is a ranked proc. + pub fn world_id(&self) -> Option<&WorldId> { + match self { + ProcId::Ranked(world_id, _) => Some(world_id), + ProcId::Direct(_) => None, + } } - /// The world index. - pub fn world_name(&self) -> &str { - self.0.name() + /// The world name, if this is a ranked proc. + pub fn world_name(&self) -> Option<&str> { + self.world_id().map(|world_id| world_id.name()) } - /// The proc's rank. - pub fn rank(&self) -> Index { - self.1 + /// The proc's rank, if this is a ranked proc. + pub fn rank(&self) -> Option { + match self { + ProcId::Ranked(_, rank) => Some(*rank), + ProcId::Direct(_) => None, + } } } impl fmt::Display for ProcId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let ProcId(world_id, rank) = self; - write!(f, "{}[{}]", world_id, rank) + match self { + ProcId::Ranked(world_id, rank) => write!(f, "{}[{}]", world_id, rank), + ProcId::Direct(addr) => write!(f, "{}", addr), + } } } @@ -497,6 +519,15 @@ impl FromStr for ProcId { type Err = ReferenceParsingError; fn from_str(addr: &str) -> Result { + // We first try to parse the proc id as a channel address; otherwise + // as a ranked reference. These grammars are currently non-overlapping, + // but we need to be careful when we add new channel address types. + // + // Over time, we will deprecate ranked references and provide a robustly + // unambiguous syntax. + if let Ok(channel_addr) = addr.parse::() { + return Ok(ProcId::Direct(channel_addr)); + } match addr.parse()? { Reference::Proc(proc_id) => Ok(proc_id), _ => Err(ReferenceParsingError::WrongType("proc".into())), @@ -540,14 +571,16 @@ impl ActorId { &self.0 } - /// The world index. + /// The world name. Panics if this is a direct proc. pub fn world_name(&self) -> &str { - self.0.world_name() + self.0 + .world_name() + .expect("world_name() called on direct proc") } - /// The actor's proc's rank. + /// The actor's proc's rank. Panics if this is a direct proc. pub fn rank(&self) -> Index { - self.0.rank() + self.0.rank().expect("rank() called on direct proc") } /// The actor's name. @@ -1050,7 +1083,7 @@ pub struct GangId(pub WorldId, pub String); impl GangId { pub(crate) fn expand(&self, world_size: usize) -> impl Iterator + '_ { - (0..world_size).map(|rank| ActorId(ProcId(self.0.clone(), rank), self.1.clone(), 0)) + (0..world_size).map(|rank| ActorId(ProcId::Ranked(self.0.clone(), rank), self.1.clone(), 0)) } /// The world id of the gang. @@ -1067,7 +1100,7 @@ impl GangId { /// actor because the root actor is the public interface of a gang. pub fn actor_id(&self, rank: Index) -> ActorId { ActorId( - ProcId(self.world_id().clone(), rank), + ProcId::Ranked(self.world_id().clone(), rank), self.name().to_string(), 0, ) @@ -1140,7 +1173,11 @@ impl GangRef { gang_id: GangId(world_id, name), .. } = self; - ActorRef::attest(ActorId(ProcId(world_id.clone(), rank), name.clone(), 0)) + ActorRef::attest(ActorId( + ProcId::Ranked(world_id.clone(), rank), + name.clone(), + 0, + )) } /// Return the gang ID. @@ -1191,15 +1228,27 @@ mod tests { fn test_reference_parse() { let cases: Vec<(&str, Reference)> = vec![ ("test", WorldId("test".into()).into()), - ("test[234]", ProcId(WorldId("test".into()), 234).into()), + ( + "test[234]", + ProcId::Ranked(WorldId("test".into()), 234).into(), + ), ( "test[234].testactor[6]", - ActorId(ProcId(WorldId("test".into()), 234), "testactor".into(), 6).into(), + ActorId( + ProcId::Ranked(WorldId("test".into()), 234), + "testactor".into(), + 6, + ) + .into(), ), ( "test[234].testactor[6][1]", PortId( - ActorId(ProcId(WorldId("test".into()), 234), "testactor".into(), 6), + ActorId( + ProcId::Ranked(WorldId("test".into()), 234), + "testactor".into(), + 6, + ), 1, ) .into(), @@ -1229,14 +1278,22 @@ mod tests { #[test] fn test_id_macro() { assert_eq!(id!(hello), WorldId("hello".into())); - assert_eq!(id!(hello[0]), ProcId(WorldId("hello".into()), 0)); + assert_eq!(id!(hello[0]), ProcId::Ranked(WorldId("hello".into()), 0)); assert_eq!( id!(hello[0].actor), - ActorId(ProcId(WorldId("hello".into()), 0), "actor".into(), 0) + ActorId( + ProcId::Ranked(WorldId("hello".into()), 0), + "actor".into(), + 0 + ) ); assert_eq!( id!(hello[0].actor[1]), - ActorId(ProcId(WorldId("hello".into()), 0), "actor".into(), 1) + ActorId( + ProcId::Ranked(WorldId("hello".into()), 0), + "actor".into(), + 1 + ) ); assert_eq!( id!(hello.actor), diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index ea985f2d9..b0f18a5b4 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -987,7 +987,7 @@ mod tests { // Send a message to a non-existent actor (the proc however exists). let unmonitored_reply_to = mesh.client().open_port::().0.bind(); - let bad_actor = ActorRef::::attest(ActorId(ProcId(WorldId(name.clone()), 0), "foo".into(), 0)); + let bad_actor = ActorRef::::attest(ActorId(ProcId::Ranked(WorldId(name.clone()), 0), "foo".into(), 0)); bad_actor.send(mesh.client(), GetRank(true, unmonitored_reply_to)).unwrap(); // The message will be returned! diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index 3cf5afb26..3aeba8295 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -461,7 +461,7 @@ pub(crate) mod testing { DialMailboxRouter::new_with_default((UndeliverableMailboxSender {}).into_boxed()); router.clone().serve(router_rx); - let client_proc_id = ProcId(WorldId("test_stuck".to_string()), 0); + let client_proc_id = ProcId::Ranked(WorldId("test_stuck".to_string()), 0); let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(transport)).await.unwrap(); let client_proc = Proc::new( diff --git a/hyperactor_mesh/src/alloc/local.rs b/hyperactor_mesh/src/alloc/local.rs index 50ef5b26e..7f143805e 100644 --- a/hyperactor_mesh/src/alloc/local.rs +++ b/hyperactor_mesh/src/alloc/local.rs @@ -152,7 +152,7 @@ impl Alloc for LocalAlloc { match self.todo_rx.recv().await? { Action::Start(rank) => { - let proc_id = ProcId(self.world_id.clone(), rank); + let proc_id = ProcId::Ranked(self.world_id.clone(), rank); let bspan = tracing::info_span!("mesh_agent_bootstrap"); let (proc, mesh_agent) = match MeshAgent::bootstrap(proc_id.clone()).await { Ok(proc_and_agent) => proc_and_agent, diff --git a/hyperactor_mesh/src/alloc/process.rs b/hyperactor_mesh/src/alloc/process.rs index 91624a149..75fdad706 100644 --- a/hyperactor_mesh/src/alloc/process.rs +++ b/hyperactor_mesh/src/alloc/process.rs @@ -339,12 +339,18 @@ impl ProcessAlloc { fn index(&self, proc_id: &ProcId) -> Result { anyhow::ensure!( - proc_id.world_name().parse::()? == self.name, + proc_id + .world_name() + .expect("proc must be ranked for allocation index") + .parse::()? + == self.name, "proc {} does not belong to alloc {}", proc_id, self.name ); - Ok(proc_id.rank()) + Ok(proc_id + .rank() + .expect("proc must be ranked for allocation index")) } #[hyperactor::instrument_infallible] @@ -366,7 +372,7 @@ impl ProcessAlloc { cmd.stdout(Stdio::piped()); cmd.stderr(Stdio::piped()); - let proc_id = ProcId(WorldId(self.name.to_string()), index); + let proc_id = ProcId::Ranked(WorldId(self.name.to_string()), index); tracing::debug!("Spawning process {:?}", cmd); match cmd.spawn() { Err(err) => { @@ -446,7 +452,7 @@ impl Alloc for ProcessAlloc { } child.post(Allocator2Process::StartProc( - ProcId(WorldId(self.name.to_string()), index), + ProcId::Ranked(WorldId(self.name.to_string()), index), transport, )); } @@ -480,7 +486,7 @@ impl Alloc for ProcessAlloc { tracing::info!("child stopped with ProcStopReason::{:?}", reason); break Some(ProcState::Stopped { - proc_id: ProcId(WorldId(self.name.to_string()), index), + proc_id: ProcId::Ranked(WorldId(self.name.to_string()), index), reason }); }, @@ -564,7 +570,11 @@ mod tests { } }; - if let Some(child) = alloc.active.get(&proc_id.rank()) { + if let Some(child) = alloc.active.get( + &proc_id + .rank() + .expect("proc must be ranked for allocation lookup"), + ) { child.fail_group(); } diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index e68799742..3bdf35173 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -807,7 +807,11 @@ impl RemoteProcessAlloc { &mut self, proc_id: &ProcId, ) -> Result<&mut RemoteProcessAllocHostState, anyhow::Error> { - self.host_state_for_world_id(proc_id.world_id()) + self.host_state_for_world_id( + proc_id + .world_id() + .expect("proc must be ranked for host state lookup"), + ) } fn add_proc_id_to_host_state(&mut self, proc_id: &ProcId) -> Result<(), anyhow::Error> { @@ -835,7 +839,9 @@ impl RemoteProcessAlloc { proc_id: &ProcId, point: &Point, ) -> Result { - let world_id = proc_id.world_id(); + let world_id = proc_id + .world_id() + .expect("proc must be ranked for point mapping"); let offset = self .world_offsets .get(world_id) diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index be9f6ff89..d31f3e5c9 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -131,7 +131,10 @@ impl CommActorMode { .cloned() .ok_or_else(|| anyhow::anyhow!("no peer for rank {}", rank)), Self::Implicit => { - let world_id = self_id.proc_id().world_id(); + let world_id = self_id + .proc_id() + .world_id() + .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc"))?; let proc_id = world_id.proc_id(rank); let actor_id = ActorId::root(proc_id, self_id.name().to_string()); Ok(ActorRef::::attest(actor_id)) @@ -145,10 +148,13 @@ impl CommActorMode { } /// Return the rank of the comm actor, given a self id. - fn self_rank(&self, self_id: &ActorId) -> usize { + fn self_rank(&self, self_id: &ActorId) -> Result { match self { - Self::Mesh(rank, _) => *rank, - Self::Implicit | Self::ImplicitWithWorldId(_) => self_id.proc_id().rank(), + Self::Mesh(rank, _) => Ok(*rank), + Self::Implicit | Self::ImplicitWithWorldId(_) => self_id + .proc_id() + .rank() + .ok_or_else(|| anyhow::anyhow!("comm actor must be on a ranked proc")), } } } @@ -249,7 +255,7 @@ impl CommActor { // Deliver message here, if necessary. if deliver_here { - let rank_on_root_mesh = mode.self_rank(cx.self_id()); + let rank_on_root_mesh = mode.self_rank(cx.self_id())?; let cast_rank = message.relative_rank(rank_on_root_mesh)?; let cast_shape = message.shape(); let mut headers = cx.headers().clone(); @@ -346,7 +352,7 @@ impl Handler for CommActor { } = fwd_message; // Resolve/dedup routing frames. - let rank = self.mode.self_rank(cx.self_id()); + let rank = self.mode.self_rank(cx.self_id())?; let (deliver_here, next_steps) = ndslice::selection::routing::resolve_routing(rank, dests, &mut |_| { panic!("Choice encountered in CommActor routing") diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index fd322e286..39ba3fdc2 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -87,7 +87,7 @@ pub fn global_mailbox() -> Mailbox { GLOBAL_MAILBOX .get_or_init(|| { let world_id = WorldId(ShortUuid::generate().to_string()); - let client_proc_id = ProcId(world_id.clone(), 0); + let client_proc_id = ProcId::Ranked(world_id.clone(), 0); let client_proc = Proc::new( client_proc_id.clone(), BoxedMailboxSender::new(global_router().clone()), @@ -207,14 +207,20 @@ impl ProcMesh { let proc_id = proc_ids.get(rank).unwrap().clone(); router.bind(Reference::Proc(proc_id.clone()), addr.clone()); // Work around for Allocs that have more than one world. - world_ids.insert(proc_id.world_id().clone()); + world_ids.insert( + proc_id + .world_id() + .expect("proc in running state must be ranked") + .clone(), + ); } router.clone().serve(router_rx); // Set up a client proc for the mesh itself, so that we can attach ourselves // to it, and communicate with the agents. We wire it into the same router as // everything else, so now the whole mesh should be able to communicate. - let client_proc_id = ProcId(WorldId(format!("{}_manager", alloc.world_id().name())), 0); + let client_proc_id = + ProcId::Ranked(WorldId(format!("{}_manager", alloc.world_id().name())), 0); let (client_proc_addr, client_rx) = channel::serve(ChannelAddr::any(alloc.transport())) .await .map_err(|err| AllocatorError::Other(err.into()))?; @@ -229,7 +235,7 @@ impl ProcMesh { // Bind this router to the global router, to enable cross-mesh routing. // TODO: unbind this when we incorporate mesh destruction too. for world_id in world_ids { - global_router().bind(world_id.clone().into(), router.clone()); + global_router().bind(world_id.into(), router.clone()); } global_router().bind(alloc.world_id().clone().into(), router.clone()); global_router().bind(client_proc_id.into(), router.clone()); @@ -611,7 +617,7 @@ impl ProcEvents { // TODO(T231868026): find a better way to represent all actors in an actor // mesh for supervision event event.actor_id = ActorId( - ProcId(WorldId(actor_mesh_id.0.0.clone()), 0), + ProcId::Ranked(WorldId(actor_mesh_id.0.0.clone()), 0), actor_mesh_id.1.clone(), 0, ); @@ -799,7 +805,7 @@ mod tests { let name = alloc.name().to_string(); let mesh = ProcMesh::allocate(alloc).await.unwrap(); - assert_eq!(mesh.get(0).unwrap().world_name(), &name); + assert_eq!(mesh.get(0).unwrap().world_name(), Some(name.as_str())); } #[tokio::test] diff --git a/hyperactor_multiprocess/src/proc_actor.rs b/hyperactor_multiprocess/src/proc_actor.rs index 4b61ce4cd..65f5d6bd0 100644 --- a/hyperactor_multiprocess/src/proc_actor.rs +++ b/hyperactor_multiprocess/src/proc_actor.rs @@ -550,7 +550,11 @@ impl Actor for ProcActor { impl ProcActor { /// This proc's rank in the world. fn rank(&self) -> Index { - self.params.proc.proc_id().rank() + self.params + .proc + .proc_id() + .rank() + .expect("proc must be ranked") } } @@ -624,10 +628,25 @@ impl ProcMessageHandler for ProcActor { world_size: usize, ) -> Result<(), anyhow::Error> { for (index, proc_id) in proc_ids.into_iter().enumerate() { - let proc_world_id = proc_id.world_id().clone(); + let proc_world_id = proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(); // Check world id isn't the same as this proc's world id. - if &proc_world_id == self.params.proc.proc_id().world_id() - || &world_id == self.params.proc.proc_id().world_id() + if &proc_world_id + == self + .params + .proc + .proc_id() + .world_id() + .expect("proc must be ranked for world_id access") + || &world_id + == self + .params + .proc + .proc_id() + .world_id() + .expect("proc must be ranked for world_id access") { return Err(anyhow::anyhow!( "cannot spawn proc in same world {}", @@ -658,7 +677,13 @@ impl ProcMessageHandler for ProcActor { self.params.bootstrap_channel_addr.to_string(), ) .env(HYPERACTOR_WORLD_SIZE, world_size.to_string()) - .env(HYPERACTOR_RANK, proc_id.rank().to_string()) + .env( + HYPERACTOR_RANK, + proc_id + .rank() + .expect("proc must be ranked for rank env var") + .to_string(), + ) .env(HYPERACTOR_LOCAL_RANK, index.to_string()) .stdin(Stdio::null()) .stdout(Stdio::inherit()) @@ -831,7 +856,11 @@ where .await?; // Wait for the spawned actor to join. - while spawned_receiver.recv().await? != proc_id.rank() {} + while spawned_receiver.recv().await? + != proc_id + .rank() + .expect("proc must be ranked for rank comparison") + {} // Gspawned actors are always exported. Ok(ActorRef::attest(proc_id.actor_id(actor_name, 0))) @@ -1575,7 +1604,7 @@ mod tests { // Ping gets Pong's address let expected_1 = r#"UpdateAddress { - proc_id: ProcId( + proc_id: Ranked( WorldId( "world", ), @@ -1585,7 +1614,7 @@ mod tests { // Pong gets Ping's address let expected_2 = r#"UpdateAddress { - proc_id: ProcId( + proc_id: Ranked( WorldId( "world", ), @@ -1595,7 +1624,7 @@ mod tests { // Ping gets "user"'s address let expected_3 = r#"UpdateAddress { - proc_id: ProcId( + proc_id: Ranked( WorldId( "user", ),"#; @@ -1701,7 +1730,11 @@ mod tests { let listen_addr = ChannelAddr::any(ChannelTransport::Tcp); let bootstrap = ProcActor::bootstrap( actor_id.proc_id().clone(), - actor_id.proc_id().world_id().clone(), + actor_id + .proc_id() + .world_id() + .expect("proc must be ranked for bootstrap world_id") + .clone(), listen_addr.clone(), system_addr.clone(), Duration::from_secs(3), diff --git a/hyperactor_multiprocess/src/system.rs b/hyperactor_multiprocess/src/system.rs index 12383e4a9..a3ed00fa9 100644 --- a/hyperactor_multiprocess/src/system.rs +++ b/hyperactor_multiprocess/src/system.rs @@ -307,7 +307,7 @@ mod tests { let test_labels = HashMap::from([("test_name".to_string(), "test_value".to_string())]); let listen_addr = ChannelAddr::any(ChannelTransport::Local); - let proc_id = ProcId(foo_world_id.clone(), 1); + let proc_id = ProcId::Ranked(foo_world_id.clone(), 1); ProcActor::try_bootstrap( proc_id.clone(), foo_world_id.clone(), @@ -392,7 +392,7 @@ mod tests { let host_world_id = WorldId(("hostworker_world").to_string()); let listen_addr: ChannelAddr = ChannelAddr::any(ChannelTransport::Local); // Join a host proc to the system first with no worker_world yet. - let host_proc_id_1 = ProcId(host_world_id.clone(), 1); + let host_proc_id_1 = ProcId::Ranked(host_world_id.clone(), 1); ProcActor::try_bootstrap( host_proc_id_1.clone(), host_world_id.clone(), @@ -470,7 +470,7 @@ mod tests { host_procs: HashSet::from([host_proc_id_1.clone()]), procs: (8..12) .map(|i| ( - ProcId(worker_world_id.clone(), i), + ProcId::Ranked(worker_world_id.clone(), i), WorldSnapshotProcInfo { labels: HashMap::new() } @@ -486,7 +486,7 @@ mod tests { ); } - let host_proc_id_0 = ProcId(host_world_id.clone(), 0); + let host_proc_id_0 = ProcId::Ranked(host_world_id.clone(), 0); ProcActor::try_bootstrap( host_proc_id_0.clone(), host_world_id.clone(), @@ -517,7 +517,7 @@ mod tests { WorldSnapshot { host_procs: HashSet::from([host_proc_id_0, host_proc_id_1]), procs: HashMap::from_iter((0..12).map(|i| ( - ProcId(worker_world_id.clone(), i), + ProcId::Ranked(worker_world_id.clone(), i), WorldSnapshotProcInfo { labels: HashMap::new() } @@ -579,7 +579,7 @@ mod tests { // Bootstrap the host procs, which will lead to work procs being spawned. let futs = (0..2).map(|i| { - let host_proc_id = ProcId(host_world_id.clone(), i); + let host_proc_id = ProcId::Ranked(host_world_id.clone(), i); ProcActor::try_bootstrap( host_proc_id.clone(), host_world_id.clone(), @@ -613,7 +613,7 @@ mod tests { // Join a non-worker proc to the "foo" world. let foo_futs = (0..2).map(|i| { let listen_addr = ChannelAddr::any(ChannelTransport::Local); - let proc_id = ProcId(foo_world_id.clone(), i); + let proc_id = ProcId::Ranked(foo_world_id.clone(), i); ProcActor::try_bootstrap( proc_id.clone(), foo_world_id.clone(), @@ -700,7 +700,7 @@ mod tests { // Bootstrap the host procs, which will lead to work procs being spawned. let futs = (0..2).map(|i| { - let host_proc_id = ProcId(host_world_id.clone(), i); + let host_proc_id = ProcId::Ranked(host_world_id.clone(), i); ProcActor::try_bootstrap( host_proc_id.clone(), host_world_id.clone(), @@ -733,7 +733,7 @@ mod tests { // Join a non-worker proc to the "foo" world. let foo_futs = (0..2).map(|i| { let listen_addr = ChannelAddr::any(ChannelTransport::Local); - let proc_id = ProcId(foo_world_id.clone(), i); + let proc_id = ProcId::Ranked(foo_world_id.clone(), i); ProcActor::try_bootstrap( proc_id.clone(), foo_world_id.clone(), diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index eb67ca521..ce75926f6 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -385,7 +385,7 @@ impl Host { // interval [H*N, (H+1)*N). let rank = self.host_rank * scheduler_params.num_procs_per_host + self.num_procs_assigned; - let proc_id = ProcId(world_id.clone(), rank); + let proc_id = ProcId::Ranked(world_id.clone(), rank); proc_ids.push(proc_id); self.num_procs_assigned += 1; } @@ -432,7 +432,11 @@ pub struct HostId(ProcId); impl HostId { /// Creates a new HostId from a proc_id. pub fn new(proc_id: ProcId) -> Result { - if !proc_id.world_name().starts_with(SHADOW_PREFIX) { + if !proc_id + .world_name() + .expect("proc must be ranked for world_name check") + .starts_with(SHADOW_PREFIX) + { anyhow::bail!( "proc_id {} is not a valid HostId because it does not start with {}", proc_id, @@ -447,7 +451,11 @@ impl TryFrom for HostId { type Error = anyhow::Error; fn try_from(proc_id: ProcId) -> Result { - if !proc_id.world_name().starts_with(SHADOW_PREFIX) { + if !proc_id + .world_name() + .expect("proc must be ranked for world_name check") + .starts_with(SHADOW_PREFIX) + { anyhow::bail!( "proc_id {} is not a valid HostId because it does not start with {}", proc_id, @@ -616,9 +624,13 @@ impl World { Entry::Occupied(_) => { return Err(SystemActorError::DuplicatedHostId(host_id)); } - Entry::Vacant(entry) => { - entry.insert_entry(Host::new(proc_message_port.clone(), host_id.0.rank())) - } + Entry::Vacant(entry) => entry.insert_entry(Host::new( + proc_message_port.clone(), + host_id + .0 + .rank() + .expect("host proc must be ranked for rank access"), + )), }; if self.state.status == WorldStatus::AwaitingCreation { @@ -670,7 +682,14 @@ impl World { } // REFACTOR(marius): remove - let world_id = procs_ids.first().unwrap().0.clone(); + let world_id = procs_ids + .first() + .unwrap() + .clone() + .into_ranked() + .expect("proc must be ranked for world_id access") + .0 + .clone(); // Open port ref tracing::info!("spawning procs for host {:?}", host_id); router.serialize_and_send( @@ -736,7 +755,7 @@ impl ReportingRouter { // - The sender and the destination are on the same proc (it // doesn't make sense to be dialing connections between them). if envelope.sender().proc_id() == &id!(unknown[0]) - || envelope.sender().proc_id().world_id() == &id!(user) + || envelope.sender().proc_id().world_id() == Some(&id!(user)) || envelope.sender().proc_id() == &system_proc_id || envelope.dest().actor_id().proc_id() == &system_proc_id || envelope.sender().proc_id() == envelope.dest().actor_id().proc_id() @@ -911,7 +930,10 @@ impl HeartbeatRecord { now > *last_update_time + supervision_update_timeout }) .for_each(|(_, proc_id)| { - if let Some(proc_state) = state.procs.get_mut(&proc_id.1) { + if let Some(proc_state) = state + .procs + .get_mut(&proc_id.rank().expect("proc must be ranked for rank access")) + { match proc_state.proc_health { ProcStatus::Alive => proc_state.proc_health = ProcStatus::Expired, // Do not overwrite the health of a proc already known to be unhealthy. @@ -989,7 +1011,12 @@ impl SystemSupervisionState { world.heartbeat_record.update(&proc_state.proc_id, clock); // Update supervision map. - if let Some(info) = world.state.procs.get_mut(&proc_state.proc_id.rank()) { + if let Some(info) = world.state.procs.get_mut( + &proc_state + .proc_id + .rank() + .expect("proc must be ranked for proc state update"), + ) { match info.proc_health { ProcStatus::Alive => info.proc_health = proc_state.proc_health, // Do not overwrite the health of a proc already known to be unhealthy. @@ -997,10 +1024,13 @@ impl SystemSupervisionState { } info.failed_actors.extend(proc_state.failed_actors); } else { - world - .state - .procs - .insert(proc_state.proc_id.rank(), proc_state); + world.state.procs.insert( + proc_state + .proc_id + .rank() + .expect("proc must be ranked for rank access"), + proc_state, + ); } } @@ -1017,7 +1047,7 @@ impl SystemSupervisionState { .get_mut() .state .procs - .entry(proc_id.1) + .entry(proc_id.rank().expect("proc must be ranked for rank access")) { Entry::Occupied(_) => { self.update(proc_state, clock); @@ -1209,9 +1239,15 @@ impl Actor for SystemActor { // established. Update the proc's supervision status // accordingly. let proc_id = to.actor_id().proc_id(); - let world_id = proc_id.world_id(); + let world_id = proc_id + .world_id() + .expect("proc must be ranked for world_id access"); if let Some(world) = &mut self.supervision_state.supervision_map.get_mut(world_id) { - if let Some(proc) = world.state.procs.get_mut(&proc_id.rank()) { + if let Some(proc) = world + .state + .procs + .get_mut(&proc_id.rank().expect("proc must be ranked for rank access")) + { match proc.proc_health { ProcStatus::Alive => proc.proc_health = ProcStatus::ConnectionFailure, // Do not overwrite the health of a proc already @@ -1516,7 +1552,11 @@ impl SystemMessageHandler for SystemActor { for (proc_id, port) in all_procs.into_iter() { let stopping_state = self .worlds_to_stop - .get_mut(&World::get_real_world_id(proc_id.world_id())) + .get_mut(&World::get_real_world_id( + proc_id + .world_id() + .expect("proc must be ranked for world_id access"), + )) .unwrap(); if !stopping_state.stopping_procs.insert(proc_id) { continue; @@ -1722,7 +1762,12 @@ impl Handler for SystemActor { } } let mut world_stopped = false; - let world_id = &msg.proc_id.0; + let world_id = &msg + .proc_id + .clone() + .into_ranked() + .expect("proc must be ranked for world_id access") + .0; if let Some(stopping_state) = self.worlds_to_stop.get_mut(world_id) { stopping_state.stopped_procs.insert(msg.proc_id.clone()); tracing::debug!( @@ -1833,7 +1878,7 @@ mod tests { async fn spawn_mock_host_actor(proc_world_id: WorldId, host_id: usize) -> MockHostActor { // Set up a local actor. - let local_proc_id = ProcId( + let local_proc_id = ProcId::Ranked( WorldId(format!("{}{}", SHADOW_PREFIX, proc_world_id.name())), host_id, ); @@ -1866,7 +1911,7 @@ mod tests { ) { let world_id = WorldId(name.to_string()); // Proc ID: world[idx] - let local_proc_id = ProcId(world_id.clone(), idx); + let local_proc_id = ProcId::Ranked(world_id.clone(), idx); let (local_proc_addr, local_proc_rx) = channel::serve(ChannelAddr::any(ChannelTransport::Local)) .await @@ -1989,7 +2034,13 @@ mod tests { let failures = sv.get_world_with_failures(&world_id, &clock); let procs = failures.unwrap().procs; assert_eq!(procs.len(), 1); - assert!(procs.contains_key(&proc_id_0.1)); + assert!( + procs.contains_key( + &proc_id_0 + .rank() + .expect("proc must be ranked for rank access") + ) + ); // Actor failure happened to proc_1 sv.report( @@ -2007,8 +2058,20 @@ mod tests { let failures = sv.get_world_with_failures(&world_id, &clock); let procs = failures.unwrap().procs; assert_eq!(procs.len(), 2); - assert!(procs.contains_key(&proc_id_0.1)); - assert!(procs.contains_key(&proc_id_1.1)); + assert!( + procs.contains_key( + &proc_id_0 + .rank() + .expect("proc must be ranked for rank access") + ) + ); + assert!( + procs.contains_key( + &proc_id_1 + .rank() + .expect("proc must be ranked for rank access") + ) + ); } #[tokio::test] @@ -2056,7 +2119,14 @@ mod tests { let ret = client_rx.recv().await.unwrap(); assert_eq!(ret.worlds.len(), 1); assert_eq!( - ret.worlds.get(&client_proc_id.0).unwrap().status, + ret.worlds + .get( + client_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + ) + .unwrap() + .status, WorldStatus::AwaitingCreation ); @@ -2122,7 +2192,9 @@ mod tests { msg.unwrap(), Some(WorldSupervisionState { procs: HashMap::from([( - local_proc_id.1, + local_proc_id + .rank() + .expect("proc must be ranked for rank access"), ProcSupervisionState { world_id: world_id.clone(), proc_addr: local_proc_addr.clone(), @@ -2182,7 +2254,10 @@ mod tests { // Create a world system_actor_handle .send(SystemMessage::UpsertWorld { - world_id: local_proc_id.0.clone(), + world_id: local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), shape: Shape::Definite(vec![1]), num_procs_per_host: 1, env: Environment::Local, @@ -2214,7 +2289,15 @@ mod tests { .unwrap(); assert_eq!(snapshot.worlds.len(), 1); assert_eq!( - snapshot.worlds.get(&local_proc_id.0).unwrap().status, + snapshot + .worlds + .get( + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + ) + .unwrap() + .status, WorldStatus::Live ); @@ -2224,7 +2307,13 @@ mod tests { let mut iter = 0; // Wait for the world to be unhealthy let mut state = system_actor_handle - .state(&client_mailbox, local_proc_id.0.clone()) + .state( + &client_mailbox, + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), + ) .await .unwrap() .unwrap(); @@ -2236,7 +2325,13 @@ mod tests { // Don't query too frequently RealClock.sleep(Duration::from_millis(100)).await; state = system_actor_handle - .state(&client_mailbox, local_proc_id.0.clone()) + .state( + &client_mailbox, + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), + ) .await .unwrap() .unwrap(); @@ -2254,7 +2349,12 @@ mod tests { let _ = system_actor_handle .stop( &client_mailbox, - Some(vec![local_proc_id.0]), + Some(vec![ + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), + ]), Duration::from_secs(2), client_tx.bind(), ) @@ -2325,7 +2425,10 @@ mod tests { // Create a world system_actor_handle .send(SystemMessage::UpsertWorld { - world_id: local_proc_id.0.clone(), + world_id: local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), shape: Shape::Definite(vec![1]), num_procs_per_host: 1, env: Environment::Local, @@ -2358,7 +2461,13 @@ mod tests { let mut iter = 0; // Wait for the world to be unhealthy let mut state = system_actor_handle - .state(&client_mailbox, local_proc_id.0.clone()) + .state( + &client_mailbox, + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), + ) .await .unwrap() .unwrap(); @@ -2370,7 +2479,13 @@ mod tests { // Don't query too frequently RealClock.sleep(Duration::from_millis(100)).await; state = system_actor_handle - .state(&client_mailbox, local_proc_id.0.clone()) + .state( + &client_mailbox, + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), + ) .await .unwrap() .unwrap(); @@ -2382,7 +2497,12 @@ mod tests { let _ = system_actor_handle .stop( &client_mailbox, - Some(vec![local_proc_id.0]), + Some(vec![ + local_proc_id + .world_id() + .expect("proc must be ranked for world_id access") + .clone(), + ]), Duration::from_secs(2), client_tx.bind(), ) @@ -2468,7 +2588,7 @@ mod tests { assert_eq!(all_procs.len(), num_procs); all_procs.sort(); for (i, proc) in all_procs.iter().enumerate() { - assert_eq!(*proc, ProcId(WorldId(world_name.clone()), i)); + assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i)); } } @@ -2542,7 +2662,7 @@ mod tests { assert_eq!(all_procs.len(), num_procs); all_procs.sort(); for (i, proc) in all_procs.iter().enumerate() { - assert_eq!(*proc, ProcId(WorldId(world_name.clone()), i)); + assert_eq!(*proc, ProcId::Ranked(WorldId(world_name.clone()), i)); } } diff --git a/hyperactor_telemetry/src/lib.rs b/hyperactor_telemetry/src/lib.rs index 8d12d86ab..8767e5864 100644 --- a/hyperactor_telemetry/src/lib.rs +++ b/hyperactor_telemetry/src/lib.rs @@ -134,7 +134,7 @@ fn writer() -> Box { lazy_static! { static ref TELEMETRY_CLOCK: Arc>> = - { Arc::new(Mutex::new(Box::new(DefaultTelemetryClock {}))) }; + Arc::new(Mutex::new(Box::new(DefaultTelemetryClock {}))); } /// The recorder singleton that is configured as a layer in the the default tracing diff --git a/monarch_hyperactor/src/bin/process_allocator/common.rs b/monarch_hyperactor/src/bin/process_allocator/common.rs index c189918dc..fa43ba5d9 100644 --- a/monarch_hyperactor/src/bin/process_allocator/common.rs +++ b/monarch_hyperactor/src/bin/process_allocator/common.rs @@ -150,10 +150,20 @@ mod tests { alloc::ProcState::Created { proc_id, .. } => { // alloc.next() will keep creating procs and incrementing rank id // so we mod the rank by world_size to map it to its logical rank - created_ranks.insert(proc_id.rank() % world_size); + created_ranks.insert( + proc_id + .rank() + .expect("process allocator currently supports only ranked procs") + % world_size, + ); } alloc::ProcState::Stopped { proc_id, .. } => { - stopped_ranks.insert(proc_id.rank() % world_size); + stopped_ranks.insert( + proc_id + .rank() + .expect("process allocator currently supports only ranked procs") + % world_size, + ); } _ => {} } @@ -353,7 +363,11 @@ mod tests { let proc_state = alloc.next().await.unwrap(); match proc_state { alloc::ProcState::Created { proc_id, .. } => { - created_ranks.insert(proc_id.rank()); + created_ranks.insert( + proc_id + .rank() + .expect("process allocator currently supports only ranked procs"), + ); } _ => { panic!("Unexpected message: {:?}", proc_state) @@ -373,7 +387,12 @@ mod tests { // ignore } alloc::ProcState::Stopped { proc_id, .. } => { - stopped_ranks.insert(proc_id.rank() % world_size); + stopped_ranks.insert( + proc_id + .rank() + .expect("process allocator currently supports only ranked procs") + % world_size, + ); } _ => { panic!("Unexpected message: {:?}", proc_state) diff --git a/monarch_hyperactor/src/proc.rs b/monarch_hyperactor/src/proc.rs index 43f232b47..53f3f6476 100644 --- a/monarch_hyperactor/src/proc.rs +++ b/monarch_hyperactor/src/proc.rs @@ -89,12 +89,19 @@ impl PyProc { #[getter] fn world_name(&self) -> String { - self.inner.proc_id().world_name().to_string() + self.inner + .proc_id() + .world_name() + .expect("proc must be ranked for world name") + .to_string() } #[getter] fn rank(&self) -> usize { - self.inner.proc_id().rank() + self.inner + .proc_id() + .rank() + .expect("proc must be ranked for rank access") } #[getter] @@ -201,7 +208,11 @@ impl PyProc { let bootstrap = ProcActor::bootstrap_for_proc( proc.clone().clone(), - proc.clone().proc_id().world_id().clone(), // REFACTOR(marius): factor out world id + proc.clone() + .proc_id() + .world_id() + .expect("proc must be ranked for world id") + .clone(), // REFACTOR(marius): factor out world id listen_addr, bootstrap_addr.clone(), system_supervision_ref, @@ -281,7 +292,7 @@ impl PyActorId { fn new(world_name: &str, rank: Index, actor_name: &str, pid: Index) -> Self { Self { inner: ActorId( - ProcId(WorldId(world_name.to_string()), rank), + ProcId::Ranked(WorldId(world_name.to_string()), rank), actor_name.to_string(), pid, ), diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index 40b01d119..58fc31c36 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -105,7 +105,7 @@ pub async fn spawn_sim_worker( rank: usize, ) -> anyhow::Result> { let listen_addr = ChannelAddr::any(bootstrap_addr.transport()); - let worker_proc_id = ProcId(worker_world_id.clone(), rank); + let worker_proc_id = ProcId::Ranked(worker_world_id.clone(), rank); let worker_actor_id = ActorId(worker_proc_id.clone(), "worker".into(), 0); let ChannelAddr::Sim(bootstrap_addr) = bootstrap_addr else { diff --git a/monarch_simulator/src/controller.rs b/monarch_simulator/src/controller.rs index be92aa584..db741f258 100644 --- a/monarch_simulator/src/controller.rs +++ b/monarch_simulator/src/controller.rs @@ -217,7 +217,11 @@ impl SimControllerActor { ) -> Result<(ActorHandle, ActorRef), anyhow::Error> { let bootstrap = ProcActor::bootstrap( controller_id.proc_id().clone(), - controller_id.proc_id().world_id().clone(), // REFACTOR(marius): plumb world id through SimControllerActor::bootstrap + controller_id + .proc_id() + .world_id() + .expect("sim controller only works on ranked procs") + .clone(), // REFACTOR(marius): plumb world id through SimControllerActor::bootstrap listen_addr, bootstrap_addr.clone(), supervision_update_interval, diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index aa9f9da6d..174e0d133 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -134,12 +134,12 @@ mod tests { let controller_world_name = format!("controller_world_{}", i); let worker_world_name = format!("worker_world_{}", i); controller_actor_ids.push(ActorId( - ProcId(WorldId(controller_world_name), 0), + ProcId::Ranked(WorldId(controller_world_name), 0), "root".into(), 0, )); worker_actor_ids.push(ActorId( - ProcId(WorldId(worker_world_name.clone()), 0), + ProcId::Ranked(WorldId(worker_world_name.clone()), 0), "root".into(), 0, ));