Skip to content

Commit e579c25

Browse files
zdevitofacebook-github-bot
authored andcommitted
tensor engine, implement rref on actor endpoints (#625)
Summary: Pull Request resolved: #625 Allows actor endpoints to take tensors as arguments and return them as monarch.Tensor objects. Actor endpoints that do this must have a propagator function defined. "cached" does not currently work because the caching machinery cannot make the actor call. Introduces `as_endpoint()` which can take something that was not annotated on a class and turn it into an endpoint. This is useful when needing to define a per-actor propagation function for rref calls because there otherwise is nowhere to put it. ghstack-source-id: 298650755 Reviewed By: suo Differential Revision: D78782587 fbshipit-source-id: 079fa58143fc32e29edb8147b5f226b810a787f7
1 parent bd0aa21 commit e579c25

File tree

14 files changed

+306
-112
lines changed

14 files changed

+306
-112
lines changed

monarch_extension/src/convert.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,18 @@ fn create_map(py: Python) -> HashMap<u64, FnType> {
443443
},
444444
))
445445
});
446+
m.insert(key("CallActorMethod"), |p| {
447+
Ok(WorkerMessage::CallActorMethod(worker::ActorMethodParams {
448+
call: worker::ActorCallParams {
449+
seq: p.parseSeq("seq")?,
450+
broker_id: p.parse("broker_id")?,
451+
local_state: p.parseRefList("local_state")?,
452+
mutates: p.parseRefList("mutates")?,
453+
stream: p.parseStreamRef("stream")?,
454+
},
455+
results: p.parseFlatReferences("result")?,
456+
}))
457+
});
446458
m
447459
}
448460

monarch_extension/src/tensor_worker.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,7 @@ pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> P
13731373
WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(),
13741374
WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(),
13751375
WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(),
1376+
WorkerMessage::CallActorMethod { .. } => unimplemented!(),
13761377
}
13771378
}
13781379

