3232from airflow .configuration import conf
3333from airflow .executors import workloads
3434from airflow .executors .executor_loader import ExecutorLoader
35+ from airflow .executors .workloads .task import TaskInstanceDTO
3536from airflow .models import Log
37+ from airflow .models .callback import CallbackKey
3638from airflow .observability .metrics import stats_utils
3739from airflow .observability .trace import Trace
3840from airflow .utils .log .logging_mixin import LoggingMixin
5254 from airflow .callbacks .callback_requests import CallbackRequest
5355 from airflow .cli .cli_config import GroupCommand
5456 from airflow .executors .executor_utils import ExecutorName
57+ from airflow .executors .workloads .types import WorkloadKey
5558 from airflow .models .taskinstance import TaskInstance
5659 from airflow .models .taskinstancekey import TaskInstanceKey
5760
@@ -143,6 +146,7 @@ class BaseExecutor(LoggingMixin):
143146 active_spans = ThreadSafeDict ()
144147
145148 supports_ad_hoc_ti_run : bool = False
149+ supports_callbacks : bool = False
146150 supports_multi_team : bool = False
147151 sentry_integration : str = ""
148152
@@ -186,8 +190,9 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
186190 self .parallelism : int = parallelism
187191 self .team_name : str | None = team_name
188192 self .queued_tasks : dict [TaskInstanceKey , workloads .ExecuteTask ] = {}
189- self .running : set [TaskInstanceKey ] = set ()
190- self .event_buffer : dict [TaskInstanceKey , EventBufferValueType ] = {}
193+ self .queued_callbacks : dict [str , workloads .ExecuteCallback ] = {}
194+ self .running : set [WorkloadKey ] = set ()
195+ self .event_buffer : dict [WorkloadKey , EventBufferValueType ] = {}
191196 self ._task_event_logs : deque [Log ] = deque ()
192197 self .conf = ExecutorConf (team_name )
193198
@@ -203,7 +208,7 @@ def __init__(self, parallelism: int = PARALLELISM, team_name: str | None = None)
203208 :meta private:
204209 """
205210
206- self .attempts : dict [TaskInstanceKey , RunningRetryAttemptType ] = defaultdict (RunningRetryAttemptType )
211+ self .attempts : dict [WorkloadKey , RunningRetryAttemptType ] = defaultdict (RunningRetryAttemptType )
207212
208213 def __repr__ (self ):
209214 _repr = f"{ self .__class__ .__name__ } (parallelism={ self .parallelism } "
@@ -224,10 +229,47 @@ def log_task_event(self, *, event: str, extra: str, ti_key: TaskInstanceKey):
224229 self ._task_event_logs .append (Log (event = event , task_instance = ti_key , extra = extra ))
225230
226231 def queue_workload (self , workload : workloads .All , session : Session ) -> None :
227- if not isinstance (workload , workloads .ExecuteTask ):
228- raise ValueError (f"Un-handled workload kind { type (workload ).__name__ !r} in { type (self ).__name__ } " )
229- ti = workload .ti
230- self .queued_tasks [ti .key ] = workload
232+ if isinstance (workload , workloads .ExecuteTask ):
233+ ti = workload .ti
234+ self .queued_tasks [ti .key ] = workload
235+ elif isinstance (workload , workloads .ExecuteCallback ):
236+ if not self .supports_callbacks :
237+ raise NotImplementedError (
238+ f"{ type (self ).__name__ } does not support ExecuteCallback workloads. "
239+ f"Set supports_callbacks = True and implement callback handling in _process_workloads(). "
240+ f"See LocalExecutor or CeleryExecutor for reference implementation."
241+ )
242+ self .queued_callbacks [workload .callback .id ] = workload
243+ else :
244+ raise ValueError (
245+ f"Un-handled workload type { type (workload ).__name__ !r} in { type (self ).__name__ } . "
246+ f"Workload must be one of: ExecuteTask, ExecuteCallback."
247+ )
248+
249+ def _get_workloads_to_schedule (self , open_slots : int ) -> list [tuple [WorkloadKey , workloads .All ]]:
250+ """
251+ Select and return the next batch of workloads to schedule, respecting priority policy.
252+
253+ Priority Policy: Callbacks are scheduled before tasks (callbacks complete existing work).
254+ Callbacks are processed in FIFO order. Tasks are sorted by priority_weight (higher priority first).
255+
256+ :param open_slots: Number of available execution slots
257+ """
258+ workloads_to_schedule : list [tuple [WorkloadKey , workloads .All ]] = []
259+
260+ if self .queued_callbacks :
261+ for key , workload in self .queued_callbacks .items ():
262+ if len (workloads_to_schedule ) >= open_slots :
263+ break
264+ workloads_to_schedule .append ((key , workload ))
265+
266+ if open_slots > len (workloads_to_schedule ) and self .queued_tasks :
267+ for task_key , task_workload in self .order_queued_tasks_by_priority ():
268+ if len (workloads_to_schedule ) >= open_slots :
269+ break
270+ workloads_to_schedule .append ((task_key , task_workload ))
271+
272+ return workloads_to_schedule
231273
232274 def _process_workloads (self , workloads : Sequence [workloads .All ]) -> None :
233275 """
@@ -266,10 +308,10 @@ def heartbeat(self) -> None:
266308 """Heartbeat sent to trigger new jobs."""
267309 open_slots = self .parallelism - len (self .running )
268310
269- num_running_tasks = len (self .running )
270- num_queued_tasks = len (self .queued_tasks )
311+ num_running_workloads = len (self .running )
312+ num_queued_workloads = len (self .queued_tasks ) + len ( self . queued_callbacks )
271313
272- self ._emit_metrics (open_slots , num_running_tasks , num_queued_tasks )
314+ self ._emit_metrics (open_slots , num_running_workloads , num_queued_workloads )
273315 self .trigger_tasks (open_slots )
274316
275317 # Calling child class sync method
@@ -350,16 +392,16 @@ def order_queued_tasks_by_priority(self) -> list[tuple[TaskInstanceKey, workload
350392
351393 def trigger_tasks (self , open_slots : int ) -> None :
352394 """
353- Initiate async execution of the queued tasks, up to the number of available slots.
395+ Initiate async execution of queued workloads (tasks and callbacks), up to the number of available slots.
396+
397+ Callbacks are prioritized over tasks to complete existing work before starting new work.
354398
355399 :param open_slots: Number of open slots
356400 """
357- sorted_queue = self .order_queued_tasks_by_priority ( )
401+ workloads_to_schedule = self ._get_workloads_to_schedule ( open_slots )
358402 workload_list = []
359403
360- for _ in range (min ((open_slots , len (self .queued_tasks )))):
361- key , item = sorted_queue .pop ()
362-
404+ for key , workload in workloads_to_schedule :
363405 # If a task makes it here but is still understood by the executor
364406 # to be running, it generally means that the task has been killed
365407 # externally and not yet been marked as failed.
@@ -373,12 +415,12 @@ def trigger_tasks(self, open_slots: int) -> None:
373415 if key in self .attempts :
374416 del self .attempts [key ]
375417
376- if isinstance (item , workloads .ExecuteTask ) and hasattr (item , "ti" ):
377- ti = item .ti
418+ if isinstance (workload , workloads .ExecuteTask ) and hasattr (workload , "ti" ):
419+ ti = workload .ti
378420
379421 # If it's None, then the span for the current id hasn't been started.
380422 if self .active_spans is not None and self .active_spans .get ("ti:" + str (ti .id )) is None :
381- if isinstance (ti , workloads . TaskInstance ):
423+ if isinstance (ti , TaskInstanceDTO ):
382424 parent_context = Trace .extract (ti .parent_context_carrier )
383425 else :
384426 parent_context = Trace .extract (ti .dag_run .context_carrier )
@@ -397,7 +439,8 @@ def trigger_tasks(self, open_slots: int) -> None:
397439 carrier = Trace .inject ()
398440 ti .context_carrier = carrier
399441
400- workload_list .append (item )
442+ workload_list .append (workload )
443+
401444 if workload_list :
402445 self ._process_workloads (workload_list )
403446
@@ -459,24 +502,25 @@ def running_state(self, key: TaskInstanceKey, info=None) -> None:
459502 """
460503 self .change_state (key , TaskInstanceState .RUNNING , info , remove_running = False )
461504
462- def get_event_buffer (self , dag_ids = None ) -> dict [TaskInstanceKey , EventBufferValueType ]:
505+ def get_event_buffer (self , dag_ids = None ) -> dict [WorkloadKey , EventBufferValueType ]:
463506 """
464507 Return and flush the event buffer.
465508
466509 In case dag_ids is specified it will only return and flush events
467510 for the given dag_ids. Otherwise, it returns and flushes all events.
511+ Note: Callback events (with string keys) are always included regardless of dag_ids filter.
468512
469513 :param dag_ids: the dag_ids to return events for; returns all if given ``None``.
470514 :return: a dict of events
471515 """
472- cleared_events : dict [TaskInstanceKey , EventBufferValueType ] = {}
516+ cleared_events : dict [WorkloadKey , EventBufferValueType ] = {}
473517 if dag_ids is None :
474518 cleared_events = self .event_buffer
475519 self .event_buffer = {}
476520 else :
477- for ti_key in list (self .event_buffer .keys ()):
478- if ti_key .dag_id in dag_ids :
479- cleared_events [ti_key ] = self .event_buffer .pop (ti_key )
521+ for key in list (self .event_buffer .keys ()):
522+ if isinstance ( key , CallbackKey ) or key .dag_id in dag_ids :
523+ cleared_events [key ] = self .event_buffer .pop (key )
480524
481525 return cleared_events
482526
@@ -529,21 +573,26 @@ def try_adopt_task_instances(self, tis: Sequence[TaskInstance]) -> Sequence[Task
529573
530574 @property
531575 def slots_available (self ):
532- """Number of new tasks this executor instance can accept."""
533- return self .parallelism - len (self .running ) - len (self .queued_tasks )
576+ """Number of new workloads ( tasks and callbacks) this executor instance can accept."""
577+ return self .parallelism - len (self .running ) - len (self .queued_tasks ) - len ( self . queued_callbacks )
534578
535579 @property
536580 def slots_occupied (self ):
537- """Number of tasks this executor instance is currently managing."""
538- return len (self .running ) + len (self .queued_tasks )
581+ """Number of workloads ( tasks and callbacks) this executor instance is currently managing."""
582+ return len (self .running ) + len (self .queued_tasks ) + len ( self . queued_callbacks )
539583
540584 def debug_dump (self ):
541585 """Get called in response to SIGUSR2 by the scheduler."""
542586 self .log .info (
543- "executor.queued (%d)\n \t %s" ,
587+ "executor.queued_tasks (%d)\n \t %s" ,
544588 len (self .queued_tasks ),
545589 "\n \t " .join (map (repr , self .queued_tasks .items ())),
546590 )
591+ self .log .info (
592+ "executor.queued_callbacks (%d)\n \t %s" ,
593+ len (self .queued_callbacks ),
594+ "\n \t " .join (map (repr , self .queued_callbacks .items ())),
595+ )
547596 self .log .info ("executor.running (%d)\n \t %s" , len (self .running ), "\n \t " .join (map (repr , self .running )))
548597 self .log .info (
549598 "executor.event_buffer (%d)\n \t %s" ,
0 commit comments