diff --git a/monarch_hyperactor/src/pytokio.rs b/monarch_hyperactor/src/pytokio.rs index 4a6b44ad7..6537e0a88 100644 --- a/monarch_hyperactor/src/pytokio.rs +++ b/monarch_hyperactor/src/pytokio.rs @@ -282,8 +282,20 @@ impl PyShared { } } +#[pyfunction] +fn is_tokio_thread() -> bool { + tokio::runtime::Handle::try_current().is_ok() +} + pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + let f = wrap_pyfunction!(is_tokio_thread, hyperactor_mod)?; + f.setattr( + "__module__", + "monarch._rust_bindings.monarch_hyperactor.pytokio", + )?; + hyperactor_mod.add_function(f)?; + Ok(()) } diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi index b2b8911ca..3611a3263 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi @@ -73,3 +73,9 @@ class Shared(Generic[T]): Create a one-use Task that awaits on this if you want to use other PythonTask apis like with_timeout. """ ... + +def is_tokio_thread() -> bool: + """ + Returns true if the current thread is a tokio worker thread (and block_on will fail). + """ + ... diff --git a/python/monarch/_src/actor/future.py b/python/monarch/_src/actor/future.py index f42f28c63..ec60fb23a 100644 --- a/python/monarch/_src/actor/future.py +++ b/python/monarch/_src/actor/future.py @@ -20,7 +20,11 @@ TypeVar, ) -from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared +from monarch._rust_bindings.monarch_hyperactor.pytokio import ( + is_tokio_thread, + PythonTask, + Shared, +) from typing_extensions import deprecated, Self @@ -79,7 +83,11 @@ class _Asyncio(NamedTuple): fut: asyncio.Future -_Status = _Unawaited | _Complete | _Exception | _Asyncio +class _Tokio(NamedTuple): + shared: Shared + + +_Status = _Unawaited | _Complete | _Exception | _Asyncio | _Tokio class Future(Generic[R]): @@ -108,31 +116,60 @@ def get(self, timeout: Optional[float] = None) -> R: return cast("R", value) case _Exception(exe=exe): raise exe + case _Tokio(_): + raise ValueError( + "already converted into a pytokio.Shared object, use 'await' from a PythonTask coroutine to get the value." + ) case _: raise RuntimeError("unknown status") def __await__(self) -> Generator[Any, Any, R]: - match self._status: - case _Unawaited(coro=coro): - loop = asyncio.get_running_loop() - fut = loop.create_future() - self._status = _Asyncio(fut) - - async def mark_complete(): - try: - func, value = fut.set_result, await coro - except Exception as e: - func, value = fut.set_exception, e - loop.call_soon_threadsafe(func, value) - - PythonTask.from_coroutine(mark_complete()).spawn() - return fut.__await__() - case _Asyncio(fut=fut): - return fut.__await__() - case _: - raise ValueError( - "already converted into a synchronous future, use 'get' to get the value." - ) + if asyncio._get_running_loop() is not None: + match self._status: + case _Unawaited(coro=coro): + loop = asyncio.get_running_loop() + fut = loop.create_future() + self._status = _Asyncio(fut) + + async def mark_complete(): + try: + func, value = fut.set_result, await coro + except Exception as e: + func, value = fut.set_exception, e + loop.call_soon_threadsafe(func, value) + + PythonTask.from_coroutine(mark_complete()).spawn() + return fut.__await__() + case _Asyncio(fut=fut): + return fut.__await__() + case _Tokio(_): + raise ValueError( + "already converted into a tokio future, but being awaited from the asyncio loop." + ) + case _: + raise ValueError( + "already converted into a synchronous future, use 'get' to get the value." + ) + elif is_tokio_thread(): + match self._status: + case _Unawaited(coro=coro): + shared = coro.spawn() + self._status = _Tokio(shared) + return shared.__await__() + case _Tokio(shared=shared): + return shared.__await__() + case _Asyncio(_): + raise ValueError( + "already converted into asyncio future, but being awaited from the tokio loop." + ) + case _: + raise ValueError( + "already converted into a synchronous future, use 'get' to get the value." + ) + else: + raise ValueError( + "__await__ with no active event loop (either asyncio or tokio)" + ) # compatibility with old tensor engine Future objects # hopefully we do not need done(), add_callback because diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index ff94d0e35..6b49528ad 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -197,8 +197,7 @@ async def _init_manager_actors_coro( setup_actor = await self._spawn_nonblocking_on( proc_mesh, "setup", SetupActor, setup ) - # pyre-ignore - await setup_actor.setup.call()._status.coro + await setup_actor.setup.call() return proc_mesh