Skip to content

Commit ffe7cd4

Browse files
committed
Fix Client.persist
1 parent ce5a886 commit ffe7cd4

2 files changed

Lines changed: 42 additions & 10 deletions

File tree

distributed/client.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3763,13 +3763,15 @@ def persist(
37633763
metadata = SpanMetadata(
37643764
collections=[get_collections_metadata(v) for v in collections]
37653765
)
3766-
expr = collections_to_expr(collections, optimize_graph, **kwargs)
3766+
expr = collections_to_expr(collections, optimize_graph)
37673767

3768-
names = {k for c in collections for k in flatten(c.__dask_keys__())}
3768+
expr2 = expr.optimize()
3769+
3770+
keys = expr2.__dask_keys__()
37693771

37703772
futures = self._graph_to_futures(
3771-
expr,
3772-
names,
3773+
expr2,
3774+
list(flatten(keys)),
37733775
workers=workers,
37743776
allow_other_workers=allow_other_workers,
37753777
resources=resources,
@@ -3781,15 +3783,15 @@ def persist(
37813783
)
37823784

37833785
postpersists = [c.__dask_postpersist__() for c in collections]
3784-
result = [
3785-
func({k: futures[k] for k in flatten(c.__dask_keys__())}, *args)
3786-
for (func, args), c in zip(postpersists, collections)
3787-
]
37883786

37893787
if singleton:
3790-
return first(result)
3788+
func, args = postpersists[0]
3789+
return func({k: futures[k] for k in keys}, *args)
37913790
else:
3792-
return result
3791+
return [
3792+
func({k: futures[k] for k in flatten(ks)}, *args)
3793+
for (func, args), ks in zip(postpersists, keys)
3794+
]
37933795

37943796
async def _restart(
37953797
self, timeout: str | int | float | NoDefault, wait_for_workers: bool

distributed/tests/test_dask_collections.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
35
import pytest
46

57
np = pytest.importorskip("numpy")
@@ -41,6 +43,34 @@ def assert_equal(a, b):
4143
assert a == b
4244

4345

46+
@gen_cluster(client=True)
47+
async def test_persist(c, s, a, b):
48+
df = pd.DataFrame({"x": range(10), "y": range(10, 20)})
49+
ddf = dd.from_pandas(df, npartitions=2)
50+
df2 = pd.DataFrame({"x": range(20, 30), "y": range(30, 40)})
51+
ddf2 = dd.from_pandas(df2, npartitions=2)
52+
53+
ddfp = await c.persist(ddf)
54+
assert s.tasks
55+
assert sum(ts.state == "memory" for ts in s.tasks.values()) == 2
56+
assert_equal(await c.compute(ddfp), await c.compute(ddf))
57+
del ddfp
58+
59+
while not sum(ts.state == "memory" for ts in s.tasks.values()) == 0:
60+
await asyncio.sleep(0.01)
61+
62+
ddfp1, ddfp2 = c.persist((ddf, ddf2))
63+
await wait((ddfp1, ddfp2))
64+
assert s.tasks
65+
assert sum(ts.state == "memory" for ts in s.tasks.values()) == 4
66+
67+
assert_equal(await c.compute(ddfp1), await c.compute(ddf))
68+
assert_equal(await c.compute(ddfp2), await c.compute(ddf2))
69+
del ddfp1, ddfp2
70+
while not sum(ts.state == "memory" for ts in s.tasks.values()) == 0:
71+
await asyncio.sleep(0.01)
72+
73+
4474
@ignore_single_machine_warning
4575
@gen_cluster(client=True)
4676
async def test_dataframes(c, s, a, b):

0 commit comments

Comments
 (0)