65
65
Endpoint ,
66
66
EndpointProperty ,
67
67
Extent ,
68
+ NotAnEndpoint ,
68
69
Propagator ,
69
70
Selection ,
70
71
)
76
77
from monarch ._src .actor .shape import MeshTrait , NDSlice
77
78
from monarch ._src .actor .sync_state import fake_sync_state
78
79
79
- from monarch ._src .actor .tensor_engine_shim import actor_send
80
+ from monarch ._src .actor .tensor_engine_shim import actor_rref , actor_send
80
81
81
82
if TYPE_CHECKING :
82
83
from monarch ._src .actor .proc_mesh import ProcMesh
@@ -313,8 +314,7 @@ def _send(
313
314
"""
314
315
self ._signature .bind (None , * args , ** kwargs )
315
316
objects , bytes = flatten ((args , kwargs ), _is_ref_or_mailbox )
316
- refs = [obj for obj in objects if hasattr (obj , "__monarch_ref__" )]
317
- if not refs :
317
+ if all (not hasattr (obj , "__monarch_ref__" ) for obj in objects ):
318
318
message = PythonMessage (
319
319
PythonMessageKind .CallMethod (
320
320
self ._name , None if port is None else port ._port_ref
@@ -323,7 +323,7 @@ def _send(
323
323
)
324
324
self ._actor_mesh .cast (message , selection )
325
325
else :
326
- actor_send (self , bytes , refs , port , selection )
326
+ actor_send (self , bytes , objects , port , selection )
327
327
shape = self ._actor_mesh ._shape
328
328
return Extent (shape .labels , shape .ndslice .sizes )
329
329
@@ -335,6 +335,26 @@ def _port(self, once: bool = False) -> "PortTuple[R]":
335
335
), "unexpected receiver type"
336
336
return PortTuple (p , PortReceiver (self ._mailbox , self ._supervise (r ._receiver )))
337
337
338
+ def _rref (self , args , kwargs ):
339
+ self ._signature .bind (None , * args , ** kwargs )
340
+ refs , bytes = flatten ((args , kwargs ), _is_ref_or_mailbox )
341
+
342
+ return actor_rref (self , bytes , refs )
343
+
344
+
345
+ def as_endpoint (
346
+ not_an_endpoint : Callable [P , R ], * , propagate : Propagator = None
347
+ ) -> Endpoint [P , R ]:
348
+ if not isinstance (not_an_endpoint , NotAnEndpoint ):
349
+ raise ValueError ("expected an method of a spawned actor" )
350
+ return ActorEndpoint (
351
+ not_an_endpoint ._ref ._actor_mesh_ref ,
352
+ not_an_endpoint ._name ,
353
+ getattr (not_an_endpoint ._ref , not_an_endpoint ._name ),
354
+ not_an_endpoint ._ref ._mailbox ,
355
+ propagate ,
356
+ )
357
+
338
358
339
359
class Accumulator (Generic [P , R , A ]):
340
360
def __init__ (
@@ -625,18 +645,23 @@ async def handle(
625
645
f" This is likely due to an earlier error: { self ._saved_error } "
626
646
)
627
647
raise AssertionError (error_message )
628
- the_method = getattr (self .instance , method )._method
648
+ the_method = getattr (self .instance , method )
649
+ if isinstance (the_method , EndpointProperty ):
650
+ module = the_method ._method .__module__
651
+ the_method = functools .partial (the_method ._method , self .instance )
652
+ else :
653
+ module = the_method .__module__
629
654
630
655
if inspect .iscoroutinefunction (the_method ):
631
656
632
657
async def instrumented ():
633
658
enter_span (
634
- the_method . __module__ ,
659
+ module ,
635
660
method ,
636
661
str (ctx .mailbox .actor_id ),
637
662
)
638
663
try :
639
- result = await the_method (self . instance , * args , ** kwargs )
664
+ result = await the_method (* args , ** kwargs )
640
665
self ._maybe_exit_debugger ()
641
666
except Exception as e :
642
667
logging .critical (
@@ -649,9 +674,9 @@ async def instrumented():
649
674
650
675
result = await instrumented ()
651
676
else :
652
- enter_span (the_method . __module__ , method , str (ctx .mailbox .actor_id ))
677
+ enter_span (module , method , str (ctx .mailbox .actor_id ))
653
678
with fake_sync_state ():
654
- result = the_method (self . instance , * args , ** kwargs )
679
+ result = the_method (* args , ** kwargs )
655
680
self ._maybe_exit_debugger ()
656
681
exit_span ()
657
682
@@ -758,35 +783,14 @@ def __init__(
758
783
attr_name ,
759
784
attr_value ._method ,
760
785
self ._mailbox ,
786
+ attr_value ._propagator ,
761
787
),
762
788
)
763
789
764
- def __getattr__ (self , name : str ) -> Any :
765
- # This method is called when an attribute is not found
766
- # For linting purposes, we need to tell the type checker that any attribute
767
- # could be an endpoint that's dynamically added at runtime
768
- # At runtime, we still want to raise AttributeError for truly missing attributes
769
-
770
- # Check if this is a method on the underlying class
771
- if hasattr (self ._class , name ):
772
- attr = getattr (self ._class , name )
773
- if isinstance (attr , EndpointProperty ):
774
- # Dynamically create the endpoint
775
- endpoint = ActorEndpoint (
776
- self ._actor_mesh_ref ,
777
- name ,
778
- attr ._method ,
779
- self ._mailbox ,
780
- propagator = attr ._propagator ,
781
- )
782
- # Cache it for future use
783
- setattr (self , name , endpoint )
784
- return endpoint
785
-
786
- # If we get here, it's truly not found
787
- raise AttributeError (
788
- f"'{ self .__class__ .__name__ } ' object has no attribute '{ name } '"
789
- )
790
+ def __getattr__ (self , attr : str ) -> NotAnEndpoint :
791
+ if attr in dir (self ._class ):
792
+ return NotAnEndpoint (self , attr )
793
+ raise AttributeError (attr )
790
794
791
795
def _create (
792
796
self ,
0 commit comments