Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions src/integrations/prefect-dask/prefect_dask/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ class PrefectDaskFuture(PrefectWrappedFuture[R, distributed.Future]):
def wait(self, timeout: Optional[float] = None) -> None:
try:
result = self._wrapped_future.result(timeout=timeout)
except Exception:
# either the task failed or the timeout was reached
except distributed.TimeoutError:
return
except Exception as exc:
self._wrapped_future_error = exc
return
if isinstance(result, State):
self._final_state = result
Expand All @@ -137,11 +139,14 @@ def result(
raise TimeoutError(
f"Task run {self.task_run_id} did not complete within {timeout} seconds"
) from exc

if isinstance(future_result, State):
self._final_state = future_result
except Exception as exc:
self._wrapped_future_error = exc
self._final_state = self.state
else:
return future_result
if isinstance(future_result, State):
self._final_state = future_result
else:
return future_result

return self._final_state.result(raise_on_failure=raise_on_failure, _sync=True)

Expand Down
26 changes: 26 additions & 0 deletions src/integrations/prefect-dask/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from typing import Generator, List

import dask
import dask.dataframe as dd
import distributed
import pandas as pd
Expand Down Expand Up @@ -303,6 +304,31 @@ def task_a():
future.wait()
assert future.state.type == StateType.CRASHED

def test_worker_loss_exhaustion_marks_task_run_crashed(self):
dask.config.set({"distributed.scheduler.allowed-failures": 1})

@task
def oom_task():
blocks = []
while True:
blocks.append(bytearray(64 * 1024 * 1024))
time.sleep(0.05)

task_runner = DaskTaskRunner(
cluster_kwargs={
"processes": True,
"n_workers": 1,
"threads_per_worker": 1,
"dashboard_address": None,
"memory_limit": "256 MiB",
}
)

with task_runner:
future = task_runner.submit(oom_task, parameters={}, wait_for=[])
future.wait()
assert future.state.type == StateType.CRASHED

def test_dask_task_key_has_prefect_task_name(self):
task_runner = DaskTaskRunner()

Expand Down
70 changes: 68 additions & 2 deletions src/prefect/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,16 @@

from prefect._waiters import FlowRunWaiter
from prefect.client.orchestration import get_client
from prefect.exceptions import ObjectNotFound
from prefect.client.schemas import OrchestrationResult
from prefect.client.schemas.responses import (
SetStateStatus,
StateAbortDetails,
StateRejectDetails,
StateWaitDetails,
)
from prefect.exceptions import Abort, ObjectNotFound
from prefect.logging.loggers import get_logger
from prefect.states import Pending, State
from prefect.states import Pending, State, exception_to_crashed_state
from prefect.task_runs import TaskRunWaiter
from prefect.utilities.annotations import quote
from prefect.utilities.asyncutils import run_coro_as_sync
Expand Down Expand Up @@ -168,6 +175,7 @@ class PrefectWrappedFuture(PrefectTaskRunFuture[R], abc.ABC, Generic[R, F]):

def __init__(self, task_run_id: uuid.UUID, wrapped_future: F):
self._wrapped_future: F = wrapped_future
self._wrapped_future_error: BaseException | None = None
super().__init__(task_run_id)

@property
Expand All @@ -187,6 +195,64 @@ def call_with_self(future: F):
return
fn(self)

@property
def state(self) -> State:
if self._final_state:
return self._final_state
if self._wrapped_future_error is not None:
self._final_state = self._resolve_wrapped_future_error(
self._wrapped_future_error
)
return self._final_state
return super().state

def _resolve_wrapped_future_error(self, exc: BaseException) -> State:
client = get_client(sync_client=True)
task_run = client.read_task_run(task_run_id=self.task_run_id)
if task_run.state and task_run.state.is_final():
return task_run.state

state = run_coro_as_sync(exception_to_crashed_state(exc))

def set_state_and_handle_waits(
set_state_func: Callable[[], OrchestrationResult[Any]],
) -> OrchestrationResult[Any]:
response = set_state_func()
while response.status == SetStateStatus.WAIT:
assert isinstance(response.details, StateWaitDetails)
time.sleep(response.details.delay_seconds)
response = set_state_func()
return response

set_state = partial(client.set_task_run_state, self.task_run_id, state)
response = set_state_and_handle_waits(set_state)

if response.status == SetStateStatus.ACCEPT:
assert response.state is not None
state.id = response.state.id
state.timestamp = response.state.timestamp
if response.state.state_details:
state.state_details = response.state.state_details
return state
if response.status == SetStateStatus.ABORT:
refreshed_task_run = client.read_task_run(task_run_id=self.task_run_id)
if refreshed_task_run.state and refreshed_task_run.state.is_final():
return refreshed_task_run.state
assert isinstance(response.details, StateAbortDetails)
raise Abort(response.details.reason)
if response.status == SetStateStatus.REJECT:
assert isinstance(response.details, StateRejectDetails)
if response.state and response.state.is_final():
return response.state
refreshed_task_run = client.read_task_run(task_run_id=self.task_run_id)
if refreshed_task_run.state and refreshed_task_run.state.is_final():
return refreshed_task_run.state
assert response.state is not None
return response.state
raise ValueError(
f"Received unexpected `SetStateStatus` from server: {response.status!r}"
)


class PrefectConcurrentFuture(PrefectWrappedFuture[R, concurrent.futures.Future[R]]):
"""
Expand Down
Loading