@@ -72,12 +72,14 @@ class DeferredExecution:
7272 The execution input.
7373 func : callable or ObjectRefType
7474 A function to be executed.
75- args : list or tuple
75+ args : list or tuple, optional
7676 Additional positional arguments to be passed in `func`.
77- kwargs : dict
77+ kwargs : dict, optional
7878 Additional keyword arguments to be passed in `func`.
79- num_returns : int
79+ num_returns : int, default: 1
8080 The number of the return values.
81+ flat_data : bool
82+ True means that the data is neither DeferredExecution nor list.
8183 flat_args : bool
8284 True means that there are no lists or DeferredExecution objects in `args`.
8385 In this case, no arguments processing is performed and `args` is passed
@@ -88,26 +90,29 @@ class DeferredExecution:
8890
8991 def __init__ (
9092 self ,
91- data : Union [
92- ObjectRefType ,
93- "DeferredExecution" ,
94- List [Union [ObjectRefType , "DeferredExecution" ]],
95- ],
93+ data : Any ,
9694 func : Union [Callable , ObjectRefType ],
97- args : Union [List [Any ], Tuple [Any ]],
98- kwargs : Dict [str , Any ],
95+ args : Union [List [Any ], Tuple [Any ]] = None ,
96+ kwargs : Dict [str , Any ] = None ,
9997 num_returns = 1 ,
10098 ):
101- if isinstance (data , DeferredExecution ):
102- data .subscribe ()
99+ self .flat_data = self ._flat_args ((data ,))
103100 self .data = data
104101 self .func = func
105- self .args = args
106- self .kwargs = kwargs
107102 self .num_returns = num_returns
108- self .flat_args = self ._flat_args (args )
109- self .flat_kwargs = self ._flat_args (kwargs .values ())
110103 self .subscribers = 0
104+ if args is not None :
105+ self .args = args
106+ self .flat_args = self ._flat_args (args )
107+ else :
108+ self .args = ()
109+ self .flat_args = True
110+ if kwargs is not None :
111+ self .kwargs = kwargs
112+ self .flat_kwargs = self ._flat_args (kwargs .values ())
113+ else :
114+ self .kwargs = {}
115+ self .flat_kwargs = True
111116
112117 @classmethod
113118 def _flat_args (cls , args : Iterable ):
@@ -134,7 +139,7 @@ def _flat_args(cls, args: Iterable):
134139
135140 def exec (
136141 self ,
137- ) -> Tuple [ObjectRefOrListType , Union [ "MetaList" , List ] , Union [int , List [int ]]]:
142+ ) -> Tuple [ObjectRefOrListType , "MetaList" , Union [int , List [int ]]]:
138143 """
139144 Execute this task, if required.
140145
@@ -150,7 +155,7 @@ def exec(
150155 return self .data , self .meta , self .meta_offset
151156
152157 if (
153- not isinstance ( self .data , DeferredExecution )
158+ self .flat_data
154159 and self .flat_args
155160 and self .flat_kwargs
156161 and self .num_returns == 1
@@ -166,14 +171,16 @@ def exec(
166171 # it back. After the execution, the result is saved and the counter has no effect.
167172 self .subscribers += 2
168173 consumers , output = self ._deconstruct ()
174+ assert not any (isinstance (o , ListOrTuple ) for o in output )
169175 # The last result is the MetaList, so adding +1 here.
170176 num_returns = sum (c .num_returns for c in consumers ) + 1
171177 results = self ._remote_exec_chain (num_returns , * output )
172178 meta = MetaList (results .pop ())
173179 meta_offset = 0
174180 results = iter (results )
175181 for de in consumers :
176- if de .num_returns == 1 :
182+ num_returns = de .num_returns
183+ if num_returns == 1 :
177184 de ._set_result (next (results ), meta , meta_offset )
178185 meta_offset += 2
179186 else :
@@ -318,6 +325,7 @@ def _deconstruct_chain(
318325 break
319326 elif not isinstance (data := de .data , DeferredExecution ):
320327 if isinstance (data , ListOrTuple ):
328+ out_append (_Tag .LIST )
321329 yield cls ._deconstruct_list (
322330 data , output , stack , result_consumers , out_append
323331 )
@@ -394,7 +402,13 @@ def _deconstruct_list(
394402 if out_pos := getattr (obj , "out_pos" , None ):
395403 obj .unsubscribe ()
396404 if obj .has_result :
397- out_append (obj .data )
405+ if isinstance (obj .data , ListOrTuple ):
406+ out_append (_Tag .LIST )
407+ yield cls ._deconstruct_list (
408+ obj .data , output , stack , result_consumers , out_append
409+ )
410+ else :
411+ out_append (obj .data )
398412 else :
399413 out_append (_Tag .REF )
400414 out_append (out_pos )
@@ -432,13 +446,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]:
432446 list
433447 The execution results. The last element of this list is the ``MetaList``.
434448 """
435- # Prefer _remote_exec_single_chain(). It has fewer arguments and
436- # does not require the num_returns to be specified in options.
449+ # Prefer _remote_exec_single_chain(). It does not require the num_returns
450+ # to be specified in options.
437451 if num_returns == 2 :
438452 return _remote_exec_single_chain .remote (* args )
439453 else :
440454 return _remote_exec_multi_chain .options (num_returns = num_returns ).remote (
441- num_returns , * args
455+ * args
442456 )
443457
444458 def _set_result (
@@ -456,7 +470,7 @@ def _set_result(
456470 meta : MetaList
457471 meta_offset : int or list of int
458472 """
459- del self .func , self .args , self .kwargs , self . flat_args , self . flat_kwargs
473+ del self .func , self .args , self .kwargs
460474 self .data = result
461475 self .meta = meta
462476 self .meta_offset = meta_offset
@@ -564,7 +578,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any:
564578 raise err
565579
566580 @classmethod
567- def construct (cls , num_returns : int , args : Tuple ): # pragma: no cover
581+ def construct (cls , args : Tuple ): # pragma: no cover
568582 """
569583 Construct and execute the specified chain.
570584
@@ -574,7 +588,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover
574588
575589 Parameters
576590 ----------
577- num_returns : int
578591 args : tuple
579592
580593 Yields
@@ -646,7 +659,7 @@ def construct_chain(
646659
647660 while chain :
648661 fn = pop ()
649- if fn == tg_e :
662+ if fn is tg_e :
650663 lst .append (obj )
651664 break
652665
@@ -676,10 +689,10 @@ def construct_chain(
676689
677690 itr = iter ([obj ] if num_returns == 1 else obj )
678691 for _ in range (num_returns ):
679- obj = next (itr )
680- meta .append (len (obj ) if hasattr (obj , "__len__" ) else 0 )
681- meta .append (len (obj .columns ) if hasattr (obj , "columns" ) else 0 )
682- yield obj
692+ o = next (itr )
693+ meta .append (len (o ) if hasattr (o , "__len__" ) else 0 )
694+ meta .append (len (o .columns ) if hasattr (o , "columns" ) else 0 )
695+ yield o
683696
684697 @classmethod
685698 def construct_list (
@@ -793,20 +806,18 @@ def _remote_exec_single_chain(
793806 -------
794807 Generator
795808 """
796- return remote_executor .construct (num_returns = 2 , args = args )
809+ return remote_executor .construct (args = args )
797810
798811
799812@ray .remote
800813def _remote_exec_multi_chain (
801- num_returns : int , * args : Tuple , remote_executor = _REMOTE_EXEC
814+ * args : Tuple , remote_executor = _REMOTE_EXEC
802815) -> Generator : # pragma: no cover
803816 """
804817 Execute the deconstructed chain with a multiple return values in a worker process.
805818
806819 Parameters
807820 ----------
808- num_returns : int
809- The number of return values.
810821 *args : tuple
811822 A deconstructed chain to be executed.
812823 remote_executor : _RemoteExecutor, default: _REMOTE_EXEC
@@ -816,4 +827,4 @@ def _remote_exec_multi_chain(
816827 -------
817828 Generator
818829 """
819- return remote_executor .construct (num_returns , args )
830+ return remote_executor .construct (args )
0 commit comments