Skip to content

Commit 9fa13c8

Browse files
ferruzziseanghaeliramitkataria
authored
Executor Synchronous callback workload (#61153)
* Synchronous callback support for BaseExecutor, LocalExecutor, and CeleryExecutor Add support for the Callback workload to be run in the executors. Other executors will need to be updated before the can support the workload, but I tried to make it as non-invasive as I could. Co-authored-by: Sean Ghaeli <seanghaeli@gmail.com> Co-authored-by: Ramit Kataria <ramitkat@amazon.com>
1 parent 0ddf517 commit 9fa13c8

File tree

28 files changed

+1221
-454
lines changed

28 files changed

+1221
-454
lines changed

airflow-core/src/airflow/executors/base_executor.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
from airflow.configuration import conf
3333
from airflow.executors import workloads
3434
from airflow.executors.executor_loader import ExecutorLoader
35+
from airflow.executors.workloads.task import TaskInstanceDTO
3536
from airflow.models import Log
37+
from airflow.models.callback import CallbackKey
3638
from airflow.observability.metrics import stats_utils
3739
from airflow.observability.trace import Trace
3840
from airflow.utils.log.logging_mixin import LoggingMixin
@@ -52,6 +54,7 @@
5254
from airflow.callbacks.callback_requests import CallbackRequest
5355
from airflow.cli.cli_config import GroupCommand
5456
from airflow.executors.executor_utils import ExecutorName
57+
from airflow.executors.workloads.types import WorkloadKey
5558
from airflow.models.taskinstance import TaskInstance
5659
from airflow.models.taskinstancekey import TaskInstanceKey
5760

@@ -143,6 +146,7 @@ class BaseExecutor(LoggingMixin):
143146
active_spans = ThreadSafeDict()
144147

145148
supports_ad_hoc_ti_run: bool = False
149+
supports_callbacks: bool = False
146150
supports_multi_team: bool = False
147151
sentry_integration: str = ""
148152

@@ -186,8 +190,9 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
186190
self.parallelism: int = parallelism
187191
self.team_name: str | None = team_name
188192
self.queued_tasks: dict[TaskInstanceKey, workloads.ExecuteTask] = {}
189-
self.running: set[TaskInstanceKey] = set()
190-
self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
193+
self.queued_callbacks: dict[str, workloads.ExecuteCallback] = {}
194+
self.running: set[WorkloadKey] = set()
195+
self.event_buffer: dict[WorkloadKey, EventBufferValueType] = {}
191196
self._task_event_logs: deque[Log] = deque()
192197
self.conf = ExecutorConf(team_name)
193198

@@ -203,7 +208,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
203208
:meta private:
204209
"""
205210

206-
self.attempts: dict[TaskInstanceKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)
211+
self.attempts: dict[WorkloadKey, RunningRetryAttemptType] = defaultdict(RunningRetryAttemptType)
207212

208213
def __repr__(self):
209214
_repr = f"{self.__class__.__name__}(parallelism={self.parallelism}"
@@ -224,10 +229,47 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
224229
self._task_event_logs.append(Log(event=event, task_instance=ti_key, extra=extra))
225230

226231
def queue_workload(self, workload: workloads.All, session: Session) -> None:
227-
if not isinstance(workload, workloads.ExecuteTask):
228-
raise ValueError(f"Un-handled workload kind {type(workload).__name__!r} in {type(self).__name__}")
229-
ti = workload.ti
230-
self.queued_tasks[ti.key] = workload
232+
if isinstance(workload, workloads.ExecuteTask):
233+
ti = workload.ti
234+
self.queued_tasks[ti.key] = workload
235+
elif isinstance(workload, workloads.ExecuteCallback):
236+
if not self.supports_callbacks:
237+
raise NotImplementedError(
238+
f"{type(self).__name__} does not support ExecuteCallback workloads. "
239+
f"Set supports_callbacks = True and implement callback handling in _process_workloads(). "
240+
f"See LocalExecutor or CeleryExecutor for reference implementation."
241+
)
242+
self.queued_callbacks[workload.callback.id] = workload
243+
else:
244+
raise ValueError(
245+
f"Un-handled workload type {type(workload).__name__!r} in {type(self).__name__}. "
246+
f"Workload must be one of: ExecuteTask, ExecuteCallback."
247+
)
248+
249+
def _get_workloads_to_schedule(self, open_slots: int) -> list[tuple[WorkloadKey, workloads.All]]:
250+
"""
251+
Select and return the next batch of workloads to schedule, respecting priority policy.
252+
253+
Priority Policy: Callbacks are scheduled before tasks (callbacks complete existing work).
254+
Callbacks are processed in FIFO order. Tasks are sorted by priority_weight (higher priority first).
255+
256+
:param open_slots: Number of available execution slots
257+
"""
258+
workloads_to_schedule: list[tuple[WorkloadKey, workloads.All]] = []
259+
260+
if self.queued_callbacks:
261+
for key, workload in self.queued_callbacks.items():
262+
if len(workloads_to_schedule) >= open_slots:
263+
break
264+
workloads_to_schedule.append((key, workload))
265+
266+
if open_slots > len(workloads_to_schedule) and self.queued_tasks:
267+
for task_key, task_workload in self.order_queued_tasks_by_priority():
268+
if len(workloads_to_schedule) >= open_slots:
269+
break
270+
workloads_to_schedule.append((task_key, task_workload))
271+
272+
return workloads_to_schedule
231273

232274
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
233275
"""
@@ -266,10 +308,10 @@ def heartbeat(self) -> None:
266308
"""Heartbeat sent to trigger new jobs."""
267309
open_slots = self.parallelism - len(self.running)
268310

269-
num_running_tasks = len(self.running)
270-
num_queued_tasks = len(self.queued_tasks)
311+
num_running_workloads = len(self.running)
312+
num_queued_workloads = len(self.queued_tasks) + len(self.queued_callbacks)
271313

272-
self._emit_metrics(open_slots, num_running_tasks, num_queued_tasks)
314+
self._emit_metrics(open_slots, num_running_workloads, num_queued_workloads)
273315
self.trigger_tasks(open_slots)
274316

275317
# Calling child class sync method
@@ -350,16 +392,16 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workload
350392

351393
def trigger_tasks(self, open_slots: int) -> None:
352394
"""
353-
Initiate async execution of the queued tasks, up to the number of available slots.
395+
Initiate async execution of queued workloads (tasks and callbacks), up to the number of available slots.
396+
397+
Callbacks are prioritized over tasks to complete existing work before starting new work.
354398
355399
:param open_slots: Number of open slots
356400
"""
357-
sorted_queue = self.order_queued_tasks_by_priority()
401+
workloads_to_schedule = self._get_workloads_to_schedule(open_slots)
358402
workload_list = []
359403

360-
for _ in range(min((open_slots, len(self.queued_tasks)))):
361-
key, item = sorted_queue.pop()
362-
404+
for key, workload in workloads_to_schedule:
363405
# If a task makes it here but is still understood by the executor
364406
# to be running, it generally means that the task has been killed
365407
# externally and not yet been marked as failed.
@@ -373,12 +415,12 @@ def trigger_tasks(self, open_slots: int) -> None:
373415
if key in self.attempts:
374416
del self.attempts[key]
375417

376-
if isinstance(item, workloads.ExecuteTask) and hasattr(item, "ti"):
377-
ti = item.ti
418+
if isinstance(workload, workloads.ExecuteTask) and hasattr(workload, "ti"):
419+
ti = workload.ti
378420

379421
# If it's None, then the span for the current id hasn't been started.
380422
if self.active_spans is not None and self.active_spans.get("ti:" + str(ti.id)) is None:
381-
if isinstance(ti, workloads.TaskInstance):
423+
if isinstance(ti, TaskInstanceDTO):
382424
parent_context = Trace.extract(ti.parent_context_carrier)
383425
else:
384426
parent_context = Trace.extract(ti.dag_run.context_carrier)
@@ -397,7 +439,8 @@ def trigger_tasks(self, open_slots: int) -> None:
397439
carrier = Trace.inject()
398440
ti.context_carrier = carrier
399441

400-
workload_list.append(item)
442+
workload_list.append(workload)
443+
401444
if workload_list:
402445
self._process_workloads(workload_list)
403446

@@ -459,24 +502,25 @@ def running_state(self, key: TaskInstanceKey, info=None) -> None:
459502
"""
460503
self.change_state(key, TaskInstanceState.RUNNING, info, remove_running=False)
461504

462-
def get_event_buffer(self, dag_ids=None) -> dict[TaskInstanceKey, EventBufferValueType]:
505+
def get_event_buffer(self, dag_ids=None) -> dict[WorkloadKey, EventBufferValueType]:
463506
"""
464507
Return and flush the event buffer.
465508
466509
In case dag_ids is specified it will only return and flush events
467510
for the given dag_ids. Otherwise, it returns and flushes all events.
511+
Note: Callback events (with string keys) are always included regardless of dag_ids filter.
468512
469513
:param dag_ids: the dag_ids to return events for; returns all if given ``None``.
470514
:return: a dict of events
471515
"""
472-
cleared_events: dict[TaskInstanceKey, EventBufferValueType] = {}
516+
cleared_events: dict[WorkloadKey, EventBufferValueType] = {}
473517
if dag_ids is None:
474518
cleared_events = self.event_buffer
475519
self.event_buffer = {}
476520
else:
477-
for ti_key in list(self.event_buffer.keys()):
478-
if ti_key.dag_id in dag_ids:
479-
cleared_events[ti_key] = self.event_buffer.pop(ti_key)
521+
for key in list(self.event_buffer.keys()):
522+
if isinstance(key, CallbackKey) or key.dag_id in dag_ids:
523+
cleared_events[key] = self.event_buffer.pop(key)
480524

481525
return cleared_events
482526

@@ -529,21 +573,26 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
529573

530574
@property
531575
def slots_available(self):
532-
"""Number of new tasks this executor instance can accept."""
533-
return self.parallelism - len(self.running) - len(self.queued_tasks)
576+
"""Number of new workloads (tasks and callbacks) this executor instance can accept."""
577+
return self.parallelism - len(self.running) - len(self.queued_tasks) - len(self.queued_callbacks)
534578

535579
@property
536580
def slots_occupied(self):
537-
"""Number of tasks this executor instance is currently managing."""
538-
return len(self.running) + len(self.queued_tasks)
581+
"""Number of workloads (tasks and callbacks) this executor instance is currently managing."""
582+
return len(self.running) + len(self.queued_tasks) + len(self.queued_callbacks)
539583

540584
def debug_dump(self):
541585
"""Get called in response to SIGUSR2 by the scheduler."""
542586
self.log.info(
543-
"executor.queued (%d)\n\t%s",
587+
"executor.queued_tasks (%d)\n\t%s",
544588
len(self.queued_tasks),
545589
"\n\t".join(map(repr, self.queued_tasks.items())),
546590
)
591+
self.log.info(
592+
"executor.queued_callbacks (%d)\n\t%s",
593+
len(self.queued_callbacks),
594+
"\n\t".join(map(repr, self.queued_callbacks.items())),
595+
)
547596
self.log.info("executor.running (%d)\n\t%s", len(self.running), "\n\t".join(map(repr, self.running)))
548597
self.log.info(
549598
"executor.event_buffer (%d)\n\t%s",

airflow-core/src/airflow/executors/local_executor.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
from airflow.executors import workloads
3939
from airflow.executors.base_executor import BaseExecutor
40-
from airflow.utils.state import TaskInstanceState
40+
from airflow.executors.workloads.callback import execute_callback_workload
41+
from airflow.utils.state import CallbackState, TaskInstanceState
4142

4243
# add logger to parameter of setproctitle to support logging
4344
if sys.platform == "darwin":
@@ -50,13 +51,23 @@
5051
if TYPE_CHECKING:
5152
from structlog.typing import FilteringBoundLogger as Logger
5253

53-
TaskInstanceStateType = tuple[workloads.TaskInstance, TaskInstanceState, Exception | None]
54+
from airflow.executors.workloads.types import WorkloadResultType
55+
56+
57+
def _get_executor_process_title_prefix(team_name: str | None) -> str:
58+
"""
59+
Build the process title prefix for LocalExecutor workers.
60+
61+
:param team_name: Team name from executor configuration
62+
"""
63+
team_suffix = f" [{team_name}]" if team_name else ""
64+
return f"airflow worker -- LocalExecutor{team_suffix}:"
5465

5566

5667
def _run_worker(
5768
logger_name: str,
5869
input: SimpleQueue[workloads.All | None],
59-
output: Queue[TaskInstanceStateType],
70+
output: Queue[WorkloadResultType],
6071
unread_messages: multiprocessing.sharedctypes.Synchronized[int],
6172
team_conf,
6273
):
@@ -68,11 +79,8 @@ def _run_worker(
6879
log = structlog.get_logger(logger_name)
6980
log.info("Worker starting up pid=%d", os.getpid())
7081

71-
# Create team suffix for process title
72-
team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""
73-
7482
while True:
75-
setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: <idle>", log)
83+
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} <idle>", log)
7684
try:
7785
workload = input.get()
7886
except EOFError:
@@ -87,25 +95,30 @@ def _run_worker(
8795
# Received poison pill, no more tasks to run
8896
return
8997

90-
if not isinstance(workload, workloads.ExecuteTask):
91-
raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}")
92-
9398
# Decrement this as soon as we pick up a message off the queue
9499
with unread_messages:
95100
unread_messages.value -= 1
96-
key = None
97-
if ti := getattr(workload, "ti", None):
98-
key = ti.key
99-
else:
100-
raise TypeError(f"Don't know how to get ti key from {type(workload).__name__}")
101101

102-
try:
103-
_execute_work(log, workload, team_conf)
102+
# Handle different workload types
103+
if isinstance(workload, workloads.ExecuteTask):
104+
try:
105+
_execute_work(log, workload, team_conf)
106+
output.put((workload.ti.key, TaskInstanceState.SUCCESS, None))
107+
except Exception as e:
108+
log.exception("Task execution failed.")
109+
output.put((workload.ti.key, TaskInstanceState.FAILED, e))
110+
111+
elif isinstance(workload, workloads.ExecuteCallback):
112+
output.put((workload.callback.id, CallbackState.RUNNING, None))
113+
try:
114+
_execute_callback(log, workload, team_conf)
115+
output.put((workload.callback.id, CallbackState.SUCCESS, None))
116+
except Exception as e:
117+
log.exception("Callback execution failed")
118+
output.put((workload.callback.id, CallbackState.FAILED, e))
104119

105-
output.put((key, TaskInstanceState.SUCCESS, None))
106-
except Exception as e:
107-
log.exception("uhoh")
108-
output.put((key, TaskInstanceState.FAILED, e))
120+
else:
121+
raise ValueError(f"LocalExecutor does not know how to handle {type(workload)}")
109122

110123

111124
def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> None:
@@ -118,9 +131,7 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No
118131
"""
119132
from airflow.sdk.execution_time.supervisor import supervise
120133

121-
# Create team suffix for process title
122-
team_suffix = f" [{team_conf.team_name}]" if team_conf.team_name else ""
123-
setproctitle(f"airflow worker -- LocalExecutor{team_suffix}: {workload.ti.id}", log)
134+
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.ti.id}", log)
124135

125136
base_url = team_conf.get("api", "base_url", fallback="/")
126137
# If it's a relative URL, use localhost:8080 as the default
@@ -141,6 +152,22 @@ def _execute_work(log: Logger, workload: workloads.ExecuteTask, team_conf) -> No
141152
)
142153

143154

155+
def _execute_callback(log: Logger, workload: workloads.ExecuteCallback, team_conf) -> None:
156+
"""
157+
Execute a callback workload.
158+
159+
:param log: Logger instance
160+
:param workload: The ExecuteCallback workload to execute
161+
:param team_conf: Team-specific executor configuration
162+
"""
163+
setproctitle(f"{_get_executor_process_title_prefix(team_conf.team_name)} {workload.callback.id}", log)
164+
165+
success, error_msg = execute_callback_workload(workload.callback, log)
166+
167+
if not success:
168+
raise RuntimeError(error_msg or "Callback execution failed")
169+
170+
144171
class LocalExecutor(BaseExecutor):
145172
"""
146173
LocalExecutor executes tasks locally in parallel.
@@ -155,9 +182,10 @@ class LocalExecutor(BaseExecutor):
155182

156183
supports_multi_team: bool = True
157184
serve_logs: bool = True
185+
supports_callbacks: bool = True
158186

159187
activity_queue: SimpleQueue[workloads.All | None]
160-
result_queue: SimpleQueue[TaskInstanceStateType]
188+
result_queue: SimpleQueue[WorkloadResultType]
161189
workers: dict[int, multiprocessing.Process]
162190
_unread_messages: multiprocessing.sharedctypes.Synchronized[int]
163191

@@ -300,10 +328,14 @@ def end(self) -> None:
300328
def terminate(self):
301329
"""Terminate the executor is not doing anything."""
302330

303-
def _process_workloads(self, workloads):
304-
for workload in workloads:
331+
def _process_workloads(self, workload_list):
332+
for workload in workload_list:
305333
self.activity_queue.put(workload)
306-
del self.queued_tasks[workload.ti.key]
334+
# Remove from appropriate queue based on workload type
335+
if isinstance(workload, workloads.ExecuteTask):
336+
del self.queued_tasks[workload.ti.key]
337+
elif isinstance(workload, workloads.ExecuteCallback):
338+
del self.queued_callbacks[workload.callback.id]
307339
with self._unread_messages:
308-
self._unread_messages.value += len(workloads)
340+
self._unread_messages.value += len(workload_list)
309341
self._check_workers()

0 commit comments

Comments
 (0)