diff --git a/monarch_extension/src/code_sync.rs b/monarch_extension/src/code_sync.rs index b08ef6972..4b225b107 100644 --- a/monarch_extension/src/code_sync.rs +++ b/monarch_extension/src/code_sync.rs @@ -191,7 +191,7 @@ impl CodeSyncMeshClient { remote: RemoteWorkspace, auto_reload: bool, ) -> PyResult> { - pyo3_async_runtimes::tokio::future_into_py( + monarch_hyperactor::runtime::future_into_py( py, CodeSyncMeshClient::sync_workspace_( self.actor_mesh.clone(), @@ -211,7 +211,7 @@ impl CodeSyncMeshClient { auto_reload: bool, ) -> PyResult> { let actor_mesh = self.actor_mesh.clone(); - pyo3_async_runtimes::tokio::future_into_py( + monarch_hyperactor::runtime::future_into_py( py, try_join_all(workspaces.into_iter().map(|workspace| { CodeSyncMeshClient::sync_workspace_( diff --git a/monarch_extension/src/simulation_tools.rs b/monarch_extension/src/simulation_tools.rs index a73855e0c..f7bba2b9c 100644 --- a/monarch_extension/src/simulation_tools.rs +++ b/monarch_extension/src/simulation_tools.rs @@ -14,7 +14,7 @@ use pyo3::prelude::*; #[pyfunction] #[pyo3(name = "start_event_loop")] pub fn start_simnet_event_loop(py: Python) -> PyResult> { - pyo3_async_runtimes::tokio::future_into_py(py, async move { + monarch_hyperactor::runtime::future_into_py(py, async move { simnet::start(); Ok(()) }) @@ -24,7 +24,7 @@ pub fn start_simnet_event_loop(py: Python) -> PyResult> { #[pyo3(name="sleep",signature=(seconds))] pub fn py_sim_sleep<'py>(py: Python<'py>, seconds: f64) -> PyResult> { let millis = (seconds * 1000.0).ceil() as u64; - pyo3_async_runtimes::tokio::future_into_py(py, async move { + monarch_hyperactor::runtime::future_into_py(py, async move { let duration = tokio::time::Duration::from_millis(millis); SimClock.sleep(duration).await; Ok(()) diff --git a/monarch_extension/src/tensor_worker.rs b/monarch_extension/src/tensor_worker.rs index e1744ea8b..31cbe86cb 100644 --- a/monarch_extension/src/tensor_worker.rs +++ b/monarch_extension/src/tensor_worker.rs @@ -1389,7 +1389,6 @@ fn worker_main(py: Python<'_>) -> PyResult<()> { BinaryArgs::Pipe => bootstrap_pipe(), BinaryArgs::WorkerServer { rd, wr } => { worker_server( - get_tokio_runtime(), // SAFETY: Raw FD passed in from parent. BufReader::new(File::from(unsafe { OwnedFd::from_raw_fd(rd) })), // SAFETY: Raw FD passed in from parent. diff --git a/monarch_hyperactor/Cargo.toml b/monarch_hyperactor/Cargo.toml index 9ca462543..4138eaa94 100644 --- a/monarch_hyperactor/Cargo.toml +++ b/monarch_hyperactor/Cargo.toml @@ -29,6 +29,7 @@ lazy_static = "1.5" monarch_types = { version = "0.0.0", path = "../monarch_types" } ndslice = { version = "0.0.0", path = "../ndslice" } nix = { version = "0.30.1", features = ["dir", "event", "hostname", "inotify", "ioctl", "mman", "mount", "net", "poll", "ptrace", "reboot", "resource", "sched", "signal", "term", "time", "user", "zerocopy"] } +once_cell = "1.21" opentelemetry = "0.29" pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] } pyo3-async-runtimes = { version = "0.24", features = ["attributes", "tokio-runtime"] } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index abf3b689a..84893fa37 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -502,26 +502,24 @@ impl Actor for PythonActor { /// Create a new TaskLocals with its own asyncio event loop in a dedicated thread. fn create_task_locals() -> pyo3_async_runtimes::TaskLocals { - let (tx, rx) = std::sync::mpsc::channel(); - let _ = std::thread::spawn(move || { - Python::with_gil(|py| { - let asyncio = Python::import(py, "asyncio").unwrap(); - let event_loop = asyncio.call_method0("new_event_loop").unwrap(); - asyncio - .call_method1("set_event_loop", (event_loop.clone(),)) - .unwrap(); - - let task_locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone()) - .copy_context(py) - .unwrap(); - tx.send(task_locals).unwrap(); - if let Err(e) = event_loop.call_method0("run_forever") { - eprintln!("Event loop stopped with error: {:?}", e); - } - let _ = event_loop.call_method0("close"); - }); - }); - rx.recv().unwrap() + Python::with_gil(|py| { + let asyncio = Python::import(py, "asyncio").unwrap(); + let event_loop = asyncio.call_method0("new_event_loop").unwrap(); + let task_locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone()) + .copy_context(py) + .unwrap(); + + let kwargs = PyDict::new(py); + let target = event_loop.getattr("run_forever").unwrap(); + kwargs.set_item("target", target).unwrap(); + let thread = py + .import("threading") + .unwrap() + .call_method("Thread", (), Some(&kwargs)) + .unwrap(); + thread.call_method0("start").unwrap(); + task_locals + }) } // [Panics in async endpoints] diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 910b3ef20..85b242075 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -249,7 +249,7 @@ impl PythonActorMesh { fn stop<'py>(&self, py: Python<'py>) -> PyResult> { let actor_mesh = self.inner.clone(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { + crate::runtime::future_into_py(py, async move { let actor_mesh = actor_mesh .take() .await diff --git a/monarch_hyperactor/src/alloc.rs b/monarch_hyperactor/src/alloc.rs index 5aae39a05..38c0ad1b6 100644 --- a/monarch_hyperactor/src/alloc.rs +++ b/monarch_hyperactor/src/alloc.rs @@ -318,17 +318,17 @@ impl PyRemoteProcessAllocInitializer { .call_method1("initialize_alloc", args) .map(|x| x.unbind()) })?; - get_tokio_runtime() - .spawn_blocking(move || -> PyResult> { - // call the function as implemented in python - Python::with_gil(|py| { - let asyncio = py.import("asyncio").unwrap(); - let addrs = asyncio.call_method1("run", (coro,))?; - let addrs: PyResult> = addrs.extract(); - addrs - }) + let r = get_tokio_runtime().spawn_blocking(move || -> PyResult> { + // call the function as implemented in python + Python::with_gil(|py| { + let asyncio = py.import("asyncio").unwrap(); + let addrs = asyncio.call_method1("run", (coro,))?; + let addrs: PyResult> = addrs.extract(); + addrs }) - .await + }); + + r.await .map_err(|err| PyRuntimeError::new_err(err.to_string()))? } diff --git a/monarch_hyperactor/src/bootstrap.rs b/monarch_hyperactor/src/bootstrap.rs index 22f0c594f..7e54ae759 100644 --- a/monarch_hyperactor/src/bootstrap.rs +++ b/monarch_hyperactor/src/bootstrap.rs @@ -26,7 +26,7 @@ pub fn bootstrap_main(py: Python) -> PyResult> { }; hyperactor::tracing::debug!("entering async bootstrap"); - pyo3_async_runtimes::tokio::future_into_py::<_, ()>(py, async move { + crate::runtime::future_into_py::<_, ()>(py, async move { // SAFETY: // - Only one of these is ever created. // - This is the entry point of this program, so this will be dropped when diff --git a/monarch_hyperactor/src/lib.rs b/monarch_hyperactor/src/lib.rs index cb57c0e61..909ff1b24 100644 --- a/monarch_hyperactor/src/lib.rs +++ b/monarch_hyperactor/src/lib.rs @@ -8,6 +8,8 @@ #![allow(unsafe_op_in_unsafe_fn)] #![feature(exit_status_error)] +#![feature(mapped_lock_guards)] +#![feature(rwlock_downgrade)] pub mod actor; pub mod actor_mesh; diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 50035239e..8649c651b 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -436,7 +436,7 @@ pub(super) struct PythonUndeliverablePortReceiver { impl PythonUndeliverablePortReceiver { fn recv<'py>(&mut self, py: Python<'py>) -> PyResult> { let receiver = self.inner.clone(); - pyo3_async_runtimes::tokio::future_into_py(py, async move { + crate::runtime::future_into_py(py, async move { let message = receiver .lock() .await diff --git a/monarch_hyperactor/src/proc.rs b/monarch_hyperactor/src/proc.rs index 44e29739b..43f232b47 100644 --- a/monarch_hyperactor/src/proc.rs +++ b/monarch_hyperactor/src/proc.rs @@ -137,7 +137,7 @@ impl PyProc { ) -> PyResult> { let proc = self.inner.clone(); let pickled_type = PickledPyObject::pickle(actor.as_any())?; - pyo3_async_runtimes::tokio::future_into_py(py, async move { + crate::runtime::future_into_py(py, async move { Ok(PythonActorHandle { inner: proc .spawn(name.as_deref().unwrap_or("anon"), pickled_type) diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 181d4d262..f66d68cc2 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -307,7 +307,7 @@ impl PyProcMesh { )); } let receiver = self.user_monitor_receiver.clone(); - Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move { + Ok(crate::runtime::future_into_py(py, async move { // Create a new user monitor Ok(PyProcMeshMonitor { receiver }) })? @@ -385,7 +385,7 @@ impl PyProcMeshMonitor { fn __anext__(&self, py: Python<'_>) -> PyResult { let receiver = self.receiver.clone(); - Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move { + Ok(crate::runtime::future_into_py(py, async move { let receiver = receiver .borrow() .map_err(|_| PyRuntimeError::new_err("`ProcEvent receiver` is shutdown"))?; diff --git a/monarch_hyperactor/src/runtime.rs b/monarch_hyperactor/src/runtime.rs index 9bfddb21a..bac4d467c 100644 --- a/monarch_hyperactor/src/runtime.rs +++ b/monarch_hyperactor/src/runtime.rs @@ -8,36 +8,72 @@ use std::cell::Cell; use std::future::Future; -use std::sync::OnceLock; +use std::pin::Pin; +use std::sync::RwLock; +use std::sync::RwLockReadGuard; use std::time::Duration; use anyhow::Result; -use anyhow::anyhow; use anyhow::ensure; +use once_cell::unsync::OnceCell as UnsyncOnceCell; use pyo3::PyResult; use pyo3::Python; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use pyo3::types::PyAnyMethods; +use pyo3::types::PyCFunction; +use pyo3::types::PyDict; +use pyo3::types::PyTuple; +use pyo3_async_runtimes::TaskLocals; +use tokio::task; -pub fn get_tokio_runtime() -> &'static tokio::runtime::Runtime { - static INSTANCE: OnceLock = OnceLock::new(); - INSTANCE.get_or_init(|| { - tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .unwrap() +// this must be a RwLock and only return a guard for reading the runtime. +// Otherwise multiple threads can deadlock fighting for the Runtime object if they hold it +// while blocking on something. +static INSTANCE: std::sync::LazyLock>> = + std::sync::LazyLock::new(|| RwLock::new(None)); + +pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::runtime::Runtime> { + // First try to get a read lock and check if runtime exists + { + let read_guard = INSTANCE.read().unwrap(); + if read_guard.is_some() { + return RwLockReadGuard::map(read_guard, |lock: &Option| { + lock.as_ref().unwrap() + }); + } + // Drop the read lock by letting it go out of scope + } + + // Runtime doesn't exist, upgrade to write lock to initialize + let mut write_guard = INSTANCE.write().unwrap(); + if write_guard.is_none() { + *write_guard = Some( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(), + ); + } + + // Downgrade write lock to read lock and return the reference + let read_guard = std::sync::RwLockWriteGuard::downgrade(write_guard); + RwLockReadGuard::map(read_guard, |lock: &Option| { + lock.as_ref().unwrap() }) } +pub fn shutdown_tokio_runtime() { + INSTANCE.write().unwrap().take().map(|x| { + x.shutdown_timeout(Duration::from_secs(1)); + }); +} + thread_local! { static IS_MAIN_THREAD: Cell = const { Cell::new(false) }; } pub fn initialize(py: Python) -> Result<()> { - pyo3_async_runtimes::tokio::init_with_runtime(get_tokio_runtime()) - .map_err(|_| anyhow!("failed to initialize py3 async runtime"))?; - // Initialize thread local state to identify the main Python thread. let threading = Python::import(py, "threading")?; let main_thread = threading.call_method0("main_thread")?; @@ -48,6 +84,17 @@ pub fn initialize(py: Python) -> Result<()> { ); IS_MAIN_THREAD.set(true); + let closure = PyCFunction::new_closure( + py, + None, + None, + |args: &Bound<'_, PyTuple>, _kwargs: Option<&Bound<'_, PyDict>>| { + shutdown_tokio_runtime(); + }, + ) + .unwrap(); + let atexit = py.import("atexit").unwrap(); + atexit.call_method1("register", (closure,)).unwrap(); Ok(()) } @@ -131,3 +178,52 @@ pub fn register_python_bindings(runtime_mod: &Bound<'_, PyModule>) -> PyResult<( runtime_mod.add_function(sleep_indefinitely_fn)?; Ok(()) } + +struct SimpleRuntime; + +impl pyo3_async_runtimes::generic::Runtime for SimpleRuntime { + type JoinError = task::JoinError; + type JoinHandle = task::JoinHandle<()>; + + fn spawn(fut: F) -> Self::JoinHandle + where + F: Future + Send + 'static, + { + get_tokio_runtime().spawn(async move { + fut.await; + }) + } +} + +tokio::task_local! { + static TASK_LOCALS: UnsyncOnceCell; +} + +impl pyo3_async_runtimes::generic::ContextExt for SimpleRuntime { + fn scope(locals: TaskLocals, fut: F) -> Pin + Send>> + where + F: Future + Send + 'static, + { + let cell = UnsyncOnceCell::new(); + cell.set(locals).unwrap(); + + Box::pin(TASK_LOCALS.scope(cell, fut)) + } + + fn get_task_locals() -> Option { + TASK_LOCALS + .try_with(|c| { + c.get() + .map(|locals| Python::with_gil(|py| locals.clone_ref(py))) + }) + .unwrap_or_default() + } +} + +pub fn future_into_py(py: Python, fut: F) -> PyResult> +where + F: Future> + Send + 'static, + T: for<'py> IntoPyObject<'py>, +{ + pyo3_async_runtimes::generic::future_into_py::(py, fut) +} diff --git a/monarch_tensor_worker/src/bootstrap.rs b/monarch_tensor_worker/src/bootstrap.rs index a6ad0d49c..f2773975f 100644 --- a/monarch_tensor_worker/src/bootstrap.rs +++ b/monarch_tensor_worker/src/bootstrap.rs @@ -21,6 +21,7 @@ use hyperactor::actor::ActorStatus; use hyperactor::channel::ChannelAddr; use hyperactor_multiprocess::proc_actor::ProcActor; use hyperactor_multiprocess::system_actor::ProcLifecycleMode; +use monarch_hyperactor::runtime::get_tokio_runtime; use pyo3::prelude::*; use pyo3::types::PyType; use serde::Deserialize; @@ -174,11 +175,7 @@ impl WorkerServerResponse { } } -pub fn worker_server( - rt: &tokio::runtime::Runtime, - inp: impl BufRead, - mut outp: impl Write, -) -> Result<()> { +pub fn worker_server(inp: impl BufRead, mut outp: impl Write) -> Result<()> { tracing::info!("running worker server on {}", std::process::id()); for line in inp.lines() { @@ -199,7 +196,7 @@ pub fn worker_server( supervision_update_interval_in_sec: 5, extra_proc_labels: Some(labels), }; - let res = rt + let res = get_tokio_runtime() .block_on(async move { anyhow::Ok(bootstrap_worker_proc(args).await?.await) }); WorkerServerResponse::Finished { error: match res { diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 5df210b64..3f95a8507 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -219,9 +219,8 @@ async def test_sync_actor_sync_client() -> None: async def test_proc_mesh_size() -> None: proc = local_proc_mesh(gpus=2) assert 2 == proc.size("gpus") - proc.initialized.get() - - await proc.stop() + # proc.initialized.get() + # await proc.stop() @pytest.mark.timeout(60)