3434
3535from modin .core .execution .ray .common import MaterializationHook , RayWrapper
3636from modin .logging import get_logger
37+ from modin .utils import _inherit_docstrings
3738
38- ObjectRefType = Union [ray .ObjectRef , ClientObjectRef , None ]
39+ ObjectRefType = Union [ray .ObjectRef , ClientObjectRef ]
3940ObjectRefOrListType = Union [ObjectRefType , List [ObjectRefType ]]
4041ListOrTuple = (list , tuple )
4142
@@ -68,16 +69,18 @@ class DeferredExecution:
6869
6970 Attributes
7071 ----------
71- data : ObjectRefType or DeferredExecution
72+ data : object
7273 The execution input.
7374 func : callable or ObjectRefType
7475 A function to be executed.
75- args : list or tuple
76+ args : list or tuple, optional
7677 Additional positional arguments to be passed in `func`.
77- kwargs : dict
78+ kwargs : dict, optional
7879 Additional keyword arguments to be passed in `func`.
79- num_returns : int
80+ num_returns : int, default: 1
8081 The number of the return values.
82+ flat_data : bool
83+ True means that the data is neither DeferredExecution nor list.
8184 flat_args : bool
8285 True means that there are no lists or DeferredExecution objects in `args`.
8386 In this case, no arguments processing is performed and `args` is passed
@@ -88,26 +91,29 @@ class DeferredExecution:
8891
8992 def __init__ (
9093 self ,
91- data : Union [
92- ObjectRefType ,
93- "DeferredExecution" ,
94- List [Union [ObjectRefType , "DeferredExecution" ]],
95- ],
94+ data : Any ,
9695 func : Union [Callable , ObjectRefType ],
97- args : Union [List [Any ], Tuple [Any ]],
98- kwargs : Dict [str , Any ],
96+ args : Union [List [Any ], Tuple [Any ]] = None ,
97+ kwargs : Dict [str , Any ] = None ,
9998 num_returns = 1 ,
10099 ):
101- if isinstance (data , DeferredExecution ):
102- data .subscribe ()
100+ self .flat_data = self ._flat_args ((data ,))
103101 self .data = data
104102 self .func = func
105- self .args = args
106- self .kwargs = kwargs
107103 self .num_returns = num_returns
108- self .flat_args = self ._flat_args (args )
109- self .flat_kwargs = self ._flat_args (kwargs .values ())
110104 self .subscribers = 0
105+ if args is not None :
106+ self .args = args
107+ self .flat_args = self ._flat_args (args )
108+ else :
109+ self .args = ()
110+ self .flat_args = True
111+ if kwargs is not None :
112+ self .kwargs = kwargs
113+ self .flat_kwargs = self ._flat_args (kwargs .values ())
114+ else :
115+ self .kwargs = {}
116+ self .flat_kwargs = True
111117
112118 @classmethod
113119 def _flat_args (cls , args : Iterable ):
@@ -134,7 +140,7 @@ def _flat_args(cls, args: Iterable):
134140
135141 def exec (
136142 self ,
137- ) -> Tuple [ObjectRefOrListType , Union [ "MetaList" , List ] , Union [int , List [int ]]]:
143+ ) -> Tuple [ObjectRefOrListType , "MetaList" , Union [int , List [int ]]]:
138144 """
139145 Execute this task, if required.
140146
@@ -150,11 +156,29 @@ def exec(
150156 return self .data , self .meta , self .meta_offset
151157
152158 if (
153- not isinstance ( self .data , DeferredExecution )
159+ self .flat_data
154160 and self .flat_args
155161 and self .flat_kwargs
156162 and self .num_returns == 1
157163 ):
164+ # self.data = RayWrapper.materialize(self.data)
165+ # self.args = [
166+ # RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
167+ # for o in self.args
168+ # ]
169+ # self.kwargs = {
170+ # k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
171+ # for k, o in self.kwargs.items()
172+ # }
173+ # obj = _REMOTE_EXEC.exec_func(
174+ # RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
175+ # )
176+ # result, length, width, ip = (
177+ # obj,
178+ # len(obj) if hasattr(obj, "__len__") else 0,
179+ # len(obj.columns) if hasattr(obj, "columns") else 0,
180+ # "",
181+ # )
158182 result , length , width , ip = remote_exec_func .remote (
159183 self .func , self .data , * self .args , ** self .kwargs
160184 )
@@ -166,19 +190,28 @@ def exec(
166190 # it back. After the execution, the result is saved and the counter has no effect.
167191 self .subscribers += 2
168192 consumers , output = self ._deconstruct ()
193+
194+ # assert not any(isinstance(o, ListOrTuple) for o in output)
195+ # tmp = [
196+ # RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
197+ # for o in output
198+ # ]
199+ # list(_REMOTE_EXEC.construct(tmp))
200+
169201 # The last result is the MetaList, so adding +1 here.
170202 num_returns = sum (c .num_returns for c in consumers ) + 1
171203 results = self ._remote_exec_chain (num_returns , * output )
172204 meta = MetaList (results .pop ())
173205 meta_offset = 0
174206 results = iter (results )
175207 for de in consumers :
176- if de .num_returns == 1 :
208+ num_returns = de .num_returns
209+ if num_returns == 1 :
177210 de ._set_result (next (results ), meta , meta_offset )
178211 meta_offset += 2
179212 else :
180213 res = list (islice (results , num_returns ))
181- offsets = list (range (0 , 2 * num_returns , 2 ))
214+ offsets = list (range (meta_offset , meta_offset + 2 * num_returns , 2 ))
182215 de ._set_result (res , meta , offsets )
183216 meta_offset += 2 * num_returns
184217 return self .data , self .meta , self .meta_offset
@@ -303,7 +336,7 @@ def _deconstruct_chain(
303336 out_extend = output .extend
304337 while True :
305338 de .unsubscribe ()
306- if (out_pos := getattr (de , "out_pos" , None )) and not de . has_result :
339+ if not de . has_result and (out_pos := getattr (de , "out_pos" , None )):
307340 out_append (_Tag .REF )
308341 out_append (out_pos )
309342 output [out_pos ] = out_pos
@@ -318,6 +351,7 @@ def _deconstruct_chain(
318351 break
319352 elif not isinstance (data := de .data , DeferredExecution ):
320353 if isinstance (data , ListOrTuple ):
354+ out_append (_Tag .LIST )
321355 yield cls ._deconstruct_list (
322356 data , output , stack , result_consumers , out_append
323357 )
@@ -394,7 +428,13 @@ def _deconstruct_list(
394428 if out_pos := getattr (obj , "out_pos" , None ):
395429 obj .unsubscribe ()
396430 if obj .has_result :
397- out_append (obj .data )
431+ if isinstance (obj .data , ListOrTuple ):
432+ out_append (_Tag .LIST )
433+ yield cls ._deconstruct_list (
434+ obj .data , output , stack , result_consumers , out_append
435+ )
436+ else :
437+ out_append (obj .data )
398438 else :
399439 out_append (_Tag .REF )
400440 out_append (out_pos )
@@ -432,13 +472,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]:
432472 list
433473 The execution results. The last element of this list is the ``MetaList``.
434474 """
435- # Prefer _remote_exec_single_chain(). It has fewer arguments and
436- # does not require the num_returns to be specified in options.
475+ # Prefer _remote_exec_single_chain(). It does not require the num_returns
476+ # to be specified in options.
437477 if num_returns == 2 :
438478 return _remote_exec_single_chain .remote (* args )
439479 else :
440480 return _remote_exec_multi_chain .options (num_returns = num_returns ).remote (
441- num_returns , * args
481+ * args
442482 )
443483
444484 def _set_result (
@@ -456,7 +496,7 @@ def _set_result(
456496 meta : MetaList
457497 meta_offset : int or list of int
458498 """
459- del self .func , self .args , self .kwargs , self . flat_args , self . flat_kwargs
499+ del self .func , self .args , self .kwargs
460500 self .data = result
461501 self .meta = meta
462502 self .meta_offset = meta_offset
@@ -466,6 +506,78 @@ def __reduce__(self):
466506 raise NotImplementedError ("DeferredExecution is not serializable!" )
467507
468508
509+ ObjectRefOrDeType = Union [ObjectRefType , DeferredExecution ]
510+
511+
512+ class DeferredGetItem (DeferredExecution ):
513+ """
514+ Deferred execution task that returns an item at the specified index.
515+
516+ Parameters
517+ ----------
518+ data : ObjectRefOrDeType
519+ The object to get the item from.
520+ idx : int
521+ The item index.
522+ """
523+
524+ def __init__ (self , data : ObjectRefOrDeType , idx : int ):
525+ super ().__init__ (data , self ._remote_fn (), [idx ])
526+ self .index = idx
527+
528+ @_inherit_docstrings (DeferredExecution .exec )
529+ def exec (self ) -> Tuple [ObjectRefType , "MetaList" , int ]:
530+ if self .has_result :
531+ return self .data , self .meta , self .meta_offset
532+
533+ if not isinstance (self .data , DeferredExecution ) or self .data .num_returns == 1 :
534+ return super ().exec ()
535+
536+ # If `data` is a `DeferredExecution`, that returns multiple results,
537+ # it's not required to execute `_remote_fn()`. We can only execute
538+ # `data` and get the result by index.
539+ self ._data_exec ()
540+ return self .data , self .meta , self .meta_offset
541+
542+ @property
543+ @_inherit_docstrings (DeferredExecution .has_result )
544+ def has_result (self ):
545+ if super ().has_result :
546+ return True
547+
548+ if (
549+ isinstance (self .data , DeferredExecution )
550+ and self .data .has_result
551+ and self .data .num_returns != 1
552+ ):
553+ self ._data_exec ()
554+ return True
555+
556+ return False
557+
558+ def _data_exec (self ):
559+ """Execute the `data` task and get the result."""
560+ obj , meta , offsets = self .data .exec ()
561+ self ._set_result (obj [self .index ], meta , offsets [self .index ])
562+
563+ @classmethod
564+ def _remote_fn (cls ) -> ObjectRefType :
565+ """
566+ Return the remote function reference.
567+
568+ Returns
569+ -------
570+ ObjectRefType
571+ """
572+ if (fn := getattr (cls , "_GET_ITEM" , None )) is None :
573+
574+ def get_item (obj , index ): # pragma: no cover
575+ return obj [index ]
576+
577+ cls ._GET_ITEM = fn = RayWrapper .put (get_item )
578+ return fn
579+
580+
469581class MetaList :
470582 """
471583 Meta information, containing the result lengths and the worker address.
@@ -478,6 +590,10 @@ class MetaList:
478590 def __init__ (self , obj : Union [ray .ObjectID , ClientObjectRef , List ]):
479591 self ._obj = obj
480592
593+ def materialize (self ):
594+ """Materialized the list, if required."""
595+ self ._obj = RayWrapper .materialize (self ._obj )
596+
481597 def __getitem__ (self , index ):
482598 """
483599 Get item at the specified index.
@@ -508,7 +624,7 @@ def __setitem__(self, index, value):
508624 obj [index ] = value
509625
510626
511- class MetaListHook (MaterializationHook ):
627+ class MetaListHook (MaterializationHook , DeferredGetItem ):
512628 """
513629 Used by MetaList.__getitem__() for lazy materialization and getting a single value from the list.
514630
@@ -521,6 +637,7 @@ class MetaListHook(MaterializationHook):
521637 """
522638
523639 def __init__ (self , meta : MetaList , idx : int ):
640+ super ().__init__ (meta ._obj , idx )
524641 self .meta = meta
525642 self .idx = idx
526643
@@ -605,7 +722,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any:
605722 raise err
606723
607724 @classmethod
608- def construct (cls , num_returns : int , args : Tuple ): # pragma: no cover
725+ def construct (cls , args : Tuple ): # pragma: no cover
609726 """
610727 Construct and execute the specified chain.
611728
@@ -615,7 +732,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
615732
616733 Parameters
617734 ----------
618- num_returns : int
619735 args : tuple
620736
621737 Yields
@@ -687,7 +803,7 @@ def construct_chain(
687803
688804 while chain :
689805 fn = pop ()
690- if fn == tg_e :
806+ if fn is tg_e :
691807 lst .append (obj )
692808 break
693809
@@ -717,10 +833,10 @@ def construct_chain(
717833
718834 itr = iter ([obj ] if num_returns == 1 else obj )
719835 for _ in range (num_returns ):
720- obj = next (itr )
721- meta .append (len (obj ) if hasattr (obj , "__len__" ) else 0 )
722- meta .append (len (obj .columns ) if hasattr (obj , "columns" ) else 0 )
723- yield obj
836+ o = next (itr )
837+ meta .append (len (o ) if hasattr (o , "__len__" ) else 0 )
838+ meta .append (len (o .columns ) if hasattr (o , "columns" ) else 0 )
839+ yield o
724840
725841 @classmethod
726842 def construct_list (
@@ -834,20 +950,18 @@ def _remote_exec_single_chain(
834950 -------
835951 Generator
836952 """
837- return remote_executor .construct (num_returns = 2 , args = args )
953+ return remote_executor .construct (args = args )
838954
839955
840956@ray .remote
841957def _remote_exec_multi_chain (
842- num_returns : int , * args : Tuple , remote_executor = _REMOTE_EXEC
958+ * args : Tuple , remote_executor = _REMOTE_EXEC
843959) -> Generator : # pragma: no cover
844960 """
845961 Execute the deconstructed chain with a multiple return values in a worker process.
846962
847963 Parameters
848964 ----------
849- num_returns : int
850- The number of return values.
851965 *args : tuple
852966 A deconstructed chain to be executed.
853967 remote_executor : _RemoteExecutor, default: _REMOTE_EXEC
@@ -857,4 +971,4 @@ def _remote_exec_multi_chain(
857971 -------
858972 Generator
859973 """
860- return remote_executor .construct (num_returns , args )
974+ return remote_executor .construct (args )
0 commit comments