Skip to content

Commit 38097f2

Browse files
authored
Add query_tags parameter support for execute methods (#736)
* Add statement level query tag support by introducing it as a parameter on execute* methods Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com> * Add query_tags support to executemany method - Added query_tags parameter to executemany() method - Query tags are applied to all queries in the batch - Updated example to demonstrate executemany usage with query_tags - All tests pass (122/122 client tests) Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com> * add example that doesn't have tag Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com> * fix presubmit errors Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com> * another lint Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com> * address review comments Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com> --------- Signed-off-by: Jiabin Hu <jiabin.hu@databricks.com>
1 parent 9fe7356 commit 38097f2

File tree

7 files changed

+244
-10
lines changed

7 files changed

+244
-10
lines changed

examples/query_tags_example.py

Lines changed: 94 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,23 @@
77
Query Tags are key-value pairs that can be attached to SQL executions and will appear
88
in the system.query.history table for analytical purposes.
99
10-
Format: "key1:value1,key2:value2,key3:value3"
10+
There are two ways to set query tags:
11+
1. Session-level: Set in session_configuration (applies to all queries in the session)
12+
2. Per-query level: Pass query_tags parameter to execute() or execute_async() (applies to specific query)
13+
14+
Format: Dictionary with string keys and optional string values
15+
Example: {"team": "engineering", "application": "etl", "priority": "high"}
16+
17+
Special cases:
18+
- If a value is None, only the key is included (no colon or value)
19+
- Special characters (comma, colon and backslash) in values are automatically escaped
20+
- Keys are not escaped (should be controlled identifiers)
1121
"""
1222

1323
print("=== Query Tags Example ===\n")
1424

25+
# Example 1: Session-level query tags (old approach)
26+
print("Example 1: Session-level query tags")
1527
with sql.connect(
1628
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
1729
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
@@ -21,10 +33,89 @@
2133
'ansi_mode': False
2234
}
2335
) as connection:
24-
36+
2537
with connection.cursor() as cursor:
2638
cursor.execute("SELECT 1")
2739
result = cursor.fetchone()
2840
print(f" Result: {result[0]}")
2941

30-
print("\n=== Query Tags Example Complete ===")
42+
print()
43+
44+
# Example 2: Per-query query tags (new approach)
45+
print("Example 2: Per-query query tags")
46+
with sql.connect(
47+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
48+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
49+
access_token=os.getenv("DATABRICKS_TOKEN"),
50+
) as connection:
51+
52+
with connection.cursor() as cursor:
53+
# Query 1: Tags for a critical ETL job
54+
cursor.execute(
55+
"SELECT 1",
56+
query_tags={"team": "data-eng", "application": "etl", "priority": "high"}
57+
)
58+
result = cursor.fetchone()
59+
print(f" ETL Query Result: {result[0]}")
60+
61+
# Query 2: Tags with None value (key-only tag)
62+
cursor.execute(
63+
"SELECT 2",
64+
query_tags={"team": "analytics", "experimental": None}
65+
)
66+
result = cursor.fetchone()
67+
print(f" Experimental Query Result: {result[0]}")
68+
69+
# Query 3: Tags with special characters (automatically escaped)
70+
cursor.execute(
71+
"SELECT 3",
72+
query_tags={"description": "test:with:colons,and,commas"}
73+
)
74+
result = cursor.fetchone()
75+
print(f" Special Chars Query Result: {result[0]}")
76+
77+
# Query 4: No tags (demonstrates tags don't persist from previous queries)
78+
cursor.execute("SELECT 4")
79+
result = cursor.fetchone()
80+
print(f" No Tags Query Result: {result[0]}")
81+
82+
print()
83+
84+
# Example 3: Async execution with query tags
85+
print("Example 3: Async execution with query tags")
86+
with sql.connect(
87+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
88+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
89+
access_token=os.getenv("DATABRICKS_TOKEN"),
90+
) as connection:
91+
92+
with connection.cursor() as cursor:
93+
cursor.execute_async(
94+
"SELECT 5",
95+
query_tags={"team": "data-eng", "mode": "async"}
96+
)
97+
cursor.get_async_execution_result()
98+
result = cursor.fetchone()
99+
print(f" Async Query Result: {result[0]}")
100+
101+
print()
102+
103+
# Example 4: executemany with query tags
104+
print("Example 4: executemany with query tags")
105+
with sql.connect(
106+
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
107+
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
108+
access_token=os.getenv("DATABRICKS_TOKEN"),
109+
) as connection:
110+
111+
with connection.cursor() as cursor:
112+
# Execute multiple queries with the same tags
113+
cursor.executemany(
114+
"SELECT ?",
115+
[[6], [7], [8]],
116+
query_tags={"team": "data-eng", "batch": "executemany"}
117+
)
118+
result = cursor.fetchone()
119+
print(f" Executemany Query Result (last): {result[0]}")
120+
121+
print("\n=== Query Tags Example Complete ===")

src/databricks/sql/backend/databricks_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def execute_command(
8383
async_op: bool,
8484
enforce_embedded_schema_correctness: bool,
8585
row_limit: Optional[int] = None,
86+
query_tags: Optional[Dict[str, Optional[str]]] = None,
8687
) -> Union[ResultSet, None]:
8788
"""
8889
Executes a SQL command or query within the specified session.
@@ -102,6 +103,7 @@ def execute_command(
102103
async_op: Whether to execute the command asynchronously
103104
enforce_embedded_schema_correctness: Whether to enforce schema correctness
104105
row_limit: Maximum number of rows in the response.
106+
query_tags: Optional dictionary of query tags to apply for this query only.
105107
106108
Returns:
107109
If async_op is False, returns a ResultSet object containing the

src/databricks/sql/backend/sea/backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,9 @@ def execute_command(
463463
async_op: bool,
464464
enforce_embedded_schema_correctness: bool,
465465
row_limit: Optional[int] = None,
466+
query_tags: Optional[
467+
Dict[str, Optional[str]]
468+
] = None, # TODO: implement query_tags for SEA backend
466469
) -> Union[SeaResultSet, None]:
467470
"""
468471
Execute a SQL command using the SEA backend.

src/databricks/sql/backend/thrift_backend.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
import time
77
import threading
8-
from typing import List, Optional, Union, Any, TYPE_CHECKING
8+
from typing import Dict, List, Optional, Union, Any, TYPE_CHECKING
99
from uuid import UUID
1010

1111
from databricks.sql.common.unified_http_client import UnifiedHttpClient
@@ -53,6 +53,7 @@
5353
convert_arrow_based_set_to_arrow_table,
5454
convert_decimals_in_arrow_table,
5555
convert_column_based_set_to_arrow_table,
56+
serialize_query_tags,
5657
)
5758
from databricks.sql.types import SSLOptions
5859
from databricks.sql.backend.databricks_client import DatabricksClient
@@ -1003,6 +1004,7 @@ def execute_command(
10031004
async_op=False,
10041005
enforce_embedded_schema_correctness=False,
10051006
row_limit: Optional[int] = None,
1007+
query_tags: Optional[Dict[str, Optional[str]]] = None,
10061008
) -> Union["ResultSet", None]:
10071009
thrift_handle = session_id.to_thrift_handle()
10081010
if not thrift_handle:
@@ -1022,6 +1024,19 @@ def execute_command(
10221024
# DBR should be changed to use month_day_nano_interval
10231025
intervalTypesAsArrow=False,
10241026
)
1027+
1028+
# Build confOverlay with default configs and query_tags
1029+
merged_conf_overlay = {
1030+
# We want to receive proper Timestamp arrow types.
1031+
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
1032+
}
1033+
1034+
# Serialize and add query_tags to confOverlay if provided
1035+
if query_tags:
1036+
serialized_tags = serialize_query_tags(query_tags)
1037+
if serialized_tags:
1038+
merged_conf_overlay["query_tags"] = serialized_tags
1039+
10251040
req = ttypes.TExecuteStatementReq(
10261041
sessionHandle=thrift_handle,
10271042
statement=operation,
@@ -1036,10 +1051,7 @@ def execute_command(
10361051
canReadArrowResult=True if pyarrow else False,
10371052
canDecompressLZ4Result=lz4_compression,
10381053
canDownloadResult=use_cloud_fetch,
1039-
confOverlay={
1040-
# We want to receive proper Timestamp arrow types.
1041-
"spark.thriftserver.arrowBasedRowSet.timestampAsString": "false"
1042-
},
1054+
confOverlay=merged_conf_overlay,
10431055
useArrowNativeTypes=spark_arrow_types,
10441056
parameters=parameters,
10451057
enforceEmbeddedSchemaCorrectness=enforce_embedded_schema_correctness,

src/databricks/sql/client.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ def execute(
12631263
parameters: Optional[TParameterCollection] = None,
12641264
enforce_embedded_schema_correctness=False,
12651265
input_stream: Optional[BinaryIO] = None,
1266+
query_tags: Optional[Dict[str, Optional[str]]] = None,
12661267
) -> "Cursor":
12671268
"""
12681269
Execute a query and wait for execution to complete.
@@ -1293,6 +1294,10 @@ def execute(
12931294
Both will result in the query equivalent to "SELECT * FROM table WHERE field = 'foo'
12941295
being sent to the server
12951296
1297+
:param query_tags: Optional dictionary of query tags to apply for this query only.
1298+
Tags are key-value pairs that can be used to identify and categorize queries.
1299+
Example: {"team": "data-eng", "application": "etl"}
1300+
12961301
:returns self
12971302
"""
12981303

@@ -1333,6 +1338,7 @@ def execute(
13331338
async_op=False,
13341339
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
13351340
row_limit=self.row_limit,
1341+
query_tags=query_tags,
13361342
)
13371343

13381344
if self.active_result_set and self.active_result_set.is_staging_operation:
@@ -1349,13 +1355,17 @@ def execute_async(
13491355
operation: str,
13501356
parameters: Optional[TParameterCollection] = None,
13511357
enforce_embedded_schema_correctness=False,
1358+
query_tags: Optional[Dict[str, Optional[str]]] = None,
13521359
) -> "Cursor":
13531360
"""
13541361
13551362
Execute a query and do not wait for it to complete and just move ahead
13561363
13571364
:param operation:
13581365
:param parameters:
1366+
:param query_tags: Optional dictionary of query tags to apply for this query only.
1367+
Tags are key-value pairs that can be used to identify and categorize queries.
1368+
Example: {"team": "data-eng", "application": "etl"}
13591369
:return:
13601370
"""
13611371

@@ -1392,6 +1402,7 @@ def execute_async(
13921402
async_op=True,
13931403
enforce_embedded_schema_correctness=enforce_embedded_schema_correctness,
13941404
row_limit=self.row_limit,
1405+
query_tags=query_tags,
13951406
)
13961407

13971408
return self
@@ -1448,7 +1459,12 @@ def get_async_execution_result(self):
14481459
session_id_hex=self.connection.get_session_id_hex(),
14491460
)
14501461

1451-
def executemany(self, operation, seq_of_parameters):
1462+
def executemany(
1463+
self,
1464+
operation,
1465+
seq_of_parameters,
1466+
query_tags: Optional[Dict[str, Optional[str]]] = None,
1467+
):
14521468
"""
14531469
Execute the operation once for every set of passed in parameters.
14541470
@@ -1457,10 +1473,14 @@ def executemany(self, operation, seq_of_parameters):
14571473
14581474
Only the final result set is retained.
14591475
1476+
:param query_tags: Optional dictionary of query tags to apply for all queries in this batch.
1477+
Tags are key-value pairs that can be used to identify and categorize queries.
1478+
Example: {"team": "data-eng", "application": "etl"}
1479+
14601480
:returns self
14611481
"""
14621482
for parameters in seq_of_parameters:
1463-
self.execute(operation, parameters)
1483+
self.execute(operation, parameters, query_tags=query_tags)
14641484
return self
14651485

14661486
@log_latency(StatementType.METADATA)

src/databricks/sql/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,49 @@ def concat_table_chunks(
898898
return pyarrow.concat_tables(table_chunks)
899899

900900

901+
def serialize_query_tags(
902+
query_tags: Optional[Dict[str, Optional[str]]]
903+
) -> Optional[str]:
904+
"""
905+
Serialize query_tags dictionary to a string format.
906+
907+
Format: "key1:value1,key2:value2"
908+
Special cases:
909+
- If value is None, omit the colon and value (e.g., "key1:value1,key2,key3:value3")
910+
- Escape special characters (:, ,, \\) in values with a leading backslash
911+
- Backslashes in keys are escaped; other special characters in keys are not escaped
912+
913+
Args:
914+
query_tags: Dictionary of query tags where keys are strings and values are optional strings
915+
916+
Returns:
917+
Serialized string or None if query_tags is None or empty
918+
"""
919+
if not query_tags:
920+
return None
921+
922+
def escape_value(value: str) -> str:
923+
"""Escape special characters in tag values."""
924+
# Escape backslash first to avoid double-escaping
925+
value = value.replace("\\", r"\\")
926+
# Escape colon and comma
927+
value = value.replace(":", r"\:")
928+
value = value.replace(",", r"\,")
929+
return value
930+
931+
serialized_parts = []
932+
for key, value in query_tags.items():
933+
escaped_key = key.replace("\\", r"\\")
934+
if value is None:
935+
# No colon or value when value is None
936+
serialized_parts.append(escaped_key)
937+
else:
938+
escaped_value = escape_value(value)
939+
serialized_parts.append(f"{escaped_key}:{escaped_value}")
940+
941+
return ",".join(serialized_parts)
942+
943+
901944
def build_client_context(server_hostname: str, version: str, **kwargs):
902945
"""Build ClientContext for HTTP client configuration."""
903946
from databricks.sql.auth.common import ClientContext

0 commit comments

Comments
 (0)