monarch_messages/src/worker.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,13 @@ pub struct ActorCallParams {
415415
pub mutates: Vec<Ref>,
416416
pub stream: StreamRef,
417417
}
418+
419+
#[derive(Serialize, Deserialize, Debug, Clone)]
420+
pub struct ActorMethodParams {
421+
pub results: Vec<Option<Ref>>,
422+
pub call: ActorCallParams,
423+
}
424+
418425
/// Type of reduction for [`WorkerMessage::Reduce`].
419426
#[derive(Debug, Clone, Serialize, Deserialize)]
420427
pub enum Reduction {
@@ -817,6 +824,7 @@ pub enum WorkerMessage {
817824
},
818825

819826
SendResultOfActorCall(ActorCallParams),
827+
CallActorMethod(ActorMethodParams),
820828
PipeRecv {
821829
seq: Seq,
822830
/// Result refs.

monarch_simulator/src/worker.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,14 @@ impl WorkerMessageHandler for WorkerActor {
319319
bail!("unimplemented: send_result_of_actor_call");
320320
}
321321

322+
async fn call_actor_method(
323+
&mut self,
324+
_cx: &hyperactor::Context<Self>,
325+
_params: ActorMethodParams,
326+
) -> Result<()> {
327+
bail!("unimplemented: call_actor_method");
328+
}
329+
322330
async fn command_group(
323331
&mut self,
324332
cx: &hyperactor::Context<Self>,

monarch_tensor_worker/src/lib.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ use monarch_messages::controller::ControllerMessageClient;
7272
use monarch_messages::controller::Seq;
7373
use monarch_messages::wire_value::WireValue;
7474
use monarch_messages::worker::ActorCallParams;
75+
use monarch_messages::worker::ActorMethodParams;
7576
use monarch_messages::worker::CallFunctionError;
7677
use monarch_messages::worker::CallFunctionParams;
7778
use monarch_messages::worker::Factory;
@@ -860,6 +861,15 @@ impl WorkerMessageHandler for WorkerActor {
860861
.await?;
861862
Ok(())
862863
}
864+
async fn call_actor_method(
865+
&mut self,
866+
cx: &hyperactor::Context<Self>,
867+
params: ActorMethodParams,
868+
) -> Result<()> {
869+
let stream = self.try_get_stream(params.call.stream)?;
870+
stream.call_actor_method(cx, params).await?;
871+
Ok(())
872+
}
863873
async fn split_comm(
864874
&mut self,
865875
cx: &hyperactor::Context<Self>,

monarch_tensor_worker/src/stream.rs

Lines changed: 71 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use monarch_messages::controller::ControllerMessageClient;
4545
use monarch_messages::controller::Seq;
4646
use monarch_messages::controller::WorkerError;
4747
use monarch_messages::worker::ActorCallParams;
48+
use monarch_messages::worker::ActorMethodParams;
4849
use monarch_messages::worker::CallFunctionError;
4950
use monarch_messages::worker::CallFunctionParams;
5051
use monarch_messages::worker::SeqError;
@@ -265,6 +266,7 @@ pub enum StreamMessage {
265266
GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle<Option<TensorCellResult>>),
266267

267268
SendResultOfActorCall(ActorId, ActorCallParams),
269+
CallActorMethod(ActorMethodParams),
268270
}
269271

270272
impl StreamMessage {
@@ -1062,6 +1064,44 @@ impl StreamActor {
10621064
self.env.insert(dest, rvalue.clone());
10631065
Ok(())
10641066
}
1067+
async fn call_actor(
1068+
&mut self,
1069+
cx: &Context<'_, Self>,
1070+
params: ActorCallParams,
1071+
) -> Result<PyObject, CallFunctionError> {
1072+
let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1073+
params
1074+
.local_state
1075+
.into_iter()
1076+
.map(|elem| {
1077+
// SAFETY: python is gonna make unsafe copies of this stuff anyway
1078+
unsafe {
1079+
let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1080+
Ok(x)
1081+
}
1082+
})
1083+
.collect()
1084+
});
1085+
1086+
let (send, recv) = cx.open_once_port();
1087+
let state = LocalState {
1088+
response_port: send,
1089+
state: local_state?,
1090+
};
1091+
let x: u64 = params.seq.into();
1092+
let message = LocalStateBrokerMessage::Set(x as usize, state);
1093+
1094+
let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1095+
broker
1096+
.send(message)
1097+
.map_err(|e| CallFunctionError::Error(e.into()))?;
1098+
let result = recv
1099+
.recv()
1100+
.await
1101+
.map_err(|e| CallFunctionError::Error(e.into()))?;
1102+
1103+
result.map_err(|pyerr| anyhow::Error::msg(pyerr.to_string()).into())
1104+
}
10651105
}
10661106

10671107
#[async_trait]
@@ -1669,62 +1709,38 @@ impl StreamMessageHandler for StreamActor {
16691709
worker_actor_id: ActorId,
16701710
params: ActorCallParams,
16711711
) -> anyhow::Result<()> {
1672-
// TODO: handle mutates
1673-
let local_state: Result<Vec<PyObject>> = Python::with_gil(|py| {
1674-
params
1675-
.local_state
1676-
.into_iter()
1677-
.map(|elem| {
1678-
// SAFETY: python is gonna make unsafe copies of this stuff anyway
1679-
unsafe {
1680-
let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into();
1681-
Ok(x)
1682-
}
1683-
})
1684-
.collect()
1685-
});
1686-
1687-
let (send, recv) = cx.open_once_port();
1688-
1689-
let state = LocalState {
1690-
response_port: send,
1691-
state: local_state?,
1692-
};
1693-
let x: u64 = params.seq.into();
1694-
let message = LocalStateBrokerMessage::Set(x as usize, state);
1712+
let seq = params.seq;
1713+
let mutates = params.mutates.clone();
1714+
self.try_define(cx, seq, vec![], &mutates, async |self| {
1715+
let value = self.call_actor(cx, params).await?;
1716+
let result = Python::with_gil(|py| {
1717+
pickle_python_result(py, value.into_bound(py), worker_actor_id)
1718+
})?;
1719+
let result = Serialized::serialize(&result).unwrap();
1720+
self.controller_actor
1721+
.fetch_result(cx, seq, Ok(result))
1722+
.await?;
1723+
Ok(vec![])
1724+
})
1725+
.await
1726+
}
16951727

1696-
let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap();
1697-
broker.send(message)?;
1698-
let result = recv.recv().await?;
1699-
1700-
match result {
1701-
Err(pyerr) => {
1702-
// If result has "exception" as its kind, then
1703-
// we need to unpickle and turn it into a WorkerError
1704-
// and call remote_function_failed otherwise the
1705-
// controller assumes the object is correct and doesn't handle
1706-
// dependency tracking correctly.
1707-
let err = Python::with_gil(|py| -> Result<WorkerError, SerializablePyErr> {
1708-
Ok(WorkerError {
1709-
worker_actor_id,
1710-
backtrace: pyerr.to_string(),
1711-
})
1712-
})?;
1713-
self.controller_actor
1714-
.remote_function_failed(cx, params.seq, err)
1715-
.await?;
1716-
}
1717-
Ok(value) => {
1718-
let result = Python::with_gil(|py| {
1719-
pickle_python_result(py, value.into_bound(py), worker_actor_id)
1720-
})?;
1721-
let result = Serialized::serialize(&result).unwrap();
1722-
self.controller_actor
1723-
.fetch_result(cx, params.seq, Ok(result))
1724-
.await?;
1725-
}
1726-
}
1727-
Ok(())
1728+
async fn call_actor_method(
1729+
&mut self,
1730+
cx: &Context<Self>,
1731+
params: ActorMethodParams,
1732+
) -> anyhow::Result<()> {
1733+
let seq = params.call.seq;
1734+
let mutates = params.call.mutates.clone();
1735+
self.try_define(cx, seq, params.results, &mutates, async |self| {
1736+
let result = self.call_actor(cx, params.call).await?;
1737+
let result = Python::with_gil(|py| {
1738+
PyTree::<RValue>::extract_bound(&result.into_bound(py))
1739+
.map_err(SerializablePyErr::from_fn(py))
1740+
})?;
1741+
Ok(result.into_leaves())
1742+
})
1743+
.await
17281744
}
17291745

17301746
async fn set_value(

python/monarch/_src/actor/actor_mesh.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
Endpoint,
6666
EndpointProperty,
6767
Extent,
68+
NotAnEndpoint,
6869
Propagator,
6970
Selection,
7071
)
@@ -76,7 +77,7 @@
7677
from monarch._src.actor.shape import MeshTrait, NDSlice
7778
from monarch._src.actor.sync_state import fake_sync_state
7879

79-
from monarch._src.actor.tensor_engine_shim import actor_send
80+
from monarch._src.actor.tensor_engine_shim import actor_rref, actor_send
8081

8182
if TYPE_CHECKING:
8283
from monarch._src.actor.proc_mesh import ProcMesh
@@ -313,8 +314,7 @@ def _send(
313314
"""
314315
self._signature.bind(None, *args, **kwargs)
315316
objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
316-
refs = [obj for obj in objects if hasattr(obj, "__monarch_ref__")]
317-
if not refs:
317+
if all(not hasattr(obj, "__monarch_ref__") for obj in objects):
318318
message = PythonMessage(
319319
PythonMessageKind.CallMethod(
320320
self._name, None if port is None else port._port_ref
@@ -323,7 +323,7 @@ def _send(
323323
)
324324
self._actor_mesh.cast(message, selection)
325325
else:
326-
actor_send(self, bytes, refs, port, selection)
326+
actor_send(self, bytes, objects, port, selection)
327327
shape = self._actor_mesh._shape
328328
return Extent(shape.labels, shape.ndslice.sizes)
329329

@@ -335,6 +335,26 @@ def _port(self, once: bool = False) -> "PortTuple[R]":
335335
), "unexpected receiver type"
336336
return PortTuple(p, PortReceiver(self._mailbox, self._supervise(r._receiver)))
337337

338+
def _rref(self, args, kwargs):
339+
self._signature.bind(None, *args, **kwargs)
340+
refs, bytes = flatten((args, kwargs), _is_ref_or_mailbox)
341+
342+
return actor_rref(self, bytes, refs)
343+
344+
345+
def as_endpoint(
346+
not_an_endpoint: Callable[P, R], *, propagate: Propagator = None
347+
) -> Endpoint[P, R]:
348+
if not isinstance(not_an_endpoint, NotAnEndpoint):
349+
raise ValueError("expected an method of a spawned actor")
350+
return ActorEndpoint(
351+
not_an_endpoint._ref._actor_mesh_ref,
352+
not_an_endpoint._name,
353+
getattr(not_an_endpoint._ref, not_an_endpoint._name),
354+
not_an_endpoint._ref._mailbox,
355+
propagate,
356+
)
357+
338358

339359
class Accumulator(Generic[P, R, A]):
340360
def __init__(
@@ -625,18 +645,23 @@ async def handle(
625645
f" This is likely due to an earlier error: {self._saved_error}"
626646
)
627647
raise AssertionError(error_message)
628-
the_method = getattr(self.instance, method)._method
648+
the_method = getattr(self.instance, method)
649+
if isinstance(the_method, EndpointProperty):
650+
module = the_method._method.__module__
651+
the_method = functools.partial(the_method._method, self.instance)
652+
else:
653+
module = the_method.__module__
629654

630655
if inspect.iscoroutinefunction(the_method):
631656

632657
async def instrumented():
633658
enter_span(
634-
the_method.__module__,
659+
module,
635660
method,
636661
str(ctx.mailbox.actor_id),
637662
)
638663
try:
639-
result = await the_method(self.instance, *args, **kwargs)
664+
result = await the_method(*args, **kwargs)
640665
self._maybe_exit_debugger()
641666
except Exception as e:
642667
logging.critical(
@@ -649,9 +674,9 @@ async def instrumented():
649674

650675
result = await instrumented()
651676
else:
652-
enter_span(the_method.__module__, method, str(ctx.mailbox.actor_id))
677+
enter_span(module, method, str(ctx.mailbox.actor_id))
653678
with fake_sync_state():
654-
result = the_method(self.instance, *args, **kwargs)
679+
result = the_method(*args, **kwargs)
655680
self._maybe_exit_debugger()
656681
exit_span()
657682

@@ -758,35 +783,14 @@ def __init__(
758783
attr_name,
759784
attr_value._method,
760785
self._mailbox,
786+
attr_value._propagator,
761787
),
762788
)
763789

764-
def __getattr__(self, name: str) -> Any:
765-
# This method is called when an attribute is not found
766-
# For linting purposes, we need to tell the type checker that any attribute
767-
# could be an endpoint that's dynamically added at runtime
768-
# At runtime, we still want to raise AttributeError for truly missing attributes
769-
770-
# Check if this is a method on the underlying class
771-
if hasattr(self._class, name):
772-
attr = getattr(self._class, name)
773-
if isinstance(attr, EndpointProperty):
774-
# Dynamically create the endpoint
775-
endpoint = ActorEndpoint(
776-
self._actor_mesh_ref,
777-
name,
778-
attr._method,
779-
self._mailbox,
780-
propagator=attr._propagator,
781-
)
782-
# Cache it for future use
783-
setattr(self, name, endpoint)
784-
return endpoint
785-
786-
# If we get here, it's truly not found
787-
raise AttributeError(
788-
f"'{self.__class__.__name__}' object has no attribute '{name}'"
789-
)
790+
def __getattr__(self, attr: str) -> NotAnEndpoint:
791+
if attr in dir(self._class):
792+
return NotAnEndpoint(self, attr)
793+
raise AttributeError(attr)
790794

791795
def _create(
792796
self,

0 commit comments

Comments
 (0)