Skip to content

Commit 3d6ac35

Browse files
committed
Allow specifiying query tags as a dict upon connection creation
1 parent 38097f2 commit 3d6ac35

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

examples/query_tags_example.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
in the system.query.history table for analytical purposes.
99
1010
There are two ways to set query tags:
11-
1. Session-level: Set in session_configuration (applies to all queries in the session)
11+
1. Connection-level: Pass query_tags parameter to sql.connect() (applies to all queries in the session)
1212
2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query)
1313
1414
Format: Dictionary with string keys and optional string values
@@ -17,21 +17,18 @@
1717
Special cases:
1818
- If a value is None, only the key is included (no colon or value)
1919
- Special characters (comma, colon and backslash) in values are automatically escaped
20-
- Keys are not escaped (should be controlled identifiers)
20+
- Backslashes in keys are automatically escaped; other special characters in keys are not escaped
2121
"""
2222

2323
print("=== Query Tags Example ===\n")
2424

25-
# Example 1: Session-level query tags (old approach)
26-
print("Example 1: Session-level query tags")
25+
# Example 1: Connection-level query tags
26+
print("Example 1: Connection-level query tags")
2727
with sql.connect(
2828
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
2929
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
3030
access_token=os.getenv("DATABRICKS_TOKEN"),
31-
session_configuration={
32-
'QUERY_TAGS': 'team:engineering,test:query-tags',
33-
'ansi_mode': False
34-
}
31+
query_tags={"team": "engineering", "application": "etl"},
3532
) as connection:
3633

3734
with connection.cursor() as cursor:
@@ -41,7 +38,7 @@
4138

4239
print()
4340

44-
# Example 2: Per-query query tags (new approach)
41+
# Example 2: Per-query query tags
4542
print("Example 2: Per-query query tags")
4643
with sql.connect(
4744
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),

src/databricks/sql/client.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ColumnQueue,
3737
build_client_context,
3838
get_session_config_value,
39+
serialize_query_tags,
3940
)
4041
from databricks.sql.parameters.native import (
4142
DbsqlParameterBase,
@@ -106,6 +107,7 @@ def __init__(
106107
schema: Optional[str] = None,
107108
_use_arrow_native_complex_types: Optional[bool] = True,
108109
ignore_transactions: bool = True,
110+
query_tags: Optional[Dict[str, Optional[str]]] = None,
109111
**kwargs,
110112
) -> None:
111113
"""
@@ -281,6 +283,15 @@ def read(self) -> Optional[OAuthToken]:
281283
"spark.sql.thriftserver.metadata.metricview.enabled"
282284
] = "true"
283285

286+
if query_tags is not None:
287+
if session_configuration is None:
288+
session_configuration = {}
289+
serialized = serialize_query_tags(query_tags)
290+
if serialized:
291+
session_configuration["QUERY_TAGS"] = serialized
292+
else:
293+
session_configuration.pop("QUERY_TAGS", None)
294+
284295
self.disable_pandas = kwargs.get("_disable_pandas", False)
285296
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
286297
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)

tests/unit/test_session.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,24 @@ def test_finalizer_closes_abandoned_connection(self, mock_client_class):
202202
close_session_call_args = instance.close_session.call_args[0][0]
203203
assert close_session_call_args.guid == b"\x22"
204204
assert close_session_call_args.secret == b"\x33"
205+
206+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
207+
def test_query_tags_dict_sets_session_config(self, mock_client_class):
208+
databricks.sql.connect(
209+
query_tags={"team": "data-eng", "project": "etl"},
210+
**self.DUMMY_CONNECTION_ARGS,
211+
)
212+
213+
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
214+
assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:data-eng,project:etl"
215+
216+
@patch("%s.session.ThriftDatabricksClient" % PACKAGE_NAME)
217+
def test_query_tags_dict_takes_precedence_over_session_config(self, mock_client_class):
218+
databricks.sql.connect(
219+
query_tags={"team": "new-team"},
220+
session_configuration={"QUERY_TAGS": "team:old-team,other:value"},
221+
**self.DUMMY_CONNECTION_ARGS,
222+
)
223+
224+
call_kwargs = mock_client_class.return_value.open_session.call_args[1]
225+
assert call_kwargs["session_configuration"]["QUERY_TAGS"] == "team:new-team"

0 commit comments

Comments
 (0)