Skip to content

Polymorphic Future await? #757

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions monarch_hyperactor/src/pytokio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,20 @@ impl PyShared {
}
}

#[pyfunction]
fn is_tokio_thread() -> bool {
tokio::runtime::Handle::try_current().is_ok()
}

pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {
hyperactor_mod.add_class::<PyPythonTask>()?;
hyperactor_mod.add_class::<PyShared>()?;
let f = wrap_pyfunction!(is_tokio_thread, hyperactor_mod)?;
f.setattr(
"__module__",
"monarch._rust_bindings.monarch_hyperactor.pytokio",
)?;
hyperactor_mod.add_function(f)?;

Ok(())
}
6 changes: 6 additions & 0 deletions python/monarch/_rust_bindings/monarch_hyperactor/pytokio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ class Shared(Generic[T]):
Create a one-use Task that awaits on this if you want to use other PythonTask apis like with_timeout.
"""
...

def is_tokio_thread() -> bool:
"""
Returns true if the current thread is a tokio worker thread (and block_on will fail).
"""
...
83 changes: 60 additions & 23 deletions python/monarch/_src/actor/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
TypeVar,
)

from monarch._rust_bindings.monarch_hyperactor.pytokio import PythonTask, Shared
from monarch._rust_bindings.monarch_hyperactor.pytokio import (
is_tokio_thread,
PythonTask,
Shared,
)

from typing_extensions import deprecated, Self

Expand Down Expand Up @@ -79,7 +83,11 @@ class _Asyncio(NamedTuple):
fut: asyncio.Future


_Status = _Unawaited | _Complete | _Exception | _Asyncio
class _Tokio(NamedTuple):
shared: Shared


_Status = _Unawaited | _Complete | _Exception | _Asyncio | _Tokio


class Future(Generic[R]):
Expand Down Expand Up @@ -108,31 +116,60 @@ def get(self, timeout: Optional[float] = None) -> R:
return cast("R", value)
case _Exception(exe=exe):
raise exe
case _Tokio(_):
raise ValueError(
"already converted into a pytokio.Shared object, use 'await' from a PythonTask coroutine to get the value."
)
case _:
raise RuntimeError("unknown status")

def __await__(self) -> Generator[Any, Any, R]:
match self._status:
case _Unawaited(coro=coro):
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._status = _Asyncio(fut)

async def mark_complete():
try:
func, value = fut.set_result, await coro
except Exception as e:
func, value = fut.set_exception, e
loop.call_soon_threadsafe(func, value)

PythonTask.from_coroutine(mark_complete()).spawn()
return fut.__await__()
case _Asyncio(fut=fut):
return fut.__await__()
case _:
raise ValueError(
"already converted into a synchronous future, use 'get' to get the value."
)
if asyncio._get_running_loop() is not None:
match self._status:
case _Unawaited(coro=coro):
loop = asyncio.get_running_loop()
fut = loop.create_future()
self._status = _Asyncio(fut)

async def mark_complete():
try:
func, value = fut.set_result, await coro
except Exception as e:
func, value = fut.set_exception, e
loop.call_soon_threadsafe(func, value)

PythonTask.from_coroutine(mark_complete()).spawn()
return fut.__await__()
case _Asyncio(fut=fut):
return fut.__await__()
case _Tokio(_):
raise ValueError(
"already converted into a tokio future, but being awaited from the asyncio loop."
)
case _:
raise ValueError(
"already converted into a synchronous future, use 'get' to get the value."
)
elif is_tokio_thread():
match self._status:
case _Unawaited(coro=coro):
shared = coro.spawn()
self._status = _Tokio(shared)
return shared.__await__()
case _Tokio(shared=shared):
return shared.__await__()
case _Asyncio(_):
raise ValueError(
"already converted into asyncio future, but being awaited from the tokio loop."
)
case _:
raise ValueError(
"already converted into a synchronous future, use 'get' to get the value."
)
else:
raise ValueError(
"__await__ with no active event loop (either asyncio or tokio)"
)

# compatibility with old tensor engine Future objects
# hopefully we do not need done(), add_callback because
Expand Down
3 changes: 1 addition & 2 deletions python/monarch/_src/actor/proc_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ async def _init_manager_actors_coro(
setup_actor = await self._spawn_nonblocking_on(
proc_mesh, "setup", SetupActor, setup
)
# pyre-ignore
await setup_actor.setup.call()._status.coro
await setup_actor.setup.call()

return proc_mesh

Expand Down
Loading