Skip to content

To fix the shutdown errors, we need to stop the tokio loop... #750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monarch_extension/src/code_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl CodeSyncMeshClient {
remote: RemoteWorkspace,
auto_reload: bool,
) -> PyResult<Bound<'py, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(
monarch_hyperactor::runtime::future_into_py(
py,
CodeSyncMeshClient::sync_workspace_(
self.actor_mesh.clone(),
Expand All @@ -211,7 +211,7 @@ impl CodeSyncMeshClient {
auto_reload: bool,
) -> PyResult<Bound<'py, PyAny>> {
let actor_mesh = self.actor_mesh.clone();
pyo3_async_runtimes::tokio::future_into_py(
monarch_hyperactor::runtime::future_into_py(
py,
try_join_all(workspaces.into_iter().map(|workspace| {
CodeSyncMeshClient::sync_workspace_(
Expand Down
4 changes: 2 additions & 2 deletions monarch_extension/src/simulation_tools.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use pyo3::prelude::*;
#[pyfunction]
#[pyo3(name = "start_event_loop")]
pub fn start_simnet_event_loop(py: Python) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
monarch_hyperactor::runtime::future_into_py(py, async move {
simnet::start();
Ok(())
})
Expand All @@ -24,7 +24,7 @@ pub fn start_simnet_event_loop(py: Python) -> PyResult<Bound<'_, PyAny>> {
#[pyo3(name="sleep",signature=(seconds))]
pub fn py_sim_sleep<'py>(py: Python<'py>, seconds: f64) -> PyResult<Bound<'py, PyAny>> {
let millis = (seconds * 1000.0).ceil() as u64;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
monarch_hyperactor::runtime::future_into_py(py, async move {
let duration = tokio::time::Duration::from_millis(millis);
SimClock.sleep(duration).await;
Ok(())
Expand Down
1 change: 0 additions & 1 deletion monarch_extension/src/tensor_worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1389,7 +1389,6 @@ fn worker_main(py: Python<'_>) -> PyResult<()> {
BinaryArgs::Pipe => bootstrap_pipe(),
BinaryArgs::WorkerServer { rd, wr } => {
worker_server(
get_tokio_runtime(),
// SAFETY: Raw FD passed in from parent.
BufReader::new(File::from(unsafe { OwnedFd::from_raw_fd(rd) })),
// SAFETY: Raw FD passed in from parent.
Expand Down
1 change: 1 addition & 0 deletions monarch_hyperactor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ lazy_static = "1.5"
monarch_types = { version = "0.0.0", path = "../monarch_types" }
ndslice = { version = "0.0.0", path = "../ndslice" }
nix = { version = "0.30.1", features = ["dir", "event", "hostname", "inotify", "ioctl", "mman", "mount", "net", "poll", "ptrace", "reboot", "resource", "sched", "signal", "term", "time", "user", "zerocopy"] }
once_cell = "1.21"
opentelemetry = "0.29"
pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] }
pyo3-async-runtimes = { version = "0.24", features = ["attributes", "tokio-runtime"] }
Expand Down
38 changes: 18 additions & 20 deletions monarch_hyperactor/src/actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -502,26 +502,24 @@ impl Actor for PythonActor {

/// Create a new TaskLocals with its own asyncio event loop in a dedicated thread.
fn create_task_locals() -> pyo3_async_runtimes::TaskLocals {
let (tx, rx) = std::sync::mpsc::channel();
let _ = std::thread::spawn(move || {
Python::with_gil(|py| {
let asyncio = Python::import(py, "asyncio").unwrap();
let event_loop = asyncio.call_method0("new_event_loop").unwrap();
asyncio
.call_method1("set_event_loop", (event_loop.clone(),))
.unwrap();

let task_locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone())
.copy_context(py)
.unwrap();
tx.send(task_locals).unwrap();
if let Err(e) = event_loop.call_method0("run_forever") {
eprintln!("Event loop stopped with error: {:?}", e);
}
let _ = event_loop.call_method0("close");
});
});
rx.recv().unwrap()
Python::with_gil(|py| {
let asyncio = Python::import(py, "asyncio").unwrap();
let event_loop = asyncio.call_method0("new_event_loop").unwrap();
let task_locals = pyo3_async_runtimes::TaskLocals::new(event_loop.clone())
.copy_context(py)
.unwrap();

let kwargs = PyDict::new(py);
let target = event_loop.getattr("run_forever").unwrap();
kwargs.set_item("target", target).unwrap();
let thread = py
.import("threading")
.unwrap()
.call_method("Thread", (), Some(&kwargs))
.unwrap();
thread.call_method0("start").unwrap();
task_locals
})
}

// [Panics in async endpoints]
Expand Down
2 changes: 1 addition & 1 deletion monarch_hyperactor/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ impl PythonActorMesh {

fn stop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let actor_mesh = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::runtime::future_into_py(py, async move {
let actor_mesh = actor_mesh
.take()
.await
Expand Down
20 changes: 10 additions & 10 deletions monarch_hyperactor/src/alloc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,17 +318,17 @@ impl PyRemoteProcessAllocInitializer {
.call_method1("initialize_alloc", args)
.map(|x| x.unbind())
})?;
get_tokio_runtime()
.spawn_blocking(move || -> PyResult<Vec<String>> {
// call the function as implemented in python
Python::with_gil(|py| {
let asyncio = py.import("asyncio").unwrap();
let addrs = asyncio.call_method1("run", (coro,))?;
let addrs: PyResult<Vec<String>> = addrs.extract();
addrs
})
let r = get_tokio_runtime().spawn_blocking(move || -> PyResult<Vec<String>> {
// call the function as implemented in python
Python::with_gil(|py| {
let asyncio = py.import("asyncio").unwrap();
let addrs = asyncio.call_method1("run", (coro,))?;
let addrs: PyResult<Vec<String>> = addrs.extract();
addrs
})
.await
});

r.await
.map_err(|err| PyRuntimeError::new_err(err.to_string()))?
}

Expand Down
2 changes: 1 addition & 1 deletion monarch_hyperactor/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ pub fn bootstrap_main(py: Python) -> PyResult<Bound<PyAny>> {
};

hyperactor::tracing::debug!("entering async bootstrap");
pyo3_async_runtimes::tokio::future_into_py::<_, ()>(py, async move {
crate::runtime::future_into_py::<_, ()>(py, async move {
// SAFETY:
// - Only one of these is ever created.
// - This is the entry point of this program, so this will be dropped when
Expand Down
2 changes: 2 additions & 0 deletions monarch_hyperactor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#![allow(unsafe_op_in_unsafe_fn)]
#![feature(exit_status_error)]
#![feature(mapped_lock_guards)]
#![feature(rwlock_downgrade)]

pub mod actor;
pub mod actor_mesh;
Expand Down
2 changes: 1 addition & 1 deletion monarch_hyperactor/src/mailbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ pub(super) struct PythonUndeliverablePortReceiver {
impl PythonUndeliverablePortReceiver {
fn recv<'py>(&mut self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
let receiver = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::runtime::future_into_py(py, async move {
let message = receiver
.lock()
.await
Expand Down
2 changes: 1 addition & 1 deletion monarch_hyperactor/src/proc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ impl PyProc {
) -> PyResult<Bound<'py, PyAny>> {
let proc = self.inner.clone();
let pickled_type = PickledPyObject::pickle(actor.as_any())?;
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::runtime::future_into_py(py, async move {
Ok(PythonActorHandle {
inner: proc
.spawn(name.as_deref().unwrap_or("anon"), pickled_type)
Expand Down
4 changes: 2 additions & 2 deletions monarch_hyperactor/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ impl PyProcMesh {
));
}
let receiver = self.user_monitor_receiver.clone();
Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(crate::runtime::future_into_py(py, async move {
// Create a new user monitor
Ok(PyProcMeshMonitor { receiver })
})?
Expand Down Expand Up @@ -385,7 +385,7 @@ impl PyProcMeshMonitor {

fn __anext__(&self, py: Python<'_>) -> PyResult<PyObject> {
let receiver = self.receiver.clone();
Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move {
Ok(crate::runtime::future_into_py(py, async move {
let receiver = receiver
.borrow()
.map_err(|_| PyRuntimeError::new_err("`ProcEvent receiver` is shutdown"))?;
Expand Down
120 changes: 108 additions & 12 deletions monarch_hyperactor/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,72 @@

use std::cell::Cell;
use std::future::Future;
use std::sync::OnceLock;
use std::pin::Pin;
use std::sync::RwLock;
use std::sync::RwLockReadGuard;
use std::time::Duration;

use anyhow::Result;
use anyhow::anyhow;
use anyhow::ensure;
use once_cell::unsync::OnceCell as UnsyncOnceCell;
use pyo3::PyResult;
use pyo3::Python;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyAnyMethods;
use pyo3::types::PyCFunction;
use pyo3::types::PyDict;
use pyo3::types::PyTuple;
use pyo3_async_runtimes::TaskLocals;
use tokio::task;

pub fn get_tokio_runtime() -> &'static tokio::runtime::Runtime {
static INSTANCE: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
INSTANCE.get_or_init(|| {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
// this must be a RwLock and only return a guard for reading the runtime.
// Otherwise multiple threads can deadlock fighting for the Runtime object if they hold it
// while blocking on something.
static INSTANCE: std::sync::LazyLock<RwLock<Option<tokio::runtime::Runtime>>> =
std::sync::LazyLock::new(|| RwLock::new(None));

pub fn get_tokio_runtime<'l>() -> std::sync::MappedRwLockReadGuard<'l, tokio::runtime::Runtime> {
// First try to get a read lock and check if runtime exists
{
let read_guard = INSTANCE.read().unwrap();
if read_guard.is_some() {
return RwLockReadGuard::map(read_guard, |lock: &Option<tokio::runtime::Runtime>| {
lock.as_ref().unwrap()
});
}
// Drop the read lock by letting it go out of scope
}

// Runtime doesn't exist, upgrade to write lock to initialize
let mut write_guard = INSTANCE.write().unwrap();
if write_guard.is_none() {
*write_guard = Some(
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap(),
);
}

// Downgrade write lock to read lock and return the reference
let read_guard = std::sync::RwLockWriteGuard::downgrade(write_guard);
RwLockReadGuard::map(read_guard, |lock: &Option<tokio::runtime::Runtime>| {
lock.as_ref().unwrap()
})
}

pub fn shutdown_tokio_runtime() {
INSTANCE.write().unwrap().take().map(|x| {
x.shutdown_timeout(Duration::from_secs(1));
});
}

thread_local! {
static IS_MAIN_THREAD: Cell<bool> = const { Cell::new(false) };
}

pub fn initialize(py: Python) -> Result<()> {
pyo3_async_runtimes::tokio::init_with_runtime(get_tokio_runtime())
.map_err(|_| anyhow!("failed to initialize py3 async runtime"))?;

// Initialize thread local state to identify the main Python thread.
let threading = Python::import(py, "threading")?;
let main_thread = threading.call_method0("main_thread")?;
Expand All @@ -48,6 +84,17 @@ pub fn initialize(py: Python) -> Result<()> {
);
IS_MAIN_THREAD.set(true);

let closure = PyCFunction::new_closure(
py,
None,
None,
|args: &Bound<'_, PyTuple>, _kwargs: Option<&Bound<'_, PyDict>>| {
shutdown_tokio_runtime();
},
)
.unwrap();
let atexit = py.import("atexit").unwrap();
atexit.call_method1("register", (closure,)).unwrap();
Ok(())
}

Expand Down Expand Up @@ -131,3 +178,52 @@ pub fn register_python_bindings(runtime_mod: &Bound<'_, PyModule>) -> PyResult<(
runtime_mod.add_function(sleep_indefinitely_fn)?;
Ok(())
}

struct SimpleRuntime;

impl pyo3_async_runtimes::generic::Runtime for SimpleRuntime {
type JoinError = task::JoinError;
type JoinHandle = task::JoinHandle<()>;

fn spawn<F>(fut: F) -> Self::JoinHandle
where
F: Future<Output = ()> + Send + 'static,
{
get_tokio_runtime().spawn(async move {
fut.await;
})
}
}

tokio::task_local! {
static TASK_LOCALS: UnsyncOnceCell<TaskLocals>;
}

impl pyo3_async_runtimes::generic::ContextExt for SimpleRuntime {
fn scope<F, R>(locals: TaskLocals, fut: F) -> Pin<Box<dyn Future<Output = R> + Send>>
where
F: Future<Output = R> + Send + 'static,
{
let cell = UnsyncOnceCell::new();
cell.set(locals).unwrap();

Box::pin(TASK_LOCALS.scope(cell, fut))
}

fn get_task_locals() -> Option<TaskLocals> {
TASK_LOCALS
.try_with(|c| {
c.get()
.map(|locals| Python::with_gil(|py| locals.clone_ref(py)))
})
.unwrap_or_default()
}
}

pub fn future_into_py<F, T>(py: Python, fut: F) -> PyResult<Bound<PyAny>>
where
F: Future<Output = PyResult<T>> + Send + 'static,
T: for<'py> IntoPyObject<'py>,
{
pyo3_async_runtimes::generic::future_into_py::<SimpleRuntime, F, T>(py, fut)
}
9 changes: 3 additions & 6 deletions monarch_tensor_worker/src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use hyperactor::actor::ActorStatus;
use hyperactor::channel::ChannelAddr;
use hyperactor_multiprocess::proc_actor::ProcActor;
use hyperactor_multiprocess::system_actor::ProcLifecycleMode;
use monarch_hyperactor::runtime::get_tokio_runtime;
use pyo3::prelude::*;
use pyo3::types::PyType;
use serde::Deserialize;
Expand Down Expand Up @@ -174,11 +175,7 @@ impl WorkerServerResponse {
}
}

pub fn worker_server(
rt: &tokio::runtime::Runtime,
inp: impl BufRead,
mut outp: impl Write,
) -> Result<()> {
pub fn worker_server(inp: impl BufRead, mut outp: impl Write) -> Result<()> {
tracing::info!("running worker server on {}", std::process::id());

for line in inp.lines() {
Expand All @@ -199,7 +196,7 @@ pub fn worker_server(
supervision_update_interval_in_sec: 5,
extra_proc_labels: Some(labels),
};
let res = rt
let res = get_tokio_runtime()
.block_on(async move { anyhow::Ok(bootstrap_worker_proc(args).await?.await) });
WorkerServerResponse::Finished {
error: match res {
Expand Down
5 changes: 2 additions & 3 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,8 @@ async def test_sync_actor_sync_client() -> None:
async def test_proc_mesh_size() -> None:
proc = local_proc_mesh(gpus=2)
assert 2 == proc.size("gpus")
proc.initialized.get()

await proc.stop()
# proc.initialized.get()
# await proc.stop()


@pytest.mark.timeout(60)
Expand Down