Skip to content

Commit ce90c4a

Browse files
authored
Properly forward cancellation reason (#9028)
1 parent e931ccd commit ce90c4a

3 files changed

Lines changed: 45 additions & 19 deletions

File tree

distributed/client.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1839,11 +1839,11 @@ def _handle_lost_data(self, key=None):
18391839
if state is not None:
18401840
state.lose()
18411841

1842-
def _handle_cancelled_keys(self, keys):
1842+
def _handle_cancelled_keys(self, keys, reason=None, msg=None):
18431843
for key in keys:
18441844
state = self.futures.get(key)
18451845
if state is not None:
1846-
state.cancel()
1846+
state.cancel(reason=reason, msg=msg)
18471847

18481848
def _handle_retried_key(self, key=None):
18491849
state = self.futures.get(key)
@@ -2796,7 +2796,15 @@ def scatter(
27962796
async def _cancel(self, futures, reason=None, msg=None, force=False):
27972797
# FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop.
27982798
keys = list({f.key for f in futures_of(futures)})
2799-
self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force})
2799+
self._send_to_scheduler(
2800+
{
2801+
"op": "cancel-keys",
2802+
"keys": keys,
2803+
"force": force,
2804+
"reason": reason,
2805+
"msg": msg,
2806+
}
2807+
)
28002808
for k in keys:
28012809
st = self.futures.pop(k, None)
28022810
if st is not None:
@@ -2823,7 +2831,12 @@ def cancel(self, futures, asynchronous=None, force=False, reason=None, msg=None)
28232831
Message that will be attached to the cancelled future
28242832
"""
28252833
return self.sync(
2826-
self._cancel, futures, asynchronous=asynchronous, force=force, msg=msg
2834+
self._cancel,
2835+
futures,
2836+
asynchronous=asynchronous,
2837+
force=force,
2838+
msg=msg,
2839+
reason=reason,
28272840
)
28282841

28292842
async def _retry(self, futures):

distributed/scheduler.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4865,7 +4865,14 @@ async def update_graph(
48654865
lost_keys = self._find_lost_dependencies(dsk, dependencies, keys)
48664866

48674867
if lost_keys:
4868-
self.report({"op": "cancelled-keys", "keys": lost_keys}, client=client)
4868+
self.report(
4869+
{
4870+
"op": "cancelled-keys",
4871+
"keys": lost_keys,
4872+
"reason": "lost dependencies",
4873+
},
4874+
client=client,
4875+
)
48694876
self.client_releases_keys(
48704877
keys=lost_keys, client=client, stimulus_id=stimulus_id
48714878
)
@@ -5572,7 +5579,7 @@ async def remove_worker_from_events() -> None:
55725579
return "OK"
55735580

55745581
def stimulus_cancel(
5575-
self, keys: Collection[Key], client: str, force: bool = False
5582+
self, keys: Collection[Key], client: str, force: bool, reason: str, msg: str
55765583
) -> None:
55775584
"""Stop execution on a list of keys"""
55785585
logger.info("Client %s requests to cancel %d keys", client, len(keys))
@@ -5591,7 +5598,11 @@ def stimulus_cancel(
55915598
if force or ts.who_wants == {cs}: # no one else wants this key
55925599
if ts.dependents:
55935600
self.stimulus_cancel(
5594-
[dts.key for dts in ts.dependents], client, force=force
5601+
[dts.key for dts in ts.dependents],
5602+
client,
5603+
force=force,
5604+
reason=reason,
5605+
msg=msg,
55955606
)
55965607
logger.info("Scheduler cancels key %s. Force=%s", key, force)
55975608
cancelled_keys.append(key)
@@ -5603,7 +5614,14 @@ def stimulus_cancel(
56035614
client=cs.client_key,
56045615
stimulus_id=f"cancel-key-{time()}",
56055616
)
5606-
self.report({"op": "cancelled-keys", "keys": cancelled_keys})
5617+
self.report(
5618+
{
5619+
"op": "cancelled-keys",
5620+
"keys": cancelled_keys,
5621+
"reason": reason,
5622+
"msg": msg,
5623+
}
5624+
)
56075625

56085626
def client_desires_keys(self, keys: Collection[Key], client: str) -> None:
56095627
cs = self.clients.get(client)
@@ -8948,7 +8966,7 @@ def request_remove_replicas(
89488966

89498967
def _task_to_report_msg(ts: TaskState) -> dict[str, Any] | None:
89508968
if ts.state == "forgotten":
8951-
return {"op": "cancelled-keys", "keys": [ts.key]}
8969+
return {"op": "cancelled-keys", "keys": [ts.key], "reason": "already forgotten"}
89528970
elif ts.state == "memory":
89538971
return {"op": "key-in-memory", "key": ts.key}
89548972
elif ts.state == "erred":

distributed/tests/test_client.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,28 +2641,23 @@ async def test_futures_of_cancelled_raises(c, s, a, b):
26412641
await asyncio.sleep(0.01)
26422642
await c.cancel([x], reason="testreason")
26432643

2644-
# Note: The scheduler currently doesn't remember the reason but rather
2645-
# forgets the task immediately. The reason is currently. only raised if the
2646-
# client checks on it. Therefore, we expect an unknown reason and definitely
2647-
# not a scheduler disconnected which would otherwise indicate a bug, e.g. an
2648-
# AssertionError during transitioning.
2649-
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
2644+
with pytest.raises(CancelledError, match="reason: testreason"):
26502645
await x
26512646
while x.key in s.tasks:
26522647
await asyncio.sleep(0.01)
26532648

2654-
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
2649+
with pytest.raises(CancelledError, match="reason: lost dependencies"):
26552650
get_obj = c.get({"x": (inc, x), "y": (inc, 2)}, ["x", "y"], sync=False)
26562651
gather_obj = c.gather(get_obj)
26572652
await gather_obj
26582653

2659-
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
2654+
with pytest.raises(CancelledError, match="reason: lost dependencies"):
26602655
await c.submit(inc, x)
26612656

2662-
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
2657+
with pytest.raises(CancelledError, match="reason: lost dependencies"):
26632658
await c.submit(add, 1, y=x)
26642659

2665-
with pytest.raises(CancelledError, match="(reason: unknown|testreason)"):
2660+
with pytest.raises(CancelledError, match="reason: lost dependencies"):
26662661
await c.gather(c.map(add, [1], y=x))
26672662

26682663

0 commit comments

Comments
 (0)