Skip to content

Commit 17290c6

Browse files
authored
AIP-67 - Multi Team: Pass args/kwargs to super in CeleryExecutor (#56006)
* pass args/kwargs to super in Celery executors This allows the team_name to be passed to the super class. This is low hanging fruit to allow the Celery executor to be used for multi team testing. More changes will be needed to allow the Celery executor to use team-based config, but that will be done at a future time. * Fix compat tests
1 parent c18288f commit 17290c6

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

providers/celery/src/airflow/providers/celery/executors/celery_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,8 @@ class CeleryExecutor(BaseExecutor):
290290
# TODO: TaskSDK: move this type change into BaseExecutor
291291
queued_tasks: dict[TaskInstanceKey, workloads.All] # type: ignore[assignment]
292292

293-
def __init__(self):
294-
super().__init__()
293+
def __init__(self, *args, **kwargs):
294+
super().__init__(*args, **kwargs)
295295

296296
# Celery doesn't support bulk sending the tasks (which can become a bottleneck on bigger clusters)
297297
# so we use a multiprocessing pool to speed this up.

providers/celery/tests/unit/celery/executors/test_celery_executor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,24 @@ def test_supports_sentry(self):
126126
def test_cli_commands_vended(self):
127127
assert CeleryExecutor.get_cli_commands()
128128

129+
def test_celery_executor_init_with_args_kwargs(self):
130+
"""Test that CeleryExecutor properly passes args and kwargs to BaseExecutor."""
131+
parallelism = 50
132+
team_name = "test_team"
133+
134+
if AIRFLOW_V_3_1_PLUS:
135+
# team_name was added in Airflow 3.1
136+
executor = celery_executor.CeleryExecutor(parallelism=parallelism, team_name=team_name)
137+
else:
138+
executor = celery_executor.CeleryExecutor(parallelism)
139+
140+
assert executor.parallelism == parallelism
141+
142+
if AIRFLOW_V_3_1_PLUS:
143+
# team_name was added in Airflow 3.1
144+
assert executor.team_name == team_name
145+
assert executor.conf.team_name == team_name
146+
129147
@pytest.mark.backend("mysql", "postgres")
130148
def test_exception_propagation(self, caplog):
131149
caplog.set_level(

0 commit comments

Comments
 (0)