diff --git a/.changes/unreleased/Enhancement or New Feature-20260317-021539.yaml b/.changes/unreleased/Enhancement or New Feature-20260317-021539.yaml new file mode 100644 index 00000000..2bb8c9f2 --- /dev/null +++ b/.changes/unreleased/Enhancement or New Feature-20260317-021539.yaml @@ -0,0 +1,3 @@ +kind: Enhancement or New Feature +body: Extract project/environment helpers for multi-project support +time: 2026-03-17T02:15:39.921671-05:00 diff --git a/src/dbt_mcp/main.py b/src/dbt_mcp/main.py index 46a66a57..c7443d5b 100644 --- a/src/dbt_mcp/main.py +++ b/src/dbt_mcp/main.py @@ -1,14 +1,33 @@ import asyncio +import logging import os from dbt_mcp.config.config import load_config from dbt_mcp.config.transport import validate_transport from dbt_mcp.mcp.server import create_dbt_mcp +logger = logging.getLogger(__name__) + def main() -> None: config = load_config() - server = asyncio.run(create_dbt_mcp(config)) + + multi_project_enabled = os.environ.get( + "DBT_MCP_MULTI_PROJECT_ENABLED", "" + ).lower() in ("true", "1", "yes") + + if multi_project_enabled: + logger.info( + "DBT_MCP_MULTI_PROJECT_ENABLED=true -> Multi-project mode (Server B)" + ) + raise NotImplementedError( + "Multi-project mode is not yet implemented. " + "It will be available in a future release." + ) + else: + logger.info("Multi-project mode disabled -> Env-var mode (Server A)") + server = asyncio.run(create_dbt_mcp(config)) + transport = validate_transport(os.environ.get("MCP_TRANSPORT", "stdio")) server.run(transport=transport) diff --git a/src/dbt_mcp/oauth/fastapi_app.py b/src/dbt_mcp/oauth/fastapi_app.py index febfb362..59885691 100644 --- a/src/dbt_mcp/oauth/fastapi_app.py +++ b/src/dbt_mcp/oauth/fastapi_app.py @@ -2,7 +2,6 @@ from typing import cast from urllib.parse import quote -import httpx from authlib.integrations.requests_client import OAuth2Session from fastapi import FastAPI, Request from fastapi.responses import RedirectResponse @@ -12,9 +11,7 @@ from dbt_mcp.oauth.context_manager import DbtPlatformContextManager from dbt_mcp.oauth.dbt_platform import ( - DbtPlatformAccount, DbtPlatformContext, - DbtPlatformEnvironment, DbtPlatformEnvironmentResponse, DbtPlatformProject, GetEnvironmentsRequest, @@ -24,6 +21,14 @@ from dbt_mcp.oauth.token import ( DecodedAccessToken, ) +from dbt_mcp.project.environment_resolver import ( + _get_all_environments_for_project, + resolve_environments, +) +from dbt_mcp.project.project_resolver import ( + get_all_accounts, + get_all_projects_for_account, +) logger = logging.getLogger(__name__) @@ -56,77 +61,6 @@ async def send_wrapper(message): await super().__call__(scope, receive, send_wrapper) -async def _get_all_accounts( - *, - dbt_platform_url: str, - headers: dict[str, str], -) -> list[DbtPlatformAccount]: - async with httpx.AsyncClient() as client: - response = await client.get( - url=f"{dbt_platform_url}/api/v3/accounts/", - headers=headers, - ) - response.raise_for_status() - data = response.json() - return [DbtPlatformAccount(**account) for account in data["data"]] - - -async def _get_all_projects_for_account( - *, - dbt_platform_url: str, - account: DbtPlatformAccount, - headers: dict[str, str], - page_size: int = 100, -) -> list[DbtPlatformProject]: - """Fetch all projects for an account using offset/page_size pagination.""" - offset = 0 - projects: list[DbtPlatformProject] = [] - async with httpx.AsyncClient() as client: - while True: - response = await client.get( - f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}", - headers=headers, - ) - response.raise_for_status() - page = response.json()["data"] - projects.extend( - DbtPlatformProject(**project, account_name=account.name) - for project in page - ) - if len(page) < page_size: - break - offset += page_size - return projects - - -async def _get_all_environments_for_project( - *, - dbt_platform_url: str, - account_id: int, - project_id: int, - headers: dict[str, str], - page_size: int = 100, -) -> list[DbtPlatformEnvironmentResponse]: - """Fetch all environments for a project using offset/page_size pagination.""" - offset = 0 - environments: list[DbtPlatformEnvironmentResponse] = [] - async with httpx.AsyncClient() as client: - while True: - response = await client.get( - f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}", - headers=headers, - ) - response.raise_for_status() - page = response.json()["data"] - environments.extend( - DbtPlatformEnvironmentResponse(**environment) for environment in page - ) - if len(page) < page_size: - break - offset += page_size - return environments - - def create_app( *, oauth_client: OAuth2Session, @@ -206,14 +140,14 @@ async def projects() -> list[DbtPlatformProject]: "Accept": "application/json", "Authorization": f"Bearer {access_token}", } - accounts = await _get_all_accounts( + accounts = await get_all_accounts( dbt_platform_url=dbt_platform_url, headers=headers, ) projects: list[DbtPlatformProject] = [] for account in [a for a in accounts if a.state == 1 and not a.locked]: projects.extend( - await _get_all_projects_for_account( + await get_all_projects_for_account( dbt_platform_url=dbt_platform_url, account=account, headers=headers, @@ -265,7 +199,7 @@ async def set_selected_project( "Accept": "application/json", "Authorization": f"Bearer {access_token}", } - accounts = await _get_all_accounts( + accounts = await get_all_accounts( dbt_platform_url=dbt_platform_url, headers=headers, ) @@ -282,45 +216,10 @@ async def set_selected_project( page_size=100, ) - prod_environment = None - dev_environment = None - - # If a specific prod_environment_id was provided, use it - if selected_project_request.prod_environment_id: - for environment in environments: - if environment.id == selected_project_request.prod_environment_id: - prod_environment = DbtPlatformEnvironment( - id=environment.id, - name=environment.name, - deployment_type=environment.deployment_type or "production", - ) - break - else: - # Fall back to auto-detection based on deployment_type - for environment in environments: - if ( - environment.deployment_type - and environment.deployment_type.lower() == "production" - ): - prod_environment = DbtPlatformEnvironment( - id=environment.id, - name=environment.name, - deployment_type=environment.deployment_type, - ) - break - - # Always try to auto-detect dev environment - for environment in environments: - if ( - environment.deployment_type - and environment.deployment_type.lower() == "development" - ): - dev_environment = DbtPlatformEnvironment( - id=environment.id, - name=environment.name, - deployment_type=environment.deployment_type, - ) - break + prod_environment, dev_environment = resolve_environments( + environments, + prod_environment_id=selected_project_request.prod_environment_id, + ) dbt_platform_context = dbt_platform_context_manager.update_context( new_dbt_platform_context=DbtPlatformContext( diff --git a/src/dbt_mcp/project/__init__.py b/src/dbt_mcp/project/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dbt_mcp/project/environment_resolver.py b/src/dbt_mcp/project/environment_resolver.py new file mode 100644 index 00000000..b48858c4 --- /dev/null +++ b/src/dbt_mcp/project/environment_resolver.py @@ -0,0 +1,115 @@ +import logging + +import httpx + +from dbt_mcp.oauth.dbt_platform import ( + DbtPlatformEnvironment, + DbtPlatformEnvironmentResponse, +) + +logger = logging.getLogger(__name__) + + +async def _get_all_environments_for_project( + *, + dbt_platform_url: str, + account_id: int, + project_id: int, + headers: dict[str, str], + page_size: int = 100, +) -> list[DbtPlatformEnvironmentResponse]: + """Fetch all environments for a project using offset/page_size pagination.""" + offset = 0 + environments: list[DbtPlatformEnvironmentResponse] = [] + async with httpx.AsyncClient() as client: + while True: + response = await client.get( + f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}", + headers=headers, + ) + response.raise_for_status() + page = response.json()["data"] + environments.extend( + DbtPlatformEnvironmentResponse(**environment) for environment in page + ) + if len(page) < page_size: + break + offset += page_size + return environments + + +def resolve_environments( + environments: list[DbtPlatformEnvironmentResponse], + *, + prod_environment_id: int | None = None, +) -> tuple[DbtPlatformEnvironment | None, DbtPlatformEnvironment | None]: + """Resolve prod and dev environments from a list of environment responses. + + Returns a tuple of (prod_environment, dev_environment). + + If prod_environment_id is provided, that specific environment is used as prod. + Otherwise, auto-detects based on deployment_type == "production". + Dev environment is always auto-detected based on deployment_type == "development". + """ + prod_environment: DbtPlatformEnvironment | None = None + dev_environment: DbtPlatformEnvironment | None = None + + if prod_environment_id: + for environment in environments: + if environment.id == prod_environment_id: + prod_environment = DbtPlatformEnvironment( + id=environment.id, + name=environment.name, + deployment_type=environment.deployment_type or "production", + ) + break + else: + for environment in environments: + if ( + environment.deployment_type + and environment.deployment_type.lower() == "production" + ): + prod_environment = DbtPlatformEnvironment( + id=environment.id, + name=environment.name, + deployment_type=environment.deployment_type, + ) + break + + for environment in environments: + if ( + environment.deployment_type + and environment.deployment_type.lower() == "development" + ): + dev_environment = DbtPlatformEnvironment( + id=environment.id, + name=environment.name, + deployment_type=environment.deployment_type, + ) + break + + return prod_environment, dev_environment + + +async def get_environments_for_project( + *, + dbt_platform_url: str, + account_id: int, + project_id: int, + headers: dict[str, str], + prod_environment_id: int | None = None, +) -> tuple[DbtPlatformEnvironment | None, DbtPlatformEnvironment | None]: + """Fetch environments for a project and resolve prod/dev. + + Returns a tuple of (prod_environment, dev_environment). + """ + environments = await _get_all_environments_for_project( + dbt_platform_url=dbt_platform_url, + account_id=account_id, + project_id=project_id, + headers=headers, + ) + return resolve_environments( + environments, + prod_environment_id=prod_environment_id, + ) diff --git a/src/dbt_mcp/project/project_resolver.py b/src/dbt_mcp/project/project_resolver.py new file mode 100644 index 00000000..07e9ac7a --- /dev/null +++ b/src/dbt_mcp/project/project_resolver.py @@ -0,0 +1,53 @@ +import logging + +import httpx + +from dbt_mcp.oauth.dbt_platform import ( + DbtPlatformAccount, + DbtPlatformProject, +) + +logger = logging.getLogger(__name__) + + +async def get_all_accounts( + *, + dbt_platform_url: str, + headers: dict[str, str], +) -> list[DbtPlatformAccount]: + async with httpx.AsyncClient() as client: + response = await client.get( + url=f"{dbt_platform_url}/api/v3/accounts/", + headers=headers, + ) + response.raise_for_status() + data = response.json() + return [DbtPlatformAccount(**account) for account in data["data"]] + + +async def get_all_projects_for_account( + *, + dbt_platform_url: str, + account: DbtPlatformAccount, + headers: dict[str, str], + page_size: int = 100, +) -> list[DbtPlatformProject]: + """Fetch all projects for an account using offset/page_size pagination.""" + offset = 0 + projects: list[DbtPlatformProject] = [] + async with httpx.AsyncClient() as client: + while True: + response = await client.get( + f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}", + headers=headers, + ) + response.raise_for_status() + page = response.json()["data"] + projects.extend( + DbtPlatformProject(**project, account_name=account.name) + for project in page + ) + if len(page) < page_size: + break + offset += page_size + return projects diff --git a/tests/unit/oauth/test_fastapi_app_pagination.py b/tests/unit/oauth/test_fastapi_app_pagination.py index d850ec08..c5daca16 100644 --- a/tests/unit/oauth/test_fastapi_app_pagination.py +++ b/tests/unit/oauth/test_fastapi_app_pagination.py @@ -3,10 +3,8 @@ import pytest from dbt_mcp.oauth.dbt_platform import DbtPlatformAccount -from dbt_mcp.oauth.fastapi_app import ( - _get_all_environments_for_project, - _get_all_projects_for_account, -) +from dbt_mcp.project.environment_resolver import _get_all_environments_for_project +from dbt_mcp.project.project_resolver import get_all_projects_for_account @pytest.fixture @@ -62,7 +60,7 @@ async def test_get_all_projects_for_account_paginates(base_headers, account): mock_client = create_mock_httpx_client([first_page_resp, second_page_resp]) with patch("httpx.AsyncClient", return_value=mock_client): - result = await _get_all_projects_for_account( + result = await get_all_projects_for_account( dbt_platform_url="https://cloud.getdbt.com", account=account, headers=base_headers,