Skip to content

Commit e4de653

Browse files
SameerMesiah97Sameer Mesiah
andauthored
Add sync and async helpers to resolve the dbt Cloud account ID from the (#61757)
configured Airflow connection and cache the value on the hook instance to avoid repeated metadata DB lookups. Introduce decorators to transparently fall back to the connection-based account_id when not explicitly provided by the caller. Add tests to verify caching behavior, including shared cache semantics between sync and async resolution paths. Co-authored-by: Sameer Mesiah <smesiah971@gmail.com>
1 parent 85085a4 commit e4de653

File tree

2 files changed

+105
-11
lines changed
  • providers/dbt/cloud

2 files changed

+105
-11
lines changed

providers/dbt/cloud/src/airflow/providers/dbt/cloud/hooks/dbt.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,7 @@ def wrapper(*args, **kwargs) -> Callable:
6363
# provided.
6464
if bound_args.arguments.get("account_id") is None:
6565
self = args[0]
66-
default_account_id = self.connection.login
67-
if not default_account_id:
68-
raise AirflowException("Could not determine the dbt Cloud account.")
69-
70-
bound_args.arguments["account_id"] = int(default_account_id)
71-
66+
bound_args.arguments["account_id"] = self._resolve_account_id()
7267
return func(*bound_args.args, **bound_args.kwargs)
7368

7469
return wrapper
@@ -162,11 +157,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any:
162157
if bound_args.arguments.get("account_id") is None:
163158
self = args[0]
164159
if self.dbt_cloud_conn_id:
165-
connection = await get_async_connection(self.dbt_cloud_conn_id)
166-
default_account_id = connection.login
167-
if not default_account_id:
168-
raise AirflowException("Could not determine the dbt Cloud account.")
169-
bound_args.arguments["account_id"] = int(default_account_id)
160+
bound_args.arguments["account_id"] = await self._resolve_account_id_async()
170161

171162
return await func(*bound_args.args, **bound_args.kwargs)
172163

@@ -434,6 +425,32 @@ def _run_and_get_response(
434425
extra_options=extra_options or None,
435426
)
436427

428+
def _resolve_account_id(self) -> int:
429+
"""Resolve and cache the dbt Cloud account ID (sync)."""
430+
# Lazily initialized; absence means "not resolved yet".
431+
if not hasattr(self, "_cached_account_id"):
432+
conn = self.get_connection(self.dbt_cloud_conn_id)
433+
if not conn.login:
434+
raise AirflowException("Could not determine the dbt Cloud account.")
435+
436+
# Cache is shared between sync and async resolution to avoid duplicate
437+
# metadata DB lookups on the same hook instance.
438+
self._cached_account_id = int(conn.login)
439+
return self._cached_account_id
440+
441+
async def _resolve_account_id_async(self) -> int:
442+
"""Resolve and cache the dbt Cloud account ID (async)."""
443+
# Lazily initialized; absence means "not resolved yet".
444+
if not hasattr(self, "_cached_account_id"):
445+
conn = await get_async_connection(self.dbt_cloud_conn_id)
446+
if not conn.login:
447+
raise AirflowException("Could not determine the dbt Cloud account.")
448+
449+
# Cache is shared between sync and async resolution to avoid duplicate
450+
# metadata DB lookups on the same hook instance.
451+
self._cached_account_id = int(conn.login)
452+
return self._cached_account_id
453+
437454
def list_accounts(self) -> list[Response]:
438455
"""
439456
Retrieve all of the dbt Cloud accounts the configured API token is authorized to access.

providers/dbt/cloud/tests/unit/dbt/cloud/hooks/test_dbt.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,83 @@ def test_get_account(self, mock_paginate, mock_http_run, conn_id, account_id):
279279
)
280280
hook._paginate.assert_not_called()
281281

282+
def test_resolve_account_id_cached_sync(self):
283+
hook = DbtCloudHook(ACCOUNT_ID_CONN)
284+
285+
with patch.object(DbtCloudHook, "get_connection") as mock_get_connection:
286+
mock_get_connection.return_value = Connection(
287+
conn_id=ACCOUNT_ID_CONN,
288+
conn_type=DbtCloudHook.conn_type,
289+
login=str(DEFAULT_ACCOUNT_ID),
290+
password=TOKEN,
291+
)
292+
293+
first_call = hook._resolve_account_id()
294+
second_call = hook._resolve_account_id()
295+
296+
assert first_call == DEFAULT_ACCOUNT_ID
297+
assert second_call == DEFAULT_ACCOUNT_ID
298+
assert mock_get_connection.call_count == 1
299+
300+
@pytest.mark.asyncio
301+
async def test_resolve_account_id_cached_async(self):
302+
hook = DbtCloudHook(ACCOUNT_ID_CONN)
303+
304+
with patch(
305+
"airflow.providers.dbt.cloud.hooks.dbt.get_async_connection",
306+
new=AsyncMock(
307+
return_value=Connection(
308+
conn_id=ACCOUNT_ID_CONN,
309+
conn_type=DbtCloudHook.conn_type,
310+
login=str(DEFAULT_ACCOUNT_ID),
311+
password=TOKEN,
312+
)
313+
),
314+
) as mock_get_async_connection:
315+
first_call = await hook._resolve_account_id_async()
316+
second_call = await hook._resolve_account_id_async()
317+
318+
assert first_call == DEFAULT_ACCOUNT_ID
319+
assert second_call == DEFAULT_ACCOUNT_ID
320+
assert mock_get_async_connection.call_count == 1
321+
322+
@pytest.mark.asyncio
323+
async def test_account_id_cache_shared_between_sync_and_async(self):
324+
hook = DbtCloudHook(ACCOUNT_ID_CONN)
325+
326+
with (
327+
patch.object(
328+
DbtCloudHook,
329+
"get_connection",
330+
return_value=Connection(
331+
conn_id=ACCOUNT_ID_CONN,
332+
conn_type=DbtCloudHook.conn_type,
333+
login=str(DEFAULT_ACCOUNT_ID),
334+
password=TOKEN,
335+
),
336+
) as mock_get_connection,
337+
patch(
338+
"airflow.providers.dbt.cloud.hooks.dbt.get_async_connection",
339+
new=AsyncMock(
340+
return_value=Connection(
341+
conn_id=ACCOUNT_ID_CONN,
342+
conn_type=DbtCloudHook.conn_type,
343+
login=str(DEFAULT_ACCOUNT_ID),
344+
password=TOKEN,
345+
)
346+
),
347+
) as mock_get_async_connection,
348+
):
349+
sync_account_id = hook._resolve_account_id()
350+
async_account_id = await hook._resolve_account_id_async()
351+
352+
assert sync_account_id == DEFAULT_ACCOUNT_ID
353+
assert async_account_id == DEFAULT_ACCOUNT_ID
354+
355+
# Only one metadata DB lookup total.
356+
assert mock_get_connection.call_count == 1
357+
assert mock_get_async_connection.call_count == 0
358+
282359
@pytest.mark.parametrize(
283360
argnames=("conn_id", "account_id"),
284361
argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],

0 commit comments

Comments
 (0)