Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Enhancement or New Feature
body: Extract project/environment helpers for multi-project support
time: 2026-03-17T02:15:39.921671-05:00
21 changes: 20 additions & 1 deletion src/dbt_mcp/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
import asyncio
import logging
import os

from dbt_mcp.config.config import load_config
from dbt_mcp.config.transport import validate_transport
from dbt_mcp.mcp.server import create_dbt_mcp

logger = logging.getLogger(__name__)


def main() -> None:
config = load_config()
server = asyncio.run(create_dbt_mcp(config))

multi_project_enabled = os.environ.get(
"DBT_MCP_MULTI_PROJECT_ENABLED", ""
).lower() in ("true", "1", "yes")

if multi_project_enabled:
logger.info(
"DBT_MCP_MULTI_PROJECT_ENABLED=true -> Multi-project mode (Server B)"
)
raise NotImplementedError(
"Multi-project mode is not yet implemented. "
"It will be available in a future release."
)
else:
logger.info("Multi-project mode disabled -> Env-var mode (Server A)")
server = asyncio.run(create_dbt_mcp(config))

transport = validate_transport(os.environ.get("MCP_TRANSPORT", "stdio"))
server.run(transport=transport)

Expand Down
131 changes: 15 additions & 116 deletions src/dbt_mcp/oauth/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import cast
from urllib.parse import quote

import httpx
from authlib.integrations.requests_client import OAuth2Session
from fastapi import FastAPI, Request
from fastapi.responses import RedirectResponse
Expand All @@ -12,9 +11,7 @@

from dbt_mcp.oauth.context_manager import DbtPlatformContextManager
from dbt_mcp.oauth.dbt_platform import (
DbtPlatformAccount,
DbtPlatformContext,
DbtPlatformEnvironment,
DbtPlatformEnvironmentResponse,
DbtPlatformProject,
GetEnvironmentsRequest,
Expand All @@ -24,6 +21,14 @@
from dbt_mcp.oauth.token import (
DecodedAccessToken,
)
from dbt_mcp.project.environment_resolver import (
_get_all_environments_for_project,
resolve_environments,
)
from dbt_mcp.project.project_resolver import (
get_all_accounts,
get_all_projects_for_account,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,77 +61,6 @@ async def send_wrapper(message):
await super().__call__(scope, receive, send_wrapper)


async def _get_all_accounts(
*,
dbt_platform_url: str,
headers: dict[str, str],
) -> list[DbtPlatformAccount]:
async with httpx.AsyncClient() as client:
response = await client.get(
url=f"{dbt_platform_url}/api/v3/accounts/",
headers=headers,
)
response.raise_for_status()
data = response.json()
return [DbtPlatformAccount(**account) for account in data["data"]]


async def _get_all_projects_for_account(
*,
dbt_platform_url: str,
account: DbtPlatformAccount,
headers: dict[str, str],
page_size: int = 100,
) -> list[DbtPlatformProject]:
"""Fetch all projects for an account using offset/page_size pagination."""
offset = 0
projects: list[DbtPlatformProject] = []
async with httpx.AsyncClient() as client:
while True:
response = await client.get(
f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}",
headers=headers,
)
response.raise_for_status()
page = response.json()["data"]
projects.extend(
DbtPlatformProject(**project, account_name=account.name)
for project in page
)
if len(page) < page_size:
break
offset += page_size
return projects


async def _get_all_environments_for_project(
*,
dbt_platform_url: str,
account_id: int,
project_id: int,
headers: dict[str, str],
page_size: int = 100,
) -> list[DbtPlatformEnvironmentResponse]:
"""Fetch all environments for a project using offset/page_size pagination."""
offset = 0
environments: list[DbtPlatformEnvironmentResponse] = []
async with httpx.AsyncClient() as client:
while True:
response = await client.get(
f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}",
headers=headers,
)
response.raise_for_status()
page = response.json()["data"]
environments.extend(
DbtPlatformEnvironmentResponse(**environment) for environment in page
)
if len(page) < page_size:
break
offset += page_size
return environments


def create_app(
*,
oauth_client: OAuth2Session,
Expand Down Expand Up @@ -206,14 +140,14 @@ async def projects() -> list[DbtPlatformProject]:
"Accept": "application/json",
"Authorization": f"Bearer {access_token}",
}
accounts = await _get_all_accounts(
accounts = await get_all_accounts(
dbt_platform_url=dbt_platform_url,
headers=headers,
)
projects: list[DbtPlatformProject] = []
for account in [a for a in accounts if a.state == 1 and not a.locked]:
projects.extend(
await _get_all_projects_for_account(
await get_all_projects_for_account(
dbt_platform_url=dbt_platform_url,
account=account,
headers=headers,
Expand Down Expand Up @@ -265,7 +199,7 @@ async def set_selected_project(
"Accept": "application/json",
"Authorization": f"Bearer {access_token}",
}
accounts = await _get_all_accounts(
accounts = await get_all_accounts(
dbt_platform_url=dbt_platform_url,
headers=headers,
)
Expand All @@ -282,45 +216,10 @@ async def set_selected_project(
page_size=100,
)

prod_environment = None
dev_environment = None

# If a specific prod_environment_id was provided, use it
if selected_project_request.prod_environment_id:
for environment in environments:
if environment.id == selected_project_request.prod_environment_id:
prod_environment = DbtPlatformEnvironment(
id=environment.id,
name=environment.name,
deployment_type=environment.deployment_type or "production",
)
break
else:
# Fall back to auto-detection based on deployment_type
for environment in environments:
if (
environment.deployment_type
and environment.deployment_type.lower() == "production"
):
prod_environment = DbtPlatformEnvironment(
id=environment.id,
name=environment.name,
deployment_type=environment.deployment_type,
)
break

# Always try to auto-detect dev environment
for environment in environments:
if (
environment.deployment_type
and environment.deployment_type.lower() == "development"
):
dev_environment = DbtPlatformEnvironment(
id=environment.id,
name=environment.name,
deployment_type=environment.deployment_type,
)
break
prod_environment, dev_environment = resolve_environments(
environments,
prod_environment_id=selected_project_request.prod_environment_id,
)

dbt_platform_context = dbt_platform_context_manager.update_context(
new_dbt_platform_context=DbtPlatformContext(
Expand Down
Empty file.
115 changes: 115 additions & 0 deletions src/dbt_mcp/project/environment_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import logging

import httpx

from dbt_mcp.oauth.dbt_platform import (
DbtPlatformEnvironment,
DbtPlatformEnvironmentResponse,
)

logger = logging.getLogger(__name__)


async def _get_all_environments_for_project(
*,
dbt_platform_url: str,
account_id: int,
project_id: int,
headers: dict[str, str],
page_size: int = 100,
) -> list[DbtPlatformEnvironmentResponse]:
"""Fetch all environments for a project using offset/page_size pagination."""
offset = 0
environments: list[DbtPlatformEnvironmentResponse] = []
async with httpx.AsyncClient() as client:
while True:
response = await client.get(
f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}",
headers=headers,
)
Comment on lines +26 to +29

Check failure

Code scanning / CodeQL

Partial server-side request forgery

Part of the URL of this request depends on a [user-provided value](1). Part of the URL of this request depends on a [user-provided value](2).

Copilot Autofix

AI about 22 hours ago

In general, to fix partial SSRF in this pattern you constrain user-controlled path components to a safe, expected subset and/or verify that they refer to resources the caller is actually allowed to access. For numeric IDs, that means enforcing integer types and checking that the requested IDs are present in the set of accounts/projects associated with the authenticated user, instead of blindly passing any ID into a backend HTTP call.

Concretely here, the vulnerable pieces are the account_id/project_id values passed into _get_all_environments_for_project from get_deployment_environments and set_selected_project. The best fix without changing intended functionality is to validate that the requested account_id corresponds to one of the accounts returned by get_all_accounts for the current token, and that the project_id corresponds to one of the projects for that account from get_all_projects_for_account. If either does not exist, we return a 4xx-style error (by raising a ValueError as is already done for a missing account in set_selected_project). This both satisfies CodeQL (the URL path still uses user-supplied IDs but now only after authorization/validation) and enforces proper access control.

Specifically:

  • In get_deployment_environments (/environments handler):
    • After building headers, call get_all_accounts(dbt_platform_url, headers) and ensure request.account_id is present. If not, raise ValueError.
    • Then call get_all_projects_for_account for that account and ensure request.project_id exists. If not, raise ValueError.
    • Only then call _get_all_environments_for_project with those IDs.
  • In set_selected_project (/selected_project handler):
    • It already validates the account by ID. Add a similar validation step for selected_project_request.project_id by calling get_all_projects_for_account with the resolved account and verifying the project exists before calling _get_all_environments_for_project.

All of these changes occur in src/dbt_mcp/oauth/fastapi_app.py. We do not need to modify _get_all_environments_for_project in environment_resolver.py; it continues to accept IDs but will only be called with IDs that have been vetted against the user’s accessible resources. No new imports are required, as we already import get_all_accounts and get_all_projects_for_account.


Suggested changeset 1
src/dbt_mcp/oauth/fastapi_app.py
Outside changed files

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/src/dbt_mcp/oauth/fastapi_app.py b/src/dbt_mcp/oauth/fastapi_app.py
--- a/src/dbt_mcp/oauth/fastapi_app.py
+++ b/src/dbt_mcp/oauth/fastapi_app.py
@@ -173,10 +173,26 @@
             "Accept": "application/json",
             "Authorization": f"Bearer {access_token}",
         }
+        # Validate that the requested account and project belong to the authenticated user
+        accounts = await get_all_accounts(
+            dbt_platform_url=dbt_platform_url,
+            headers=headers,
+        )
+        account = next((a for a in accounts if a.id == request.account_id), None)
+        if account is None:
+            raise ValueError(f"Account {request.account_id} not found")
+        projects = await get_all_projects_for_account(
+            dbt_platform_url=dbt_platform_url,
+            account=account,
+            headers=headers,
+        )
+        project = next((p for p in projects if p.id == request.project_id), None)
+        if project is None:
+            raise ValueError(f"Project {request.project_id} not found for account {request.account_id}")
         environments = await _get_all_environments_for_project(
             dbt_platform_url=dbt_platform_url,
-            account_id=request.account_id,
-            project_id=request.project_id,
+            account_id=account.id,
+            project_id=project.id,
             headers=headers,
             page_size=100,
         )
@@ -208,10 +222,23 @@
         )
         if account is None:
             raise ValueError(f"Account {selected_project_request.account_id} not found")
+        # Validate that the selected project belongs to the selected account
+        projects = await get_all_projects_for_account(
+            dbt_platform_url=dbt_platform_url,
+            account=account,
+            headers=headers,
+        )
+        project = next(
+            (p for p in projects if p.id == selected_project_request.project_id), None
+        )
+        if project is None:
+            raise ValueError(
+                f"Project {selected_project_request.project_id} not found for account {selected_project_request.account_id}"
+            )
         environments = await _get_all_environments_for_project(
             dbt_platform_url=dbt_platform_url,
-            account_id=selected_project_request.account_id,
-            project_id=selected_project_request.project_id,
+            account_id=account.id,
+            project_id=project.id,
             headers=headers,
             page_size=100,
         )
EOF
@@ -173,10 +173,26 @@
"Accept": "application/json",
"Authorization": f"Bearer {access_token}",
}
# Validate that the requested account and project belong to the authenticated user
accounts = await get_all_accounts(
dbt_platform_url=dbt_platform_url,
headers=headers,
)
account = next((a for a in accounts if a.id == request.account_id), None)
if account is None:
raise ValueError(f"Account {request.account_id} not found")
projects = await get_all_projects_for_account(
dbt_platform_url=dbt_platform_url,
account=account,
headers=headers,
)
project = next((p for p in projects if p.id == request.project_id), None)
if project is None:
raise ValueError(f"Project {request.project_id} not found for account {request.account_id}")
environments = await _get_all_environments_for_project(
dbt_platform_url=dbt_platform_url,
account_id=request.account_id,
project_id=request.project_id,
account_id=account.id,
project_id=project.id,
headers=headers,
page_size=100,
)
@@ -208,10 +222,23 @@
)
if account is None:
raise ValueError(f"Account {selected_project_request.account_id} not found")
# Validate that the selected project belongs to the selected account
projects = await get_all_projects_for_account(
dbt_platform_url=dbt_platform_url,
account=account,
headers=headers,
)
project = next(
(p for p in projects if p.id == selected_project_request.project_id), None
)
if project is None:
raise ValueError(
f"Project {selected_project_request.project_id} not found for account {selected_project_request.account_id}"
)
environments = await _get_all_environments_for_project(
dbt_platform_url=dbt_platform_url,
account_id=selected_project_request.account_id,
project_id=selected_project_request.project_id,
account_id=account.id,
project_id=project.id,
headers=headers,
page_size=100,
)
Copilot is powered by AI and may make mistakes. Always verify output.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a false positive.

response.raise_for_status()
page = response.json()["data"]
environments.extend(
DbtPlatformEnvironmentResponse(**environment) for environment in page
)
if len(page) < page_size:
break
offset += page_size
return environments


def resolve_environments(
environments: list[DbtPlatformEnvironmentResponse],
*,
prod_environment_id: int | None = None,
) -> tuple[DbtPlatformEnvironment | None, DbtPlatformEnvironment | None]:
"""Resolve prod and dev environments from a list of environment responses.
Returns a tuple of (prod_environment, dev_environment).
If prod_environment_id is provided, that specific environment is used as prod.
Otherwise, auto-detects based on deployment_type == "production".
Dev environment is always auto-detected based on deployment_type == "development".
"""
prod_environment: DbtPlatformEnvironment | None = None
dev_environment: DbtPlatformEnvironment | None = None

if prod_environment_id:
for environment in environments:
if environment.id == prod_environment_id:
prod_environment = DbtPlatformEnvironment(
id=environment.id,
name=environment.name,
deployment_type=environment.deployment_type or "production",
)
break
else:
for environment in environments:
if (
environment.deployment_type
and environment.deployment_type.lower() == "production"
):
prod_environment = DbtPlatformEnvironment(
id=environment.id,
name=environment.name,
deployment_type=environment.deployment_type,
)
break

for environment in environments:
if (
environment.deployment_type
and environment.deployment_type.lower() == "development"
):
dev_environment = DbtPlatformEnvironment(
id=environment.id,
name=environment.name,
deployment_type=environment.deployment_type,
)
break

return prod_environment, dev_environment


async def get_environments_for_project(
*,
dbt_platform_url: str,
account_id: int,
project_id: int,
headers: dict[str, str],
prod_environment_id: int | None = None,
) -> tuple[DbtPlatformEnvironment | None, DbtPlatformEnvironment | None]:
"""Fetch environments for a project and resolve prod/dev.
Returns a tuple of (prod_environment, dev_environment).
"""
environments = await _get_all_environments_for_project(
dbt_platform_url=dbt_platform_url,
account_id=account_id,
project_id=project_id,
headers=headers,
)
return resolve_environments(
environments,
prod_environment_id=prod_environment_id,
)
53 changes: 53 additions & 0 deletions src/dbt_mcp/project/project_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import logging

import httpx

from dbt_mcp.oauth.dbt_platform import (
DbtPlatformAccount,
DbtPlatformProject,
)

logger = logging.getLogger(__name__)


async def get_all_accounts(
*,
dbt_platform_url: str,
headers: dict[str, str],
) -> list[DbtPlatformAccount]:
async with httpx.AsyncClient() as client:
response = await client.get(
url=f"{dbt_platform_url}/api/v3/accounts/",
headers=headers,
)
response.raise_for_status()
data = response.json()
return [DbtPlatformAccount(**account) for account in data["data"]]


async def get_all_projects_for_account(
*,
dbt_platform_url: str,
account: DbtPlatformAccount,
headers: dict[str, str],
page_size: int = 100,
) -> list[DbtPlatformProject]:
"""Fetch all projects for an account using offset/page_size pagination."""
offset = 0
projects: list[DbtPlatformProject] = []
async with httpx.AsyncClient() as client:
while True:
response = await client.get(
f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}",
headers=headers,
)
response.raise_for_status()
page = response.json()["data"]
projects.extend(
DbtPlatformProject(**project, account_name=account.name)
for project in page
)
if len(page) < page_size:
break
offset += page_size
return projects
Loading
Loading