From 24d475c96c9957cf7829005bcc407e885f360c25 Mon Sep 17 00:00:00 2001 From: Pablo Ruiz Fischer Bennetts Date: Mon, 4 Aug 2025 17:13:12 -0700 Subject: [PATCH 1/2] Support configurable heartbeat timeout (#698) Summary: Move RemoteAlloc HB timeout as part of configs Monarch runs with 1k+ workers timeout on allocation with the default heartbeat setting (1s). This is not a fix but simply extending configurability to Python side. Reviewed By: vidhyav Differential Revision: D79064585 --- hyperactor/src/config.rs | 17 +++ hyperactor_mesh/src/alloc/remoteprocess.rs | 113 ++++++++++-------- hyperactor_mesh/test/remote_process_alloc.rs | 1 - monarch_hyperactor/src/alloc.rs | 12 +- .../src/bin/process_allocator/common.rs | 9 -- python/tests/test_allocator.py | 17 --- 6 files changed, 82 insertions(+), 87 deletions(-) diff --git a/hyperactor/src/config.rs b/hyperactor/src/config.rs index 8842f8ee3..fc661a317 100644 --- a/hyperactor/src/config.rs +++ b/hyperactor/src/config.rs @@ -46,6 +46,9 @@ declare_attrs! { /// Timeout used by proc mesh for stopping an actor. pub attr STOP_ACTOR_TIMEOUT: Duration = Duration::from_secs(1); + + /// Heartbeat interval for remote allocator + pub attr REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); } /// Load configuration from environment variables @@ -87,6 +90,13 @@ pub fn from_env() -> Attrs { } } + // Load remote allocator heartbeat interval + if let Ok(val) = env::var("HYPERACTOR_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_SECS") { + if let Ok(parsed) = val.parse::() { + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = Duration::from_secs(parsed); + } + } + config } @@ -122,6 +132,9 @@ pub fn merge(config: &mut Attrs, other: &Attrs) { if other.contains_key(SPLIT_MAX_BUFFER_SIZE) { config[SPLIT_MAX_BUFFER_SIZE] = other[SPLIT_MAX_BUFFER_SIZE]; } + if other.contains_key(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) { + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = other[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL]; + } } /// Global configuration functions @@ -292,6 +305,10 @@ mod tests { ); assert_eq!(config[MESSAGE_ACK_EVERY_N_MESSAGES], 1000); assert_eq!(config[SPLIT_MAX_BUFFER_SIZE], 5); + assert_eq!( + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL], + Duration::from_secs(5) + ); } #[test] diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index 4f9abb02f..86bbcc4e4 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -32,6 +32,7 @@ use hyperactor::channel::TxStatus; use hyperactor::clock; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; +use hyperactor::config; use hyperactor::mailbox::DialMailboxRouter; use hyperactor::mailbox::MailboxServer; use hyperactor::mailbox::monitored_return_handle; @@ -72,8 +73,6 @@ pub enum RemoteProcessAllocatorMessage { /// Ordered list of hosts in this allocation. Can be used to /// pre-populate the any local configurations such as torch.dist. hosts: Vec, - /// How often to send heartbeat messages to check if client is alive. - heartbeat_interval: Duration, }, /// Stop allocation. Stop, @@ -196,7 +195,6 @@ impl RemoteProcessAllocator { spec, bootstrap_addr, hosts, - heartbeat_interval, }) => { tracing::info!("received allocation request: {:?}", spec); @@ -212,7 +210,6 @@ impl RemoteProcessAllocator { Box::new(alloc) as Box, bootstrap_addr, hosts, - heartbeat_interval, cancel_token, )), }) @@ -262,7 +259,6 @@ impl RemoteProcessAllocator { alloc: Box, bootstrap_addr: ChannelAddr, hosts: Vec, - heartbeat_interval: Duration, cancel_token: CancellationToken, ) { tracing::info!("handle allocation request, bootstrap_addr: {bootstrap_addr}"); @@ -309,15 +305,8 @@ impl RemoteProcessAllocator { } } - Self::handle_allocation_loop( - alloc, - bootstrap_addr, - router, - forwarder_addr, - heartbeat_interval, - cancel_token, - ) - .await; + Self::handle_allocation_loop(alloc, bootstrap_addr, router, forwarder_addr, cancel_token) + .await; mailbox_handle.stop("alloc stopped"); if let Err(e) = mailbox_handle.await { @@ -330,7 +319,6 @@ impl RemoteProcessAllocator { bootstrap_addr: ChannelAddr, router: DialMailboxRouter, forward_addr: ChannelAddr, - heartbeat_interval: Duration, cancel_token: CancellationToken, ) { tracing::info!("starting handle allocation loop"); @@ -419,7 +407,7 @@ impl RemoteProcessAllocator { } } } - _ = RealClock.sleep(heartbeat_interval) => { + _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => { tracing::trace!("sending heartbeat"); tx.post(RemoteProcessProcStateMessage::HeartBeat); } @@ -537,7 +525,6 @@ pub struct RemoteProcessAlloc { initializer: Box, spec: AllocSpec, remote_allocator_port: u16, - remote_allocator_heartbeat_interval: Duration, transport: ChannelTransport, world_id: WorldId, ordered_hosts: Vec, @@ -569,7 +556,6 @@ impl RemoteProcessAlloc { world_id: WorldId, transport: ChannelTransport, remote_allocator_port: u16, - remote_allocator_heartbeat_interval: Duration, initializer: impl RemoteProcessAllocInitializer + Send + Sync + 'static, ) -> Result { let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(transport.clone())) @@ -611,7 +597,6 @@ impl RemoteProcessAlloc { world_id, transport, remote_allocator_port, - remote_allocator_heartbeat_interval, initializer: Box::new(initializer), world_shapes: HashMap::new(), ordered_hosts: Vec::new(), @@ -753,7 +738,6 @@ impl RemoteProcessAlloc { constraints: self.spec.constraints.clone(), }, hosts: hostnames.clone(), - heartbeat_interval: self.remote_allocator_heartbeat_interval, }); let offset = host_shape.slice().offset(); @@ -916,8 +900,9 @@ impl Alloc for RemoteProcessAlloc { }); } - let mut heartbeat_time = - hyperactor::clock::RealClock.now() + self.remote_allocator_heartbeat_interval; + let heartbeat_interval = + config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL); + let mut heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval; // rerun outer loop in case we pushed new items to the event queue let mut reloop = false; let update = loop { @@ -1012,7 +997,7 @@ impl Alloc for RemoteProcessAlloc { _ = clock::RealClock.sleep_until(heartbeat_time) => { self.host_states.iter().for_each(|(_, host_state)| host_state.tx.post(RemoteProcessAllocatorMessage::HeartBeat)); - heartbeat_time = hyperactor::clock::RealClock.now() + self.remote_allocator_heartbeat_interval; + heartbeat_time = hyperactor::clock::RealClock.now() + heartbeat_interval; } closed_host_id = self.comm_watcher_rx.recv() => { @@ -1223,6 +1208,11 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 5)] async fn test_simple() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let serve_addr = ChannelAddr::any(ChannelTransport::Unix); let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1272,7 +1262,6 @@ mod test { spec: spec.clone(), bootstrap_addr, hosts: vec![], - heartbeat_interval: Duration::from_secs(1), }) .await .unwrap(); @@ -1361,6 +1350,11 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_normal_stop() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let serve_addr = ChannelAddr::any(ChannelTransport::Unix); let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1409,7 +1403,6 @@ mod test { spec: spec.clone(), bootstrap_addr, hosts: vec![], - heartbeat_interval: Duration::from_millis(200), }) .await .unwrap(); @@ -1435,6 +1428,11 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_realloc() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let serve_addr = ChannelAddr::any(ChannelTransport::Unix); let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1505,7 +1503,6 @@ mod test { spec: spec.clone(), bootstrap_addr: bootstrap_addr.clone(), hosts: vec![], - heartbeat_interval: Duration::from_millis(200), }) .await .unwrap(); @@ -1522,7 +1519,6 @@ mod test { spec: spec.clone(), bootstrap_addr, hosts: vec![], - heartbeat_interval: Duration::from_millis(200), }) .await .unwrap(); @@ -1553,10 +1549,14 @@ mod test { async fn test_upstream_closed() { // Use temporary config for this test let config = hyperactor::config::global::lock(); - let _guard = config.override_key( + let _guard1 = config.override_key( hyperactor::config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1), ); + let _guard2 = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let serve_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1612,7 +1612,6 @@ mod test { spec: spec.clone(), bootstrap_addr, hosts: vec![], - heartbeat_interval: Duration::from_millis(200), }) .await .unwrap(); @@ -1640,6 +1639,11 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_inner_alloc_failure() { + let config = hyperactor::config::global::lock(); + let _guard = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_secs(60), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let serve_addr = ChannelAddr::any(ChannelTransport::Unix); let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1695,7 +1699,6 @@ mod test { spec: spec.clone(), bootstrap_addr, hosts: vec![], - heartbeat_interval: Duration::from_secs(60), }) .await .unwrap(); @@ -1725,6 +1728,7 @@ mod test_alloc { use std::os::unix::process::ExitStatusExt; use hyperactor::clock::ClockKind; + use hyperactor::config; use ndslice::shape; use nix::sys::signal; use nix::unistd::Pid; @@ -1736,10 +1740,14 @@ mod test_alloc { async fn test_alloc_simple() { // Use temporary config for this test let config = hyperactor::config::global::lock(); - let _guard = config.override_key( + let _guard1 = config.override_key( hyperactor::config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1), ); + let _guard2 = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let spec = AllocSpec { @@ -1748,7 +1756,6 @@ mod test_alloc { }; let world_id = WorldId("test_world_id".to_string()); let transport = ChannelTransport::Unix; - let heartbeat = Duration::from_millis(100); let task1_allocator = RemoteProcessAllocator::new(); let task1_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1789,10 +1796,9 @@ mod test_alloc { }, ]) }); - let mut alloc = - RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, heartbeat, initializer) - .await - .unwrap(); + let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer) + .await + .unwrap(); let mut procs = HashSet::new(); let mut started_procs = HashSet::new(); let mut proc_coords = HashSet::new(); @@ -1859,10 +1865,14 @@ mod test_alloc { async fn test_alloc_host_failure() { // Use temporary config for this test let config = hyperactor::config::global::lock(); - let _guard = config.override_key( + let _guard1 = config.override_key( hyperactor::config::MESSAGE_DELIVERY_TIMEOUT, Duration::from_secs(1), ); + let _guard2 = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let spec = AllocSpec { @@ -1871,7 +1881,6 @@ mod test_alloc { }; let world_id = WorldId("test_world_id".to_string()); let transport = ChannelTransport::Unix; - let heartbeat = Duration::from_millis(100); let task1_allocator = RemoteProcessAllocator::new(); let task1_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1914,10 +1923,9 @@ mod test_alloc { }, ]) }); - let mut alloc = - RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, heartbeat, initializer) - .await - .unwrap(); + let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer) + .await + .unwrap(); let alloc_len = spec.shape.slice().len(); for _ in 0..alloc_len * 2 { match alloc.next().await { @@ -1937,7 +1945,9 @@ mod test_alloc { // now we kill task1 and wait for timeout tracing::info!("aborting task1 allocator"); task1_allocator_handle.abort(); - RealClock.sleep(heartbeat * 2).await; + RealClock + .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2) + .await; for _ in 0..spec.shape.slice().len() / 2 { let proc_state = alloc.next().await.unwrap(); tracing::info!("test received next proc_state: {:?}", proc_state); @@ -1959,7 +1969,9 @@ mod test_alloc { // abort the second host tracing::info!("aborting task2 allocator"); task2_allocator_handle.abort(); - RealClock.sleep(heartbeat * 2).await; + RealClock + .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2) + .await; for _ in 0..spec.shape.slice().len() / 2 { let proc_state = alloc.next().await.unwrap(); tracing::info!("test received next proc_state: {:?}", proc_state); @@ -1984,6 +1996,11 @@ mod test_alloc { unsafe { std::env::set_var("MONARCH_MESSAGE_DELIVERY_TIMEOUT_SECS", "1"); } + let config = hyperactor::config::global::lock(); + let _guard = config.override_key( + hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + Duration::from_millis(100), + ); hyperactor_telemetry::initialize_logging(ClockKind::default()); let spec = AllocSpec { @@ -1992,7 +2009,6 @@ mod test_alloc { }; let world_id = WorldId("test_world_id".to_string()); let transport = ChannelTransport::Unix; - let heartbeat = Duration::from_millis(100); let task1_allocator = RemoteProcessAllocator::new(); let task1_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -2033,10 +2049,9 @@ mod test_alloc { }, ]) }); - let mut alloc = - RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, heartbeat, initializer) - .await - .unwrap(); + let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, initializer) + .await + .unwrap(); let mut procs = HashSet::new(); let mut started_procs = HashSet::new(); let mut proc_coords = HashSet::new(); diff --git a/hyperactor_mesh/test/remote_process_alloc.rs b/hyperactor_mesh/test/remote_process_alloc.rs index b9a286dc2..b0b1563a4 100644 --- a/hyperactor_mesh/test/remote_process_alloc.rs +++ b/hyperactor_mesh/test/remote_process_alloc.rs @@ -95,7 +95,6 @@ async fn main() { WorldId("test_world_id".to_string()), ChannelTransport::Unix, 0, - Duration::from_millis(100), initializer, ) .await diff --git a/monarch_hyperactor/src/alloc.rs b/monarch_hyperactor/src/alloc.rs index 565814247..5aae39a05 100644 --- a/monarch_hyperactor/src/alloc.rs +++ b/monarch_hyperactor/src/alloc.rs @@ -10,7 +10,6 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; use std::sync::Mutex; -use std::time::Duration; use anyhow::anyhow; use async_trait::async_trait; @@ -389,7 +388,6 @@ impl RemoteProcessAllocInitializer for PyRemoteProcessAllocInitializer { pub struct PyRemoteAllocator { world_id: String, initializer: Py, - heartbeat_interval: Duration, } impl Clone for PyRemoteAllocator { @@ -397,7 +395,6 @@ impl Clone for PyRemoteAllocator { Self { world_id: self.world_id.clone(), initializer: Python::with_gil(|py| Py::clone_ref(&self.initializer, py)), - heartbeat_interval: self.heartbeat_interval.clone(), } } } @@ -423,7 +420,6 @@ impl Allocator for PyRemoteAllocator { WorldId(self.world_id.clone()), transport, port, - self.heartbeat_interval, initializer, ) .await?; @@ -437,17 +433,11 @@ impl PyRemoteAllocator { #[pyo3(signature = ( world_id, initializer, - heartbeat_interval = Duration::from_secs(5), ))] - fn new( - world_id: String, - initializer: Py, - heartbeat_interval: Duration, - ) -> PyResult { + fn new(world_id: String, initializer: Py) -> PyResult { Ok(Self { world_id, initializer, - heartbeat_interval, }) } diff --git a/monarch_hyperactor/src/bin/process_allocator/common.rs b/monarch_hyperactor/src/bin/process_allocator/common.rs index 7211d89fb..61c4a4db2 100644 --- a/monarch_hyperactor/src/bin/process_allocator/common.rs +++ b/monarch_hyperactor/src/bin/process_allocator/common.rs @@ -127,7 +127,6 @@ mod tests { }]) }); - let heartbeat = std::time::Duration::from_millis(100); let world_id = WorldId("__unused__".to_string()); let mut alloc = remoteprocess::RemoteProcessAlloc::new( @@ -135,7 +134,6 @@ mod tests { world_id, ChannelTransport::Unix, 0, - heartbeat, initializer, ) .await @@ -195,7 +193,6 @@ mod tests { }]) }); - let heartbeat = std::time::Duration::from_millis(100); let world_id = WorldId("__unused__".to_string()); // Wait at least as long as the timeout before sending any messages. @@ -207,7 +204,6 @@ mod tests { world_id.clone(), ChannelTransport::Unix, 0, - heartbeat, initializer, ) .await @@ -257,7 +253,6 @@ mod tests { .expect_initialize_alloc() .return_once(move || Ok(vec![alloc_host_clone])); - let heartbeat = std::time::Duration::from_millis(100); let world_id = WorldId("__unused__".to_string()); // Attempt to allocate, it should succeed because a timeout happens before @@ -266,7 +261,6 @@ mod tests { world_id.clone(), ChannelTransport::Unix, 0, - heartbeat, initializer, ) .await @@ -289,7 +283,6 @@ mod tests { world_id.clone(), ChannelTransport::Unix, 0, - heartbeat, initializer, ) .await @@ -338,7 +331,6 @@ mod tests { .expect_initialize_alloc() .return_once(move || Ok(vec![alloc_host])); - let heartbeat = std::time::Duration::from_millis(100); let world_id = WorldId("__unused__".to_string()); // Attempt to allocate, it should succeed because a timeout happens before @@ -347,7 +339,6 @@ mod tests { world_id.clone(), ChannelTransport::Unix, 0, - heartbeat, initializer, ) .await diff --git a/python/tests/test_allocator.py b/python/tests/test_allocator.py index 5a37fd96e..44b732072 100644 --- a/python/tests/test_allocator.py +++ b/python/tests/test_allocator.py @@ -252,7 +252,6 @@ async def test_allocate_failure_message(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) await alloc.initialized @@ -276,7 +275,6 @@ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=initializer, - heartbeat_interval=_100_MILLISECONDS, ) spec = AllocSpec(AllocConstraints(), host=1, gpu=1) @@ -305,7 +303,6 @@ async def initialize_alloc(self, match_labels: dict[str, str]) -> list[str]: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=empty_initializer, - heartbeat_interval=_100_MILLISECONDS, ) await allocator.allocate( AllocSpec(AllocConstraints(), host=1, gpu=1) @@ -322,7 +319,6 @@ async def test_allocate_2d_mesh(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) proc_mesh = ProcMesh.from_alloc(alloc) @@ -341,7 +337,6 @@ async def test_stop_proc_mesh_blocking(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) @@ -368,7 +363,6 @@ async def test_wrong_address(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(wrong_host), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) await alloc.initialized @@ -392,7 +386,6 @@ def dummy(self) -> None: allocator = RemoteAllocator( world_id="helloworld", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) spec = AllocSpec(AllocConstraints(), host=2, gpu=2) proc_mesh = ProcMesh.from_alloc(allocator.allocate(spec)) @@ -412,7 +405,6 @@ async def test_stop_proc_mesh(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) proc_mesh = ProcMesh.from_alloc(alloc) @@ -438,7 +430,6 @@ async def test_stop_proc_mesh_context_manager(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) proc_mesh = ProcMesh.from_alloc(alloc) @@ -474,7 +465,6 @@ def setup_env_vars() -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) proc_mesh = ProcMesh.from_alloc(alloc, setup=setup_env_vars) @@ -501,7 +491,6 @@ async def test_stop_proc_mesh_context_manager_multiple_times(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1, host2), - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(spec) proc_mesh = ProcMesh.from_alloc(alloc) @@ -531,7 +520,6 @@ async def test_remote_allocator_with_no_connection(self) -> None: allocator = RemoteAllocator( world_id="test_remote_allocator", initializer=StaticRemoteAllocInitializer(host1), - heartbeat_interval=_100_MILLISECONDS, ) with self.assertRaisesRegex( Exception, "no process has ever been allocated on" @@ -547,12 +535,10 @@ async def test_stacked_1d_meshes(self) -> None: allocator_a = RemoteAllocator( world_id="a", initializer=StaticRemoteAllocInitializer(host1_a), - heartbeat_interval=_100_MILLISECONDS, ) allocator_b = RemoteAllocator( world_id="b", initializer=StaticRemoteAllocInitializer(host1_b), - heartbeat_interval=_100_MILLISECONDS, ) spec_a = AllocSpec(AllocConstraints(), host=1, gpu=2) @@ -634,7 +620,6 @@ async def test_torchx_remote_alloc_initializer_no_match_label_1_mesh(self) -> No allocator = RemoteAllocator( world_id="test", initializer=initializer, - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate(AllocSpec(AllocConstraints(), host=1, gpu=4)) proc_mesh = ProcMesh.from_alloc(alloc) @@ -666,7 +651,6 @@ async def test_torchx_remote_alloc_initializer_with_match_label(self) -> None: allocator = RemoteAllocator( world_id="test", initializer=initializer, - heartbeat_interval=_100_MILLISECONDS, ) alloc = allocator.allocate( AllocSpec( @@ -726,7 +710,6 @@ async def test_log(self) -> None: allocator = RemoteAllocator( world_id="test_actor_logger", initializer=StaticRemoteAllocInitializer(host), - heartbeat_interval=_100_MILLISECONDS, ) spec = AllocSpec(AllocConstraints(), host=1, gpu=2) From d9c7d50caebc9b4678e1f523728dbf14daf621bf Mon Sep 17 00:00:00 2001 From: Pablo Ruiz Fischer Bennetts Date: Mon, 4 Aug 2025 17:13:12 -0700 Subject: [PATCH 2/2] Create hyperactor_mesh config file (#758) Summary: Split config for hyperactor and hyperactor_mesh Rollback Plan: Differential Revision: D79599694 --- hyperactor/src/config.rs | 34 ++-- hyperactor_mesh/src/alloc/remoteprocess.rs | 36 ++-- hyperactor_mesh/src/config.rs | 188 +++++++++++++++++++++ hyperactor_mesh/src/lib.rs | 1 + 4 files changed, 222 insertions(+), 37 deletions(-) create mode 100644 hyperactor_mesh/src/config.rs diff --git a/hyperactor/src/config.rs b/hyperactor/src/config.rs index fc661a317..76c5569f2 100644 --- a/hyperactor/src/config.rs +++ b/hyperactor/src/config.rs @@ -46,9 +46,6 @@ declare_attrs! { /// Timeout used by proc mesh for stopping an actor. pub attr STOP_ACTOR_TIMEOUT: Duration = Duration::from_secs(1); - - /// Heartbeat interval for remote allocator - pub attr REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); } /// Load configuration from environment variables @@ -90,13 +87,6 @@ pub fn from_env() -> Attrs { } } - // Load remote allocator heartbeat interval - if let Ok(val) = env::var("HYPERACTOR_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_SECS") { - if let Ok(parsed) = val.parse::() { - config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = Duration::from_secs(parsed); - } - } - config } @@ -132,9 +122,6 @@ pub fn merge(config: &mut Attrs, other: &Attrs) { if other.contains_key(SPLIT_MAX_BUFFER_SIZE) { config[SPLIT_MAX_BUFFER_SIZE] = other[SPLIT_MAX_BUFFER_SIZE]; } - if other.contains_key(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) { - config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = other[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL]; - } } /// Global configuration functions @@ -180,6 +167,16 @@ pub mod global { static MUTEX: LazyLock> = LazyLock::new(|| std::sync::Mutex::new(())); ConfigLock { _guard: MUTEX.lock().unwrap(), + config: CONFIG.clone(), + } + } + + /// Create a new ConfigLock with a specific config. + pub fn new(config: Arc>) -> ConfigLock { + static MUTEX: LazyLock> = LazyLock::new(|| std::sync::Mutex::new(())); + ConfigLock { + _guard: MUTEX.lock().unwrap(), + config, } } @@ -235,6 +232,7 @@ pub mod global { /// this ConfigLock, ensuring proper synchronization. pub struct ConfigLock { _guard: std::sync::MutexGuard<'static, ()>, + config: Arc>, } impl ConfigLock { @@ -256,7 +254,7 @@ pub mod global { value: T, ) -> ConfigValueGuard<'a, T> { let orig = { - let mut config = CONFIG.write().unwrap(); + let mut config = self.config.write().unwrap(); let orig = config.take_value(key); config.set(key, value); orig @@ -265,6 +263,7 @@ pub mod global { ConfigValueGuard { key, orig, + config: self.config.clone(), _phantom: PhantomData, } } @@ -274,13 +273,14 @@ pub mod global { pub struct ConfigValueGuard<'a, T: 'static> { key: crate::attrs::Key, orig: Option>, + config: Arc>, // This is here so we can hold onto a 'a lifetime. _phantom: PhantomData<&'a ()>, } impl Drop for ConfigValueGuard<'_, T> { fn drop(&mut self) { - let mut config = CONFIG.write().unwrap(); + let mut config = self.config.write().unwrap(); if let Some(orig) = self.orig.take() { config.restore_value(self.key, orig); } else { @@ -305,10 +305,6 @@ mod tests { ); assert_eq!(config[MESSAGE_ACK_EVERY_N_MESSAGES], 1000); assert_eq!(config[SPLIT_MAX_BUFFER_SIZE], 5); - assert_eq!( - config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL], - Duration::from_secs(5) - ); } #[test] diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index 86bbcc4e4..693de6294 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -32,7 +32,6 @@ use hyperactor::channel::TxStatus; use hyperactor::clock; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; -use hyperactor::config; use hyperactor::mailbox::DialMailboxRouter; use hyperactor::mailbox::MailboxServer; use hyperactor::mailbox::monitored_return_handle; @@ -59,6 +58,8 @@ use crate::alloc::AllocatorError; use crate::alloc::ProcState; use crate::alloc::ProcStopReason; use crate::alloc::ProcessAllocator; +use crate::config; +use crate::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL; /// Control messages sent from remote process allocator to local allocator. #[derive(Debug, Clone, Serialize, Deserialize, Named)] @@ -407,7 +408,7 @@ impl RemoteProcessAllocator { } } } - _ = RealClock.sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => { + _ = RealClock.sleep(crate::config::global::get(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL)) => { tracing::trace!("sending heartbeat"); tx.post(RemoteProcessProcStateMessage::HeartBeat); } @@ -1208,9 +1209,9 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 5)] async fn test_simple() { - let config = hyperactor::config::global::lock(); + let config = crate::config::global::lock(); let _guard = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); hyperactor_telemetry::initialize_logging(ClockKind::default()); @@ -1350,9 +1351,9 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_normal_stop() { - let config = hyperactor::config::global::lock(); + let config = crate::config::global::lock(); let _guard = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); hyperactor_telemetry::initialize_logging(ClockKind::default()); @@ -1428,9 +1429,9 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_realloc() { - let config = hyperactor::config::global::lock(); + let config = crate::config::global::lock(); let _guard = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); hyperactor_telemetry::initialize_logging(ClockKind::default()); @@ -1554,7 +1555,7 @@ mod test { Duration::from_secs(1), ); let _guard2 = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); @@ -1640,10 +1641,8 @@ mod test { #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_inner_alloc_failure() { let config = hyperactor::config::global::lock(); - let _guard = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, - Duration::from_secs(60), - ); + let _guard = + config.override_key(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_secs(60)); hyperactor_telemetry::initialize_logging(ClockKind::default()); let serve_addr = ChannelAddr::any(ChannelTransport::Unix); let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix); @@ -1727,6 +1726,7 @@ mod test { mod test_alloc { use std::os::unix::process::ExitStatusExt; + use REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL; use hyperactor::clock::ClockKind; use hyperactor::config; use ndslice::shape; @@ -1745,7 +1745,7 @@ mod test_alloc { Duration::from_secs(1), ); let _guard2 = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); hyperactor_telemetry::initialize_logging(ClockKind::default()); @@ -1870,7 +1870,7 @@ mod test_alloc { Duration::from_secs(1), ); let _guard2 = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); hyperactor_telemetry::initialize_logging(ClockKind::default()); @@ -1946,7 +1946,7 @@ mod test_alloc { tracing::info!("aborting task1 allocator"); task1_allocator_handle.abort(); RealClock - .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2) + .sleep(crate::config::global::get(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2) .await; for _ in 0..spec.shape.slice().len() / 2 { let proc_state = alloc.next().await.unwrap(); @@ -1970,7 +1970,7 @@ mod test_alloc { tracing::info!("aborting task2 allocator"); task2_allocator_handle.abort(); RealClock - .sleep(config::global::get(config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2) + .sleep(config::global::get(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) * 2) .await; for _ in 0..spec.shape.slice().len() / 2 { let proc_state = alloc.next().await.unwrap(); @@ -1998,7 +1998,7 @@ mod test_alloc { } let config = hyperactor::config::global::lock(); let _guard = config.override_key( - hyperactor::config::REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_millis(100), ); hyperactor_telemetry::initialize_logging(ClockKind::default()); diff --git a/hyperactor_mesh/src/config.rs b/hyperactor_mesh/src/config.rs new file mode 100644 index 000000000..6026c72ad --- /dev/null +++ b/hyperactor_mesh/src/config.rs @@ -0,0 +1,188 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Configuration for Hyperactor Mesh. +//! +//! This module provides a centralized way to manage configuration settings for Hyperactor Mesh. +//! It uses the attrs system for type-safe, flexible configuration management that supports +//! environment variables, YAML files, and temporary modifications for tests. + +use std::env; +use std::time::Duration; + +use hyperactor::attrs::Attrs; +use hyperactor::attrs::declare_attrs; + +// Declare configuration keys using the attrs system with defaults +declare_attrs! { + /// Heartbeat interval for remote allocator + pub attr REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); +} + +/// Load configuration from environment variables +pub fn from_env() -> Attrs { + let mut config = Attrs::new(); + + // Load remote allocator heartbeat interval + if let Ok(val) = env::var("HYPERACTOR_MESH_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_SECS") { + if let Ok(parsed) = val.parse::() { + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = Duration::from_secs(parsed); + } + } + + config +} + +/// Merge with another configuration, with the other taking precedence +pub fn merge(config: &mut Attrs, other: &Attrs) { + if other.contains_key(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL) { + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = other[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL]; + } +} + +/// Global configuration functions +pub mod global { + use std::sync::Arc; + use std::sync::LazyLock; + use std::sync::RwLock; + + use hyperactor::attrs::Key; + use hyperactor::config::global::ConfigLock; + + use super::*; + + /// Global configuration instance, initialized from environment variables. + static CONFIG: LazyLock>> = + LazyLock::new(|| Arc::new(RwLock::new(from_env()))); + + /// Get a key from the global configuration. + pub fn get< + T: Send + + Sync + + Copy + + serde::Serialize + + serde::de::DeserializeOwned + + hyperactor::data::Named + + 'static, + >( + key: Key, + ) -> T { + *CONFIG.read().unwrap().get(key).unwrap() + } + + /// Reset the global configuration to defaults (for testing only) + pub fn reset_to_defaults() { + let mut config = CONFIG.write().unwrap(); + *config = Attrs::new(); + } + + /// Acquire the global configuration lock for testing. + pub fn lock() -> ConfigLock { + hyperactor::config::global::new(CONFIG.clone()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_config() { + let config = Attrs::new(); + assert_eq!( + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL], + Duration::from_secs(5) + ); + } + + #[test] + fn test_from_env() { + // Set environment variables + // SAFETY: TODO: Audit that the environment access only happens in single-threaded code. + unsafe { + std::env::set_var( + "HYPERACTOR_MESH_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_SECS", + "30", + ) + }; + + let config = from_env(); + + assert_eq!( + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL], + Duration::from_secs(30) + ); + + // Clean up + // SAFETY: TODO: Audit that the environment access only happens in single-threaded code. + unsafe { std::env::remove_var("HYPERACTOR_MESH_REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL_SECS") }; + } + + #[test] + fn test_merge() { + let mut config1 = Attrs::new(); + let mut config2 = Attrs::new(); + config2[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL] = Duration::from_secs(30); + + merge(&mut config1, &config2); + + assert_eq!( + config1[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL], + Duration::from_secs(30) + ); + } + + #[test] + fn test_global_config() { + let config = global::lock(); + + // Reset global config to defaults to avoid interference from other tests + global::reset_to_defaults(); + + assert_eq!( + global::get(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL), + Duration::from_secs(5) + ); + { + let _guard = + config.override_key(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL, Duration::from_secs(30)); + assert_eq!( + global::get(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL), + Duration::from_secs(30) + ); + } + assert_eq!( + global::get(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL), + Duration::from_secs(5) + ); + } + + #[test] + fn test_defaults() { + // Test that empty config now returns defaults via get_or_default + let config = Attrs::new(); + + // Verify that the config is empty (no values explicitly set) + assert!(config.is_empty()); + + // But getters should still return the defaults from the keys + assert_eq!( + config[REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL], + Duration::from_secs(5) + ); + + // Verify the keys have defaults + assert!(REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL.has_default()); + + // Verify we can get defaults directly from keys + assert_eq!( + REMOTE_ALLOCATOR_HEARTBEAT_INTERVAL.default(), + Some(&Duration::from_secs(5)) + ); + } +} diff --git a/hyperactor_mesh/src/lib.rs b/hyperactor_mesh/src/lib.rs index 4a874d17f..147bb1a45 100644 --- a/hyperactor_mesh/src/lib.rs +++ b/hyperactor_mesh/src/lib.rs @@ -17,6 +17,7 @@ pub mod alloc; mod assign; pub mod bootstrap; pub mod comm; +pub mod config; pub mod connect; pub mod logging; pub mod mesh;