Skip to content

Commit 69ce9db

Browse files
DevonFulcherclaude
andcommitted
feat: use httpx async in project resolver modules
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c7820f5 commit 69ce9db

File tree

4 files changed

+121
-111
lines changed

4 files changed

+121
-111
lines changed

src/dbt_mcp/oauth/fastapi_app.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,22 @@ def shutdown_server() -> dict[str, bool]:
132132
return {"ok": True}
133133

134134
@app.get("/projects")
135-
def projects() -> list[DbtPlatformProject]:
135+
async def projects() -> list[DbtPlatformProject]:
136136
if app.state.decoded_access_token is None:
137137
raise RuntimeError("Access token missing; OAuth flow not completed")
138138
access_token = app.state.decoded_access_token.access_token_response.access_token
139139
headers = {
140140
"Accept": "application/json",
141141
"Authorization": f"Bearer {access_token}",
142142
}
143-
accounts = get_all_accounts(
143+
accounts = await get_all_accounts(
144144
dbt_platform_url=dbt_platform_url,
145145
headers=headers,
146146
)
147147
projects: list[DbtPlatformProject] = []
148148
for account in [a for a in accounts if a.state == 1 and not a.locked]:
149149
projects.extend(
150-
get_all_projects_for_account(
150+
await get_all_projects_for_account(
151151
dbt_platform_url=dbt_platform_url,
152152
account=account,
153153
headers=headers,
@@ -161,7 +161,7 @@ def get_dbt_platform_context() -> DbtPlatformContext:
161161
return dbt_platform_context_manager.read_context() or DbtPlatformContext()
162162

163163
@app.post("/environments")
164-
def get_deployment_environments(
164+
async def get_deployment_environments(
165165
request: GetEnvironmentsRequest,
166166
) -> list[DbtPlatformEnvironmentResponse]:
167167
"""Get all deployment environments for a project, excluding development environments."""
@@ -173,7 +173,7 @@ def get_deployment_environments(
173173
"Accept": "application/json",
174174
"Authorization": f"Bearer {access_token}",
175175
}
176-
environments = _get_all_environments_for_project(
176+
environments = await _get_all_environments_for_project(
177177
dbt_platform_url=dbt_platform_url,
178178
account_id=request.account_id,
179179
project_id=request.project_id,
@@ -188,7 +188,7 @@ def get_deployment_environments(
188188
]
189189

190190
@app.post("/selected_project")
191-
def set_selected_project(
191+
async def set_selected_project(
192192
selected_project_request: SelectedProjectRequest,
193193
) -> DbtPlatformContext:
194194
logger.info("Selected project received")
@@ -199,7 +199,7 @@ def set_selected_project(
199199
"Accept": "application/json",
200200
"Authorization": f"Bearer {access_token}",
201201
}
202-
accounts = get_all_accounts(
202+
accounts = await get_all_accounts(
203203
dbt_platform_url=dbt_platform_url,
204204
headers=headers,
205205
)
@@ -208,7 +208,7 @@ def set_selected_project(
208208
)
209209
if account is None:
210210
raise ValueError(f"Account {selected_project_request.account_id} not found")
211-
environments = _get_all_environments_for_project(
211+
environments = await _get_all_environments_for_project(
212212
dbt_platform_url=dbt_platform_url,
213213
account_id=selected_project_request.account_id,
214214
project_id=selected_project_request.project_id,

src/dbt_mcp/project/environment_resolver.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
import requests
3+
import httpx
44

55
from dbt_mcp.oauth.dbt_platform import (
66
DbtPlatformEnvironment,
@@ -10,7 +10,7 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
def _get_all_environments_for_project(
13+
async def _get_all_environments_for_project(
1414
*,
1515
dbt_platform_url: str,
1616
account_id: int,
@@ -21,19 +21,20 @@ def _get_all_environments_for_project(
2121
"""Fetch all environments for a project using offset/page_size pagination."""
2222
offset = 0
2323
environments: list[DbtPlatformEnvironmentResponse] = []
24-
while True:
25-
environments_response = requests.get(
26-
f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}",
27-
headers=headers,
28-
)
29-
environments_response.raise_for_status()
30-
page = environments_response.json()["data"]
31-
environments.extend(
32-
DbtPlatformEnvironmentResponse(**environment) for environment in page
33-
)
34-
if len(page) < page_size:
35-
break
36-
offset += page_size
24+
async with httpx.AsyncClient() as client:
25+
while True:
26+
response = await client.get(
27+
f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}",
28+
headers=headers,
29+
)
30+
response.raise_for_status()
31+
page = response.json()["data"]
32+
environments.extend(
33+
DbtPlatformEnvironmentResponse(**environment) for environment in page
34+
)
35+
if len(page) < page_size:
36+
break
37+
offset += page_size
3738
return environments
3839

3940

@@ -90,7 +91,7 @@ def resolve_environments(
9091
return prod_environment, dev_environment
9192

9293

93-
def get_environments_for_project(
94+
async def get_environments_for_project(
9495
*,
9596
dbt_platform_url: str,
9697
account_id: int,
@@ -102,7 +103,7 @@ def get_environments_for_project(
102103
103104
Returns a tuple of (prod_environment, dev_environment).
104105
"""
105-
environments = _get_all_environments_for_project(
106+
environments = await _get_all_environments_for_project(
106107
dbt_platform_url=dbt_platform_url,
107108
account_id=account_id,
108109
project_id=project_id,
Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22

3-
import requests
3+
import httpx
44

55
from dbt_mcp.oauth.dbt_platform import (
66
DbtPlatformAccount,
@@ -10,22 +10,22 @@
1010
logger = logging.getLogger(__name__)
1111

1212

13-
def get_all_accounts(
13+
async def get_all_accounts(
1414
*,
1515
dbt_platform_url: str,
1616
headers: dict[str, str],
1717
) -> list[DbtPlatformAccount]:
18-
accounts_response = requests.get(
19-
url=f"{dbt_platform_url}/api/v3/accounts/",
20-
headers=headers,
21-
)
22-
accounts_response.raise_for_status()
23-
return [
24-
DbtPlatformAccount(**account) for account in accounts_response.json()["data"]
25-
]
18+
async with httpx.AsyncClient() as client:
19+
response = await client.get(
20+
url=f"{dbt_platform_url}/api/v3/accounts/",
21+
headers=headers,
22+
)
23+
response.raise_for_status()
24+
data = response.json()
25+
return [DbtPlatformAccount(**account) for account in data["data"]]
2626

2727

28-
def get_all_projects_for_account(
28+
async def get_all_projects_for_account(
2929
*,
3030
dbt_platform_url: str,
3131
account: DbtPlatformAccount,
@@ -35,17 +35,19 @@ def get_all_projects_for_account(
3535
"""Fetch all projects for an account using offset/page_size pagination."""
3636
offset = 0
3737
projects: list[DbtPlatformProject] = []
38-
while True:
39-
projects_response = requests.get(
40-
f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}",
41-
headers=headers,
42-
)
43-
projects_response.raise_for_status()
44-
page = projects_response.json()["data"]
45-
projects.extend(
46-
DbtPlatformProject(**project, account_name=account.name) for project in page
47-
)
48-
if len(page) < page_size:
49-
break
50-
offset += page_size
38+
async with httpx.AsyncClient() as client:
39+
while True:
40+
response = await client.get(
41+
f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}",
42+
headers=headers,
43+
)
44+
response.raise_for_status()
45+
page = response.json()["data"]
46+
projects.extend(
47+
DbtPlatformProject(**project, account_name=account.name)
48+
for project in page
49+
)
50+
if len(page) < page_size:
51+
break
52+
offset += page_size
5153
return projects
Lines changed: 69 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import Mock, patch
1+
from unittest.mock import AsyncMock, MagicMock, patch
22

33
import pytest
44

@@ -24,35 +24,49 @@ def account():
2424
)
2525

2626

27-
@patch("dbt_mcp.project.project_resolver.requests.get")
28-
def test_get_all_projects_for_account_paginates(mock_get: Mock, base_headers, account):
27+
def create_mock_response(data: dict) -> MagicMock:
28+
resp = MagicMock()
29+
resp.json.return_value = data
30+
resp.raise_for_status.return_value = None
31+
return resp
32+
33+
34+
def create_mock_httpx_client(responses: list) -> AsyncMock:
35+
mock_client = AsyncMock()
36+
mock_client.get = AsyncMock(side_effect=responses)
37+
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
38+
mock_client.__aexit__ = AsyncMock(return_value=None)
39+
return mock_client
40+
41+
42+
async def test_get_all_projects_for_account_paginates(base_headers, account):
2943
# Two pages: first full page (limit=2), second partial page (1 item) -> stop
30-
first_page_resp = Mock()
31-
first_page_resp.json.return_value = {
32-
"data": [
33-
{"id": 101, "name": "Proj A", "account_id": account.id},
34-
{"id": 102, "name": "Proj B", "account_id": account.id},
35-
]
36-
}
37-
first_page_resp.raise_for_status.return_value = None
38-
39-
second_page_resp = Mock()
40-
second_page_resp.json.return_value = {
41-
"data": [
42-
{"id": 103, "name": "Proj C", "account_id": account.id},
43-
]
44-
}
45-
second_page_resp.raise_for_status.return_value = None
46-
47-
mock_get.side_effect = [first_page_resp, second_page_resp]
48-
49-
result = get_all_projects_for_account(
50-
dbt_platform_url="https://cloud.getdbt.com",
51-
account=account,
52-
headers=base_headers,
53-
page_size=2,
44+
first_page_resp = create_mock_response(
45+
{
46+
"data": [
47+
{"id": 101, "name": "Proj A", "account_id": account.id},
48+
{"id": 102, "name": "Proj B", "account_id": account.id},
49+
]
50+
}
51+
)
52+
second_page_resp = create_mock_response(
53+
{
54+
"data": [
55+
{"id": 103, "name": "Proj C", "account_id": account.id},
56+
]
57+
}
5458
)
5559

60+
mock_client = create_mock_httpx_client([first_page_resp, second_page_resp])
61+
62+
with patch("httpx.AsyncClient", return_value=mock_client):
63+
result = await get_all_projects_for_account(
64+
dbt_platform_url="https://cloud.getdbt.com",
65+
account=account,
66+
headers=base_headers,
67+
page_size=2,
68+
)
69+
5670
# Should aggregate 3 projects and include account_name field
5771
assert len(result) == 3
5872
assert {p.id for p in result} == {101, 102, 103}
@@ -63,52 +77,45 @@ def test_get_all_projects_for_account_paginates(mock_get: Mock, base_headers, ac
6377
"https://cloud.getdbt.com/api/v3/accounts/1/projects/?state=1&offset=0&limit=2",
6478
"https://cloud.getdbt.com/api/v3/accounts/1/projects/?state=1&offset=2&limit=2",
6579
]
66-
actual_urls = [
67-
call.kwargs["url"] if "url" in call.kwargs else call.args[0]
68-
for call in mock_get.call_args_list
69-
]
80+
actual_urls = [call.args[0] for call in mock_client.get.call_args_list]
7081
assert actual_urls == expected_urls
7182

7283

73-
@patch("dbt_mcp.project.environment_resolver.requests.get")
74-
def test_get_all_environments_for_project_paginates(mock_get: Mock, base_headers):
84+
async def test_get_all_environments_for_project_paginates(base_headers):
7585
# Two pages: first full page (limit=2), second partial (1 item)
76-
first_page_resp = Mock()
77-
first_page_resp.json.return_value = {
78-
"data": [
79-
{"id": 201, "name": "Dev", "deployment_type": "development"},
80-
{"id": 202, "name": "Prod", "deployment_type": "production"},
81-
]
82-
}
83-
first_page_resp.raise_for_status.return_value = None
84-
85-
second_page_resp = Mock()
86-
second_page_resp.json.return_value = {
87-
"data": [
88-
{"id": 203, "name": "Staging", "deployment_type": "development"},
89-
]
90-
}
91-
second_page_resp.raise_for_status.return_value = None
92-
93-
mock_get.side_effect = [first_page_resp, second_page_resp]
94-
95-
result = _get_all_environments_for_project(
96-
dbt_platform_url="https://cloud.getdbt.com",
97-
account_id=1,
98-
project_id=9,
99-
headers=base_headers,
100-
page_size=2,
86+
first_page_resp = create_mock_response(
87+
{
88+
"data": [
89+
{"id": 201, "name": "Dev", "deployment_type": "development"},
90+
{"id": 202, "name": "Prod", "deployment_type": "production"},
91+
]
92+
}
93+
)
94+
second_page_resp = create_mock_response(
95+
{
96+
"data": [
97+
{"id": 203, "name": "Staging", "deployment_type": "development"},
98+
]
99+
}
101100
)
102101

102+
mock_client = create_mock_httpx_client([first_page_resp, second_page_resp])
103+
104+
with patch("httpx.AsyncClient", return_value=mock_client):
105+
result = await _get_all_environments_for_project(
106+
dbt_platform_url="https://cloud.getdbt.com",
107+
account_id=1,
108+
project_id=9,
109+
headers=base_headers,
110+
page_size=2,
111+
)
112+
103113
assert len(result) == 3
104114
assert {e.id for e in result} == {201, 202, 203}
105115

106116
expected_urls = [
107117
"https://cloud.getdbt.com/api/v3/accounts/1/projects/9/environments/?state=1&offset=0&limit=2",
108118
"https://cloud.getdbt.com/api/v3/accounts/1/projects/9/environments/?state=1&offset=2&limit=2",
109119
]
110-
actual_urls = [
111-
call.kwargs["url"] if "url" in call.kwargs else call.args[0]
112-
for call in mock_get.call_args_list
113-
]
120+
actual_urls = [call.args[0] for call in mock_client.get.call_args_list]
114121
assert actual_urls == expected_urls

0 commit comments

Comments
 (0)