Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,7 @@
from typing import TYPE_CHECKING, Any
from uuid import UUID

from sqlalchemy import (
and_,
delete,
exists,
func,
inspect,
or_,
select,
text,
tuple_,
update,
)
from sqlalchemy import CTE, and_, delete, exists, func, inspect, or_, select, text, tuple_, update
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient, selectinload
from sqlalchemy.sql import expression
Expand Down Expand Up @@ -799,7 +788,11 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
task_instance,
)
starved_tasks_task_dagrun_concurrency.add(
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)
(
task_instance.dag_id,
task_instance.run_id,
task_instance.task_id,
)
)
continue

Expand Down Expand Up @@ -3070,44 +3063,41 @@ def _update_asset_orphanage(self, session: Session = NEW_SESSION) -> None:
)
== 0
).label("orphaned")
asset_reference_query = session.execute(
select(orphaned, AssetModel)
asset_reference_query = (
select(AssetModel)
.outerjoin(DagScheduleAssetReference)
.outerjoin(TaskOutletAssetReference)
.outerjoin(TaskInletAssetReference)
.group_by(AssetModel.id)
.order_by(orphaned)
Comment on lines +3067 to -3079

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider extracting this to a helper function in asset.py, like many others located there?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not see a benefit from doing it, and so I did not do it, do you have a reason for the request? as I might have missed something

)
asset_orphanation: dict[bool, Collection[AssetModel]] = {
orphaned: [asset for _, asset in group]
for orphaned, group in itertools.groupby(asset_reference_query, key=operator.itemgetter(0))
}
self._orphan_unreferenced_assets(asset_orphanation.get(True, ()), session=session)
self._activate_referenced_assets(asset_orphanation.get(False, ()), session=session)

orphan_query = asset_reference_query.having(orphaned).cte()
activate_query = asset_reference_query.having(~orphaned).cte()

self._orphan_unreferenced_assets(orphan_query, session=session)
self._activate_referenced_assets(activate_query, session=session)

@staticmethod
def _orphan_unreferenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
if assets:
session.execute(
delete(AssetActive).where(
tuple_(AssetActive.name, AssetActive.uri).in_((a.name, a.uri) for a in assets)
def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) -> None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can avoid passing a CTE as an argument (which is not intuitive) by using the helper function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you suggest then?
it had the least amount of duplicated code, if there are any suggestions, I would be happy to hear

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

asset_reference_query is a static query that never changes. If it's referenced in two places, maybe it's worth extracting it as a helper function, again?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This way we won't be passing CTEs as method parameters

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is harder to track that way in my opinion

way simpler to just see a query passed rather than go to a different method

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's harder to track. As it's a constant, reusable CTE, I would put it as a cached util function in the corresponding module instead of generating it in the scheduler code.

deleted_orphaned_assets = session.execute(
delete(AssetActive).where(
exists().where(
and_(AssetActive.name == assets_query.c.name, AssetActive.uri == assets_query.c.uri)
)
)
Stats.gauge("asset.orphaned", len(assets))
)

@staticmethod
def _activate_referenced_assets(assets: Collection[AssetModel], *, session: Session) -> None:
if not assets:
return
Stats.gauge("asset.orphaned", max(getattr(deleted_orphaned_assets, "rowcount", 0), 0))

active_assets = set(
session.execute(
select(AssetActive.name, AssetActive.uri).where(
tuple_(AssetActive.name, AssetActive.uri).in_((a.name, a.uri) for a in assets)
)
)
@staticmethod
def _activate_referenced_assets(assets_query: CTE, *, session: Session) -> None:
active_assets_query = select(AssetActive.name, AssetActive.uri).join(
assets_query,
and_(AssetActive.name == assets_query.c.name, AssetActive.uri == assets_query.c.uri),
Comment on lines +3094 to +3096

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a helper function in asset.py too?
Just to avoid adding even more logic into the scheduler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

)

active_assets = session.execute(active_assets_query).all()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are users with thousands of active assets, I wonder if this may explode one day.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a good point, maybe it is out of scope of the given PR, I might open a new PR for this after to handle large scale, as if, batching, yet as of now it is not an issue, and so for now I will leave it as is


active_name_to_uri: dict[str, str] = {name: uri for name, uri in active_assets}
active_uri_to_name: dict[str, str] = {uri: name for name, uri in active_assets}

Expand All @@ -3132,9 +3122,21 @@ def _generate_warning_message(
def _activate_assets_generate_warnings() -> Iterator[tuple[str, str]]:
incoming_name_to_uri: dict[str, str] = {}
incoming_uri_to_name: dict[str, str] = {}
for asset in assets:
if (asset.name, asset.uri) in active_assets:
continue
for asset in session.scalars(
select(AssetModel)
.join(
assets_query,
and_(
assets_query.c.name == AssetModel.name,
assets_query.c.uri == AssetModel.uri,
),
)
.where(
~active_assets_query.where(
and_(AssetActive.name == AssetModel.name, AssetActive.uri == AssetModel.uri)
).exists()
)
):
existing_uri = active_name_to_uri.get(asset.name) or incoming_name_to_uri.get(asset.name)
if existing_uri is not None and existing_uri != asset.uri:
yield from _generate_warning_message(asset, "name", existing_uri)
Expand Down
41 changes: 36 additions & 5 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
from collections import Counter, deque
from collections.abc import Callable, Generator, Iterator
from contextlib import ExitStack
from contextlib import ExitStack, contextmanager
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -7464,6 +7464,37 @@ def test_misconfigured_dags_doesnt_crash_scheduler(self, mock_create, session, d
job_runner._create_dag_runs([dm1], session)
assert "Failed creating DagRun" in caplog.text

def test_activate_referenced_assets_no_in_check_inside_query(self, session, testing_dag_bundle):
dag_id1 = "test_asset_dag1"
asset1_name = "asset1"
asset_extra = {"foo": "bar"}

asset1 = Asset(name=asset1_name, uri="s3://bucket/key/1", extra=asset_extra)
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1])
sync_dag_to_db(dag1, session=session)

@contextmanager
def assert_no_in_clause(session):
from sqlalchemy import event

def fail_on_in_clause_found(execute_statement):
if " IN " in str(execute_statement).upper():
execute_statement = str(execute_statement).upper()
pytest.fail(
f"Query contains IN clause which was removed in PR #62114, query: {execute_statement}"
)

event.listen(session, "do_orm_execute", fail_on_in_clause_found)
try:
yield
finally:
event.remove(session, "do_orm_execute", fail_on_in_clause_found)

asset_models = select(AssetModel).cte()

with assert_no_in_clause(session):
SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)

def test_activate_referenced_assets_with_no_existing_warning(self, session, testing_dag_bundle):
dag_warnings = session.scalars(select(DagWarning)).all()
assert dag_warnings == []
Expand All @@ -7478,8 +7509,8 @@ def test_activate_referenced_assets_with_no_existing_warning(self, session, test
dag1 = DAG(dag_id=dag_id1, start_date=DEFAULT_DATE, schedule=[asset1, asset1_1, asset1_2])
sync_dag_to_db(dag1, session=session)

asset_models = session.scalars(select(AssetModel)).all()
assert len(asset_models) == 3
asset_models = select(AssetModel).cte()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you add an explicit regression assertion for #61453's failure mode (large tuple-IN bind expansion)? These tests now validate behavior with a CTE input, but they don't directly guard against reintroducing a huge (name, uri) IN (...) path in scheduler asset activation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you think this can be added? As it does not cause failure when using in, rather just cause some slowdown

The only think I can think of is to check for the keyword 'in' for the str of the query

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

found a way to make it work with event listeners in sqlalchemy, added the test

assert len(session.execute(select(asset_models)).all()) == 3

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()
Expand Down Expand Up @@ -7516,7 +7547,7 @@ def test_activate_referenced_assets_with_existing_warnings(self, session, testin
)
session.flush()

asset_models = session.scalars(select(AssetModel)).all()
asset_models = select(AssetModel).cte()

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()
Expand Down Expand Up @@ -7565,7 +7596,7 @@ def test_activate_referenced_assets_with_multiple_conflict_asset_in_one_dag(
session.add(DagWarning(dag_id=dag_id, warning_type="asset conflict", message="will not exist"))
session.flush()

asset_models = session.scalars(select(AssetModel)).all()
asset_models = select(AssetModel).cte()

SchedulerJobRunner._activate_referenced_assets(asset_models, session=session)
session.flush()
Expand Down
8 changes: 7 additions & 1 deletion devel-common/src/tests_common/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,13 @@ def _activate_assets(self):
AssetModel.producing_tasks.any(TaskOutletAssetReference.dag_id == self.dag.dag_id),
)

assets = self.session.scalars(select(AssetModel).where(assets_select_condition)).all()
assets = select(AssetModel).where(assets_select_condition).cte()

if not AIRFLOW_V_3_2_PLUS:
assets = self.session.scalars(
select(AssetModel).join(assets, AssetModel.id == AssetModel.id)
).all()

SchedulerJobRunner._activate_referenced_assets(assets, session=self.session)

def __exit__(self, type, value, traceback):
Expand Down
Loading