Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions src/prefect/deployments/steps/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import re
import subprocess
import warnings
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from contextvars import ContextVar
from copy import deepcopy
from importlib import import_module
from pathlib import Path
from typing import Any
from uuid import UUID

Expand All @@ -39,13 +43,38 @@

RESERVED_KEYWORDS = {"id", "requires"}

_StepCompletionObserver = Callable[
[dict[str, Any], Any, Path | None, Path | None],
None,
]
_STEP_COMPLETION_OBSERVER: ContextVar[_StepCompletionObserver | None] = ContextVar(
"step_completion_observer",
default=None,
)


class StepExecutionError(Exception):
"""
Raised when a step fails to execute.
"""


def _safe_current_working_directory() -> Path | None:
try:
return Path.cwd().resolve()
except OSError:
return None


@contextmanager
def _observe_step_completion(callback: _StepCompletionObserver) -> Iterator[None]:
token = _STEP_COMPLETION_OBSERVER.set(callback)
try:
yield
finally:
_STEP_COMPLETION_OBSERVER.reset(token)


def _strip_version(requirement: str) -> str:
"""
Strips the version from a requirement string.
Expand Down Expand Up @@ -152,6 +181,7 @@ async def run_steps(
logger: Any | None = None,
) -> dict[str, Any]:
upstream_outputs = deepcopy(upstream_outputs) if upstream_outputs else {}
step_completion_observer = _STEP_COMPLETION_OBSERVER.get()
for step_index, step in enumerate(steps):
if not step:
continue
Expand All @@ -177,6 +207,11 @@ async def run_steps(

try:
# catch warnings to ensure deprecation warnings are printed
step_start_cwd = (
_safe_current_working_directory()
if step_completion_observer is not None
else None
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter(
"always",
Expand All @@ -201,6 +236,14 @@ async def run_steps(
print_function(message)
printed_messages.append(message)

if step_completion_observer is not None:
step_completion_observer(
step,
step_output,
step_start_cwd,
_safe_current_working_directory(),
)

if not isinstance(step_output, dict):
if PREFECT_DEBUG_MODE:
get_logger().warning(
Expand Down
Loading
Loading