Skip to content

Commit da4ce85

Browse files
kaxilgopidesupavan
andauthored
Add LLMSQLQueryOperator and @task.llm_sql to common.ai provider (#62599)
SQL query generation from natural language, inheriting from `LLMOperator` for shared LLM connection handling, `agent_params`, and `system_prompt`. - Schema introspection via `db_conn_id` + `table_names` using DbApiHook - Defense-in-depth SQL safety: AST validation via sqlglot (allowlist + single-statement enforcement + system prompt instructions) - User-provided `system_prompt` appended to built-in SQL safety prompt for domain-specific guidance (e.g. "prefer CTEs over subqueries") - `agent_params` inherited from LLMOperator (retries, temperature, etc.) - Generate-only mode: returns SQL string, does not execute Co-authored-by: GPK <gopidesupavan@gmail.com>
1 parent e5301ea commit da4ce85

File tree

18 files changed

+1239
-5
lines changed

18 files changed

+1239
-5
lines changed

dev/breeze/tests/test_selective_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2249,7 +2249,7 @@ def test_upgrade_to_newer_dependencies(
22492249
("providers/common/sql/src/airflow/providers/common/sql/common_sql_python.py",),
22502250
{
22512251
"docs-list-as-string": "amazon apache.drill apache.druid apache.hive "
2252-
"apache.impala apache.pinot common.compat common.sql databricks elasticsearch "
2252+
"apache.impala apache.pinot common.ai common.compat common.sql databricks elasticsearch "
22532253
"exasol google jdbc microsoft.mssql mysql odbc openlineage "
22542254
"oracle pgvector postgres presto slack snowflake sqlite teradata trino vertica ydb",
22552255
},

docs/spelling_wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ del
502502
delim
503503
delitem
504504
deltalake
505+
denylist
505506
dep
506507
DependencyMixin
507508
deployable
@@ -1726,6 +1727,7 @@ sql
17261727
sqla
17271728
Sqlalchemy
17281729
sqlalchemy
1730+
sqlglot
17291731
Sqlite
17301732
sqlite
17311733
sqlproxy
@@ -1803,6 +1805,7 @@ Subpath
18031805
subpath
18041806
subprocess
18051807
subprocesses
1808+
subqueries
18061809
subquery
18071810
SubscriberClient
18081811
subscriptionId

providers/common/ai/docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ You can install such cross-provider dependencies when installing from PyPI. For
121121
Dependent package Extra
122122
================================================================================================================== =================
123123
`apache-airflow-providers-common-compat <https://airflow.apache.org/docs/apache-airflow-providers-common-compat>`_ ``common.compat``
124+
`apache-airflow-providers-common-sql <https://airflow.apache.org/docs/apache-airflow-providers-common-sql>`_ ``common.sql``
124125
================================================================================================================== =================
125126

126127
Downloading official packages
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
.. Licensed to the Apache Software Foundation (ASF) under one
2+
or more contributor license agreements. See the NOTICE file
3+
distributed with this work for additional information
4+
regarding copyright ownership. The ASF licenses this file
5+
to you under the Apache License, Version 2.0 (the
6+
"License"); you may not use this file except in compliance
7+
with the License. You may obtain a copy of the License at
8+
9+
.. http://www.apache.org/licenses/LICENSE-2.0
10+
11+
.. Unless required by applicable law or agreed to in writing,
12+
software distributed under the License is distributed on an
13+
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
KIND, either express or implied. See the License for the
15+
specific language governing permissions and limitations
16+
under the License.
17+
18+
.. _howto/operator:llm_sql_query:
19+
20+
``LLMSQLQueryOperator``
21+
========================
22+
23+
Use :class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator` to generate
24+
SQL queries from natural language using an LLM.
25+
26+
The operator generates SQL but does not execute it. The generated query is returned
27+
as XCom and can be passed to ``SQLExecuteQueryOperator`` or used in downstream tasks.
28+
29+
.. seealso::
30+
:ref:`Connection configuration <howto/connection:pydantic_ai>`
31+
32+
Basic Usage
33+
-----------
34+
35+
Provide a natural language ``prompt`` and the operator generates a SQL query:
36+
37+
.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
38+
:language: python
39+
:start-after: [START howto_operator_llm_sql_basic]
40+
:end-before: [END howto_operator_llm_sql_basic]
41+
42+
With Schema Introspection
43+
-------------------------
44+
45+
Use ``db_conn_id`` and ``table_names`` to automatically include database schema
46+
in the LLM's context. This produces more accurate queries because the LLM knows
47+
the actual column names and types:
48+
49+
.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
50+
:language: python
51+
:start-after: [START howto_operator_llm_sql_schema]
52+
:end-before: [END howto_operator_llm_sql_schema]
53+
54+
TaskFlow Decorator
55+
------------------
56+
57+
The ``@task.llm_sql`` decorator lets you write a function that returns the
58+
prompt. The decorator handles LLM connection, schema introspection, SQL generation,
59+
and safety validation:
60+
61+
.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
62+
:language: python
63+
:start-after: [START howto_decorator_llm_sql]
64+
:end-before: [END howto_decorator_llm_sql]
65+
66+
Dynamic Task Mapping
67+
--------------------
68+
69+
Generate SQL for multiple prompts in parallel using ``expand()``:
70+
71+
.. exampleinclude:: /../../ai/src/airflow/providers/common/ai/example_dags/example_llm_sql.py
72+
:language: python
73+
:start-after: [START howto_operator_llm_sql_expand]
74+
:end-before: [END howto_operator_llm_sql_expand]
75+
76+
SQL Safety Validation
77+
---------------------
78+
79+
By default, the operator validates generated SQL using an allowlist approach:
80+
81+
- Only ``SELECT``, ``UNION``, ``INTERSECT``, and ``EXCEPT`` statements are allowed.
82+
- Multi-statement SQL (semicolon-separated) is rejected.
83+
- Disallowed statements (``INSERT``, ``UPDATE``, ``DELETE``, ``DROP``, etc.) raise
84+
:class:`~airflow.providers.common.ai.utils.sql_validation.SQLSafetyError`.
85+
86+
You can disable validation with ``validate_sql=False`` or customize the allowed
87+
statement types with ``allowed_sql_types``.

providers/common/ai/provider.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ integrations:
3333
external-doc-url: https://airflow.apache.org/docs/apache-airflow-providers-common-ai/
3434
how-to-guide:
3535
- /docs/apache-airflow-providers-common-ai/operators/llm.rst
36+
- /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst
3637
tags: [ai]
3738
- integration-name: Pydantic AI
3839
external-doc-url: https://ai.pydantic.dev/
@@ -61,7 +62,10 @@ operators:
6162
- integration-name: Common AI
6263
python-modules:
6364
- airflow.providers.common.ai.operators.llm
65+
- airflow.providers.common.ai.operators.llm_sql
6466

6567
task-decorators:
6668
- class-name: airflow.providers.common.ai.decorators.llm.llm_task
6769
name: llm
70+
- class-name: airflow.providers.common.ai.decorators.llm_sql.llm_sql_task
71+
name: llm_sql

providers/common/ai/pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,23 @@ dependencies = [
7272
"common.compat" = [
7373
"apache-airflow-providers-common-compat"
7474
]
75+
"sql" = [
76+
"apache-airflow-providers-common-sql",
77+
"sqlglot>=26.0.0",
78+
]
79+
"common.sql" = [
80+
"apache-airflow-providers-common-sql"
81+
]
7582

7683
[dependency-groups]
7784
dev = [
7885
"apache-airflow",
7986
"apache-airflow-task-sdk",
8087
"apache-airflow-devel-common",
8188
"apache-airflow-providers-common-compat",
89+
"apache-airflow-providers-common-sql",
8290
# Additional devel dependencies (do not remove this line and add extra development dependencies)
91+
"sqlglot>=26.0.0",
8392
]
8493

8594
# To build docs:
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
TaskFlow decorator for LLM SQL generation.
19+
20+
The user writes a function that **returns the prompt**. The decorator handles
21+
the LLM call, schema introspection, and safety validation. The decorated task's
22+
XCom output is the generated SQL string.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
from collections.abc import Callable, Collection, Mapping, Sequence
28+
from typing import TYPE_CHECKING, Any, ClassVar
29+
30+
from airflow.providers.common.ai.operators.llm_sql import LLMSQLQueryOperator
31+
from airflow.providers.common.compat.sdk import (
32+
DecoratedOperator,
33+
TaskDecorator,
34+
context_merge,
35+
task_decorator_factory,
36+
)
37+
from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
38+
from airflow.utils.operator_helpers import determine_kwargs
39+
40+
if TYPE_CHECKING:
41+
from airflow.sdk import Context
42+
43+
44+
class _LLMSQLDecoratedOperator(DecoratedOperator, LLMSQLQueryOperator):
45+
"""
46+
Wraps a callable that returns a prompt for LLM SQL generation.
47+
48+
The user function is called at execution time to produce the prompt string.
49+
All other parameters (``llm_conn_id``, ``db_conn_id``, ``table_names``, etc.)
50+
are passed through to :class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator`.
51+
52+
:param python_callable: A reference to a callable that returns the prompt string.
53+
:param op_args: Positional arguments for the callable.
54+
:param op_kwargs: Keyword arguments for the callable.
55+
"""
56+
57+
template_fields: Sequence[str] = (
58+
*DecoratedOperator.template_fields,
59+
*LLMSQLQueryOperator.template_fields,
60+
)
61+
template_fields_renderers: ClassVar[dict[str, str]] = {
62+
**DecoratedOperator.template_fields_renderers,
63+
}
64+
65+
custom_operator_name: str = "@task.llm_sql"
66+
67+
def __init__(
68+
self,
69+
*,
70+
python_callable: Callable,
71+
op_args: Collection[Any] | None = None,
72+
op_kwargs: Mapping[str, Any] | None = None,
73+
**kwargs,
74+
) -> None:
75+
super().__init__(
76+
python_callable=python_callable,
77+
op_args=op_args,
78+
op_kwargs=op_kwargs,
79+
prompt=SET_DURING_EXECUTION,
80+
**kwargs,
81+
)
82+
83+
def execute(self, context: Context) -> Any:
84+
context_merge(context, self.op_kwargs)
85+
kwargs = determine_kwargs(self.python_callable, self.op_args, context)
86+
87+
self.prompt = self.python_callable(*self.op_args, **kwargs)
88+
89+
if not isinstance(self.prompt, str) or not self.prompt.strip():
90+
raise TypeError("The returned value from the @task.llm_sql callable must be a non-empty string.")
91+
92+
self.render_template_fields(context)
93+
# Call LLMSQLQueryOperator.execute directly, not super().execute(),
94+
# because we need to skip DecoratedOperator.execute — the callable
95+
# invocation is already handled above.
96+
return LLMSQLQueryOperator.execute(self, context)
97+
98+
99+
def llm_sql_task(
100+
python_callable: Callable | None = None,
101+
**kwargs,
102+
) -> TaskDecorator:
103+
"""
104+
Wrap a function that returns a natural language prompt into an LLM SQL task.
105+
106+
The function body constructs the prompt (can use Airflow context, XCom, etc.).
107+
The decorator handles: LLM connection, schema introspection, SQL generation,
108+
and safety validation.
109+
110+
Usage::
111+
112+
@task.llm_sql(
113+
llm_conn_id="openai_default",
114+
db_conn_id="postgres_default",
115+
table_names=["customers", "orders"],
116+
)
117+
def build_query(ds=None):
118+
return f"Find top 10 customers by revenue in {ds}"
119+
120+
:param python_callable: Function to decorate.
121+
"""
122+
return task_decorator_factory(
123+
python_callable=python_callable,
124+
decorated_operator_class=_LLMSQLDecoratedOperator,
125+
**kwargs,
126+
)

0 commit comments

Comments
 (0)