Skip to content

Commit 07945cb

Browse files
Add @task.analytics Decorator (#62648)
* Add @task.analytics decorator * Fix publish-docs to respect not ready packages when requested to build
1 parent 0ffb28e commit 07945cb

File tree

7 files changed

+291
-3
lines changed

7 files changed

+291
-3
lines changed

dev/breeze/src/airflow_breeze/utils/packages.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,10 @@ def find_matching_long_package_names(
419419
removed_packages: list[str] = [
420420
f"apache-airflow-providers-{provider.replace('.', '-')}" for provider in get_removed_provider_ids()
421421
]
422-
all_packages_including_removed: list[str] = available_doc_packages + removed_packages
422+
not_ready_packages: list[str] = [
423+
f"apache-airflow-providers-{provider.replace('.', '-')}" for provider in get_not_ready_provider_ids()
424+
]
425+
all_packages_including_removed: list[str] = available_doc_packages + removed_packages + not_ready_packages
423426
invalid_filters = [
424427
f
425428
for f in processed_package_filters

providers/common/sql/docs/operators.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,14 @@ Local File System Storage
317317
:dedent: 4
318318
:start-after: [START howto_analytics_operator_with_local]
319319
:end-before: [END howto_analytics_operator_with_local]
320+
321+
Analytics TaskFlow Decorator
322+
----------------------------
323+
324+
The ``@task.analytics`` decorator lets you write a function that returns the
325+
analytics sql queries:
326+
327+
.. exampleinclude:: /../../sql/src/airflow/providers/common/sql/example_dags/example_analytics.py
328+
:language: python
329+
:start-after: [START howto_analytics_decorator]
330+
:end-before: [END howto_analytics_decorator]

providers/common/sql/provider.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,3 +129,5 @@ sensors:
129129
task-decorators:
130130
- class-name: airflow.providers.common.sql.decorators.sql.sql_task
131131
name: sql
132+
- class-name: airflow.providers.common.sql.decorators.analytics.analytics_task
133+
name: analytics
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from collections.abc import Callable, Collection, Mapping, Sequence
21+
from typing import TYPE_CHECKING, Any, ClassVar
22+
23+
from airflow.providers.common.compat.sdk import (
24+
AIRFLOW_V_3_0_PLUS,
25+
DecoratedOperator,
26+
TaskDecorator,
27+
context_merge,
28+
task_decorator_factory,
29+
)
30+
from airflow.providers.common.sql.operators.analytics import AnalyticsOperator
31+
from airflow.utils.operator_helpers import determine_kwargs
32+
33+
if AIRFLOW_V_3_0_PLUS:
34+
from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
35+
else:
36+
from airflow.utils.types import NOTSET as SET_DURING_EXECUTION # type: ignore[attr-defined,no-redef]
37+
38+
39+
if TYPE_CHECKING:
40+
from airflow.providers.common.compat.sdk import Context
41+
42+
43+
class _AnalyticsDecoratedOperator(DecoratedOperator, AnalyticsOperator):
44+
"""
45+
Wraps a Python callable and uses the callable return value as the SQL commands to be executed.
46+
47+
:param python_callable: A reference to an object that is callable.
48+
:param op_kwargs: A dictionary of keyword arguments that will get unpacked (templated).
49+
:param op_args: A list of positional arguments that will get unpacked (templated).
50+
"""
51+
52+
template_fields: Sequence[str] = (
53+
*DecoratedOperator.template_fields,
54+
*AnalyticsOperator.template_fields,
55+
)
56+
template_fields_renderers: ClassVar[dict[str, str]] = {
57+
**DecoratedOperator.template_fields_renderers,
58+
**AnalyticsOperator.template_fields_renderers,
59+
}
60+
61+
overwrite_rtif_after_execution: bool = True
62+
63+
custom_operator_name: str = "@task.analytics"
64+
65+
def __init__(
66+
self,
67+
python_callable: Callable,
68+
op_args: Collection[Any] | None = None,
69+
op_kwargs: Mapping[str, Any] | None = None,
70+
**kwargs,
71+
) -> None:
72+
super().__init__(
73+
python_callable=python_callable,
74+
op_args=op_args,
75+
op_kwargs=op_kwargs,
76+
queries=SET_DURING_EXECUTION,
77+
**kwargs,
78+
)
79+
80+
@property
81+
def xcom_push(self) -> bool:
82+
"""Compatibility property for BaseDecorator that expects xcom_push attribute."""
83+
return self.do_xcom_push
84+
85+
@xcom_push.setter
86+
def xcom_push(self, value: bool) -> None:
87+
"""Compatibility setter for BaseDecorator that expects xcom_push attribute."""
88+
self.do_xcom_push = value
89+
90+
def execute(self, context: Context) -> Any:
91+
"""
92+
Build the SQL and execute the generated query (or queries).
93+
94+
:param context: Airflow context.
95+
:return: Any
96+
"""
97+
context_merge(context, self.op_kwargs)
98+
kwargs = determine_kwargs(self.python_callable, self.op_args, context)
99+
100+
# Set the queries using the Python callable
101+
result = self.python_callable(*self.op_args, **kwargs)
102+
103+
# Only non-empty strings and non-empty lists of non-empty strings are acceptable return types
104+
if (
105+
not isinstance(result, (str, list))
106+
or (isinstance(result, str) and not result.strip())
107+
or (
108+
isinstance(result, list)
109+
and (not result or not all(isinstance(s, str) and s.strip() for s in result))
110+
)
111+
):
112+
raise TypeError(
113+
"The returned value from the @task.analytics callable must be a non-empty string "
114+
"or a non-empty list of non-empty strings."
115+
)
116+
117+
# AnalyticsOperator expects queries as a list of strings
118+
self.queries = [result] if isinstance(result, str) else result
119+
120+
self.render_template_fields(context)
121+
122+
return AnalyticsOperator.execute(self, context)
123+
124+
125+
def analytics_task(python_callable=None, **kwargs) -> TaskDecorator:
126+
"""
127+
Wrap a Python function into a AnalyticsOperator.
128+
129+
:param python_callable: Function to decorate.
130+
131+
:meta private:
132+
"""
133+
return task_decorator_factory(
134+
python_callable=python_callable,
135+
decorated_operator_class=_AnalyticsDecoratedOperator,
136+
**kwargs,
137+
)

providers/common/sql/src/airflow/providers/common/sql/example_dags/example_analytics.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from airflow.providers.common.sql.config import DataSourceConfig
2222
from airflow.providers.common.sql.operators.analytics import AnalyticsOperator
23-
from airflow.sdk import DAG
23+
from airflow.sdk import DAG, task
2424

2525
datasource_config_s3 = DataSourceConfig(
2626
conn_id="aws_default", table_name="users_data", uri="s3://bucket/path/", format="parquet"
@@ -56,3 +56,12 @@
5656
)
5757
analytics_with_s3 >> analytics_with_local
5858
# [END howto_analytics_operator_with_local]
59+
60+
# [START howto_analytics_decorator]
61+
@task.analytics(datasource_configs=[datasource_config_s3])
62+
def get_user_summary_queries():
63+
return ["SELECT * FROM users_data LIMIT 10", "SELECT count(*) FROM users_data"]
64+
65+
# [END howto_analytics_decorator]
66+
67+
analytics_with_local >> get_user_summary_queries()

providers/common/sql/src/airflow/providers/common/sql/get_provider_info.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ def get_provider_info():
7171
{"integration-name": "Common SQL", "python-modules": ["airflow.providers.common.sql.sensors.sql"]}
7272
],
7373
"task-decorators": [
74-
{"class-name": "airflow.providers.common.sql.decorators.sql.sql_task", "name": "sql"}
74+
{"class-name": "airflow.providers.common.sql.decorators.sql.sql_task", "name": "sql"},
75+
{
76+
"class-name": "airflow.providers.common.sql.decorators.analytics.analytics_task",
77+
"name": "analytics",
78+
},
7579
],
7680
}
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from unittest.mock import MagicMock, patch
21+
22+
import pytest
23+
24+
from airflow.providers.common.sql.config import DataSourceConfig
25+
from airflow.providers.common.sql.decorators.analytics import _AnalyticsDecoratedOperator
26+
27+
DATASOURCE_CONFIGS = [
28+
DataSourceConfig(conn_id="", table_name="users_data", uri="file:///path/to/", format="parquet")
29+
]
30+
31+
32+
class TestAnalyticsDecoratedOperator:
33+
def test_custom_operator_name(self):
34+
assert _AnalyticsDecoratedOperator.custom_operator_name == "@task.analytics"
35+
36+
@patch(
37+
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
38+
autospec=True,
39+
)
40+
def test_execute_calls_callable_and_sets_queries_from_list(self, mock_execute):
41+
"""The callable return value (list) becomes self.queries."""
42+
mock_execute.return_value = "mocked output"
43+
44+
def get_user_queries():
45+
return ["SELECT * FROM users_data", "SELECT count(*) FROM users_data"]
46+
47+
op = _AnalyticsDecoratedOperator(
48+
task_id="test",
49+
python_callable=get_user_queries,
50+
datasource_configs=DATASOURCE_CONFIGS,
51+
)
52+
result = op.execute(context={})
53+
54+
assert result == "mocked output"
55+
assert op.queries == ["SELECT * FROM users_data", "SELECT count(*) FROM users_data"]
56+
mock_execute.assert_called_once()
57+
58+
@patch(
59+
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
60+
autospec=True,
61+
)
62+
def test_execute_wraps_single_string_into_list(self, mock_execute):
63+
"""A single string return value is wrapped into a list for self.queries."""
64+
mock_execute.return_value = "mocked output"
65+
66+
def get_single_query():
67+
return "SELECT 1"
68+
69+
op = _AnalyticsDecoratedOperator(
70+
task_id="test",
71+
python_callable=get_single_query,
72+
datasource_configs=DATASOURCE_CONFIGS,
73+
)
74+
op.execute(context={})
75+
76+
assert op.queries == ["SELECT 1"]
77+
78+
@pytest.mark.parametrize(
79+
"return_value",
80+
[42, "", " ", None, [], [""], ["SELECT 1", ""], ["SELECT 1", " "], [42]],
81+
ids=[
82+
"non-string",
83+
"empty-string",
84+
"whitespace-string",
85+
"none",
86+
"empty-list",
87+
"list-with-empty-string",
88+
"list-with-one-valid-one-empty",
89+
"list-with-one-valid-one-whitespace",
90+
"list-with-non-string",
91+
],
92+
)
93+
def test_execute_raises_on_invalid_return_value(self, return_value):
94+
"""TypeError when the callable returns an invalid value."""
95+
op = _AnalyticsDecoratedOperator(
96+
task_id="test",
97+
python_callable=lambda: return_value,
98+
datasource_configs=DATASOURCE_CONFIGS,
99+
)
100+
with pytest.raises(TypeError, match="non-empty string"):
101+
op.execute(context={})
102+
103+
@patch(
104+
"airflow.providers.common.sql.operators.analytics.AnalyticsOperator.execute",
105+
autospec=True,
106+
)
107+
def test_execute_merges_op_kwargs_into_callable(self, mock_execute):
108+
"""op_kwargs are forwarded to the callable to build queries."""
109+
mock_execute.return_value = "mocked output"
110+
111+
def get_queries_for_table(table_name):
112+
return [f"SELECT * FROM {table_name}", f"SELECT count(*) FROM {table_name}"]
113+
114+
op = _AnalyticsDecoratedOperator(
115+
task_id="test",
116+
python_callable=get_queries_for_table,
117+
datasource_configs=DATASOURCE_CONFIGS,
118+
op_kwargs={"table_name": "orders"},
119+
)
120+
op.execute(context={"task_instance": MagicMock()})
121+
122+
assert op.queries == ["SELECT * FROM orders", "SELECT count(*) FROM orders"]

0 commit comments

Comments
 (0)