3939
4040from packaging .version import parse as parse_version
4141from tlz import first , groupby , merge , partition_all , valmap
42+ from tornado import gen
43+ from tornado .ioloop import IOLoop
4244
4345import dask
46+ from dask ._expr import Expr , HLGExpr , LLGExpr
47+ from dask ._task_spec import DataNode , GraphNode , List , Task , TaskRef , parse_input
4448from dask .base import collections_to_dsk
4549from dask .core import flatten , validate_key
46- from dask .layers import Layer
50+ from dask .highlevelgraph import HighLevelGraph
4751from dask .tokenize import tokenize
4852from dask .typing import Key , NestedKeys , NoDefault , no_default
4953from dask .utils import (
5761)
5862from dask .widgets import get_template
5963
60- from distributed .core import OKMessage
61- from distributed .protocol .serialize import _is_dumpable
62- from distributed .utils import Deadline , wait_for
63-
64- try :
65- from dask .delayed import single_key
66- except ImportError :
67- single_key = first
68- from tornado import gen
69- from tornado .ioloop import IOLoop
70-
71- from dask ._task_spec import DataNode , GraphNode , List , Task , TaskRef , parse_input
72-
7364import distributed .utils
7465from distributed import cluster_dump , preloading
7566from distributed import versions as version_module
7970from distributed .core import (
8071 CommClosedError ,
8172 ConnectionPool ,
73+ OKMessage ,
8274 PooledRPCCall ,
8375 Status ,
8476 clean_exception ,
9890from distributed .objects import HasWhat , SchedulerInfo , WhoHas
9991from distributed .protocol import to_serialize
10092from distributed .protocol .pickle import dumps , loads
93+ from distributed .protocol .serialize import _is_dumpable
10194from distributed .publish import Datasets
10295from distributed .pubsub import PubSubClientExtension
10396from distributed .security import Security
10699from distributed .threadpoolexecutor import rejoin
107100from distributed .utils import (
108101 CancelledError ,
102+ Deadline ,
109103 LoopRunner ,
110104 NoOpAwaitable ,
111105 SyncMethodMixin ,
117111 nbytes ,
118112 sync ,
119113 thread_state ,
114+ wait_for ,
120115)
121116from distributed .utils_comm import (
122117 gather_from_workers ,
@@ -834,51 +829,32 @@ def _is_nested(iterable):
834829 return False
835830
836831
837- class _MapLayer ( Layer ):
832+ class _MapExpr ( Expr ):
838833 func : Callable
839- iterables : Iterable [ Any ]
840- key : str | Iterable [ str ] | None
834+ iterables : Iterable
835+ key : Key
841836 pure : bool
842- annotations : dict [str , Any ] | None
843-
844- def __init__ (
845- self ,
846- func : Callable ,
847- iterables : Iterable [Any ],
848- key : str | Iterable [str ] | None = None ,
849- pure : bool = True ,
850- annotations : dict [str , Any ] | None = None ,
851- ** kwargs ,
852- ):
853- self .func : Callable = func
854- self .iterables = [tuple (map (parse_input , iterable )) for iterable in iterables ]
855- self .key : str | Iterable [str ] | None = key
856- self .pure : bool = pure
857- self .kwargs = {k : parse_input (v ) for k , v in kwargs .items ()}
858- super ().__init__ (annotations = annotations )
859-
860- def __repr__ (self ) -> str :
861- return f"{ type (self ).__name__ } <func='{ funcname (self .func )} '>"
837+ annotations : dict
838+ kwargs : dict
839+ _cached_keys : Iterable [Key ] | None
840+ _parameters = [
841+ "func" ,
842+ "iterables" ,
843+ "key" ,
844+ "pure" ,
845+ "annotations" ,
846+ "kwargs" ,
847+ "_cached_keys" ,
848+ ]
849+ _defaults = {"_cached_keys" : None }
862850
863851 @property
864- def _dict (self ) -> _T_LowLevelGraph :
865- self ._cached_dict : _T_LowLevelGraph
866- dsk : _T_LowLevelGraph
867-
868- if hasattr (self , "_cached_dict" ):
869- return self ._cached_dict
870- else :
871- dsk = self ._construct_graph ()
872- self ._cached_dict = dsk
873- return self ._cached_dict
874-
875- @property
876- def _keys (self ) -> Iterable [Key ]:
877- if hasattr (self , "_cached_keys" ):
852+ def keys (self ) -> Iterable [Key ]:
853+ if self ._cached_keys is not None :
878854 return self ._cached_keys
879855 else :
880856 if isinstance (self .key , Iterable ) and not isinstance (self .key , str ):
881- self ._cached_keys : Iterable [ Key ] = self .key
857+ self .operands [ - 1 ] = self .key
882858 return self .key
883859
884860 else :
@@ -898,34 +874,19 @@ def _keys(self) -> Iterable[Key]:
898874 if self .iterables
899875 else []
900876 )
901- self ._cached_keys = keys
877+ self .operands [ - 1 ] = keys
902878 return keys
903879
904- def get_output_keys (self ) -> set [Key ]:
905- return set (self ._keys )
906-
907- def get_ordered_keys (self ):
908- return list (self ._keys )
909-
910- def is_materialized (self ) -> bool :
911- return hasattr (self , "_cached_dict" )
912-
913- def __getitem__ (self , key : Key ) -> GraphNode :
914- return self ._dict [key ]
880+ def _meta (self ):
881+ return []
915882
916- def __iter__ (self ) -> Iterator [Key ]:
917- return iter (self ._dict )
918-
919- def __len__ (self ) -> int :
920- return len (self ._dict )
921-
922- def _construct_graph (self ) -> _T_LowLevelGraph :
883+ def _layer (self ):
923884 dsk : _T_LowLevelGraph = {}
924885
925886 if not self .kwargs :
926887 dsk = {
927888 key : Task (key , self .func , * args )
928- for key , args in zip (self ._keys , zip (* self .iterables ))
889+ for key , args in zip (self .keys , zip (* self .iterables ))
929890 }
930891
931892 else :
@@ -937,12 +898,12 @@ def _construct_graph(self) -> _T_LowLevelGraph:
937898 kwargs2 [k ] = vv .ref ()
938899 dsk [vv .key ] = vv
939900 else :
940- kwargs2 [k ] = v
901+ kwargs2 [k ] = parse_input ( v )
941902
942903 dsk .update (
943904 {
944905 key : Task (key , self .func , * args , ** kwargs2 )
945- for key , args in zip (self ._keys , zip (* self .iterables ))
906+ for key , args in zip (self .keys , zip (* self .iterables ))
946907 }
947908 )
948909 return dsk
@@ -2162,16 +2123,19 @@ def submit(
21622123
21632124 if isinstance (workers , (str , Number )):
21642125 workers = [workers ]
2165- dsk = {
2166- key : Task (
2167- key ,
2168- func ,
2169- * (parse_input (a ) for a in args ),
2170- ** {k : parse_input (v ) for k , v in kwargs .items ()},
2171- )
2172- }
2126+
2127+ expr = LLGExpr (
2128+ {
2129+ key : Task (
2130+ key ,
2131+ func ,
2132+ * (parse_input (a ) for a in args ),
2133+ ** {k : parse_input (v ) for k , v in kwargs .items ()},
2134+ )
2135+ }
2136+ )
21732137 futures = self ._graph_to_futures (
2174- dsk ,
2138+ expr ,
21752139 [key ],
21762140 workers = workers ,
21772141 allow_other_workers = allow_other_workers ,
@@ -2331,14 +2295,16 @@ def map(
23312295 if allow_other_workers and workers is None :
23322296 raise ValueError ("Only use allow_other_workers= if using workers=" )
23332297
2334- dsk = _MapLayer (
2298+ expr = _MapExpr (
23352299 func ,
23362300 iterables ,
23372301 key = key ,
23382302 pure = pure ,
2339- ** kwargs ,
2303+ # FIXME: this doesn't look right
2304+ annotations = {},
2305+ kwargs = kwargs ,
23402306 )
2341- keys = dsk . get_ordered_keys ( )
2307+ keys = list ( expr . keys )
23422308 if isinstance (workers , (str , Number )):
23432309 workers = [workers ]
23442310 if workers is not None and not isinstance (workers , (list , set )):
@@ -2347,7 +2313,7 @@ def map(
23472313 internal_priority = dict (zip (keys , range (len (keys ))))
23482314
23492315 futures = self ._graph_to_futures (
2350- dsk ,
2316+ expr ,
23512317 keys ,
23522318 workers = workers ,
23532319 allow_other_workers = allow_other_workers ,
@@ -2361,7 +2327,6 @@ def map(
23612327 )
23622328
23632329 # make sure the graph is not materialized
2364- assert not dsk .is_materialized (), "Graph must be non-materialized"
23652330 logger .debug ("map(%s, ...)" , funcname (func ))
23662331 return [futures [k ] for k in keys ]
23672332
@@ -3464,8 +3429,12 @@ def get(
34643429 --------
34653430 Client.compute : Compute asynchronous collections
34663431 """
3432+ if isinstance (dsk , dict ):
3433+ dsk = LLGExpr (dsk )
3434+ elif isinstance (dsk , HighLevelGraph ):
3435+ dsk = HLGExpr (dsk )
34673436 futures = self ._graph_to_futures (
3468- dsk ,
3437+ expr = dsk ,
34693438 keys = set (flatten ([keys ])),
34703439 workers = workers ,
34713440 allow_other_workers = allow_other_workers ,
@@ -3667,7 +3636,6 @@ def compute(
36673636 expr = FinalizeCompute (expr )
36683637
36693638 expr = expr .optimize ()
3670- # FIXME: Is this actually required?
36713639 names = list (flatten (expr .__dask_keys__ ()))
36723640
36733641 futures_dict = self ._graph_to_futures (
0 commit comments