Skip to content

Commit 8a3a522

Browse files
Add support to create connections using uri in SDK (#62211)
* Add support to create connections using uri in SDK * Add support to create connections using uri in SDK * Add support to create connections using uri in SDK * fixing unit tests * Allow URI without adding it as a field * fixing failing tests and mypy * remove unwanted line --------- Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
1 parent 917abea commit 8a3a522

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

task-sdk/src/airflow/sdk/definitions/connection.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import json
2121
import logging
2222
from json import JSONDecodeError
23-
from typing import Any
23+
from typing import Any, overload
2424
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
2525

2626
import attrs
@@ -93,7 +93,7 @@ def _prune_dict(val: Any, mode="strict"):
9393
return val
9494

9595

96-
@attrs.define
96+
@attrs.define(slots=False)
9797
class Connection:
9898
"""
9999
A connection to an external data source.
@@ -108,6 +108,7 @@ class Connection:
108108
:param port: The port number.
109109
:param extra: Extra metadata. Non-standard data such as private/SSH keys can be saved here. JSON
110110
encoded object.
111+
:param uri: URI address describing connection parameters.
111112
"""
112113

113114
conn_id: str
@@ -122,6 +123,36 @@ class Connection:
122123

123124
EXTRA_KEY = "__extra__"
124125

126+
@overload
127+
def __init__(self, *, conn_id: str, uri: str) -> None: ...
128+
129+
@overload
130+
def __init__(
131+
self,
132+
*,
133+
conn_id: str,
134+
conn_type: str | None = None,
135+
description: str | None = None,
136+
host: str | None = None,
137+
schema: str | None = None,
138+
login: str | None = None,
139+
password: str | None = None,
140+
port: int | None = None,
141+
extra: str | None = None,
142+
) -> None: ...
143+
144+
def __init__(self, *, conn_id: str, uri: str | None = None, **kwargs) -> None:
145+
if uri is not None and kwargs:
146+
raise AirflowException(
147+
"You must create an object using the URI or individual values "
148+
"(conn_type, host, login, password, schema, port or extra). "
149+
"You can't mix these two ways to create this object."
150+
)
151+
if uri is None:
152+
self.__attrs_init__(conn_id=conn_id, **kwargs) # type: ignore[attr-defined]
153+
else:
154+
self.__dict__.update(self.from_uri(uri, conn_id=conn_id).to_dict(validate=False))
155+
125156
def get_uri(self) -> str:
126157
"""Generate and return connection in URI format."""
127158
from urllib.parse import parse_qsl

task-sdk/tests/task_sdk/definitions/test_connection.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,20 @@ def test_from_uri_invalid_protocol_host_error(self):
377377
with pytest.raises(AirflowException, match="Invalid connection string"):
378378
Connection.from_uri(uri, conn_id="test_conn")
379379

380+
def test_connection_constructor_with_uri(self):
381+
"""Test Connection(uri=..., conn_id=...) constructor form."""
382+
conn = Connection(conn_id="test_conn", uri="postgres://user:pass@host:5432/db")
383+
384+
assert conn.conn_id == "test_conn"
385+
assert conn.conn_type == "postgres"
386+
assert conn.host == "host"
387+
assert conn.login == "user"
388+
assert conn.password == "pass"
389+
assert conn.port == 5432
390+
assert conn.schema == "db"
391+
# uri should not exist as an attribute (it's init-only)
392+
assert not hasattr(conn, "uri")
393+
380394
def test_from_uri_roundtrip(self):
381395
"""Test that from_uri and get_uri are inverse operations."""
382396
original_uri = "postgres://user:pass@host:5432/db?param1=value1&param2=value2"

0 commit comments

Comments
 (0)