diff --git a/.changes/unreleased/Enhancement or New Feature-20260317-021539.yaml b/.changes/unreleased/Enhancement or New Feature-20260317-021539.yaml new file mode 100644 index 00000000..2bb8c9f2 --- /dev/null +++ b/.changes/unreleased/Enhancement or New Feature-20260317-021539.yaml @@ -0,0 +1,3 @@ +kind: Enhancement or New Feature +body: Extract project/environment helpers for multi-project support +time: 2026-03-17T02:15:39.921671-05:00 diff --git a/.changes/unreleased/Enhancement or New Feature-20260317-082145.yaml b/.changes/unreleased/Enhancement or New Feature-20260317-082145.yaml new file mode 100644 index 00000000..602a135d --- /dev/null +++ b/.changes/unreleased/Enhancement or New Feature-20260317-082145.yaml @@ -0,0 +1,3 @@ +kind: Enhancement or New Feature +body: Add multi-project semantic layer tools with project_id parameter +time: 2026-03-17T08:21:45.385201-05:00 diff --git a/.gitignore b/.gitignore index d3cedca2..346db069 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ __pycache__/ .venv/ .mypy_cache/ .pytest_cache/ +pytest-of-*/ *.egg-info/ # IDEs diff --git a/src/dbt_mcp/config/config_providers.py b/src/dbt_mcp/config/config_providers.py index ee57c494..9c741352 100644 --- a/src/dbt_mcp/config/config_providers.py +++ b/src/dbt_mcp/config/config_providers.py @@ -9,7 +9,9 @@ SemanticLayerHeadersProvider, TokenProvider, ) -from dbt_mcp.config.settings import CredentialsProvider +from dbt_mcp.config.settings import CredentialsProvider, DbtMcpSettings +from dbt_mcp.oauth.dbt_platform import DbtPlatformEnvironment +from dbt_mcp.project.environment_resolver import get_environments_for_project @dataclass @@ -45,6 +47,37 @@ class ProxiedToolConfig: headers_provider: ProxiedToolHeadersProvider +async def _resolve_project_environments( + credentials_provider: CredentialsProvider, + project_id: int, +) -> tuple[ + DbtMcpSettings, + TokenProvider, + DbtPlatformEnvironment, + DbtPlatformEnvironment | None, +]: + settings, token_provider = await credentials_provider.get_credentials() + assert settings.actual_host and settings.dbt_account_id + dbt_platform_url = ( + f"https://{settings.actual_host_prefix}.{settings.actual_host}" + if settings.actual_host_prefix + else f"https://{settings.actual_host}" + ) + headers = { + "Accept": "application/json", + "Authorization": f"Bearer {token_provider.get_token()}", + } + prod_env, dev_env = await get_environments_for_project( + dbt_platform_url=dbt_platform_url, + account_id=settings.dbt_account_id, + project_id=project_id, + headers=headers, + ) + if not prod_env: + raise ValueError(f"No production environment found for project {project_id}") + return settings, token_provider, prod_env, dev_env + + class ConfigProvider[ConfigType](ABC): @abstractmethod async def get_config(self) -> ConfigType: ... @@ -54,6 +87,30 @@ class DefaultSemanticLayerConfigProvider(ConfigProvider[SemanticLayerConfig]): def __init__(self, credentials_provider: CredentialsProvider): self.credentials_provider = credentials_provider + async def get_config_for_project(self, project_id: int) -> SemanticLayerConfig: + settings, token_provider, prod_env, _ = await _resolve_project_environments( + self.credentials_provider, project_id + ) + assert settings.actual_host + host = settings.actual_host + host_prefix = settings.actual_host_prefix + is_local = host.startswith("localhost") + if is_local: + sl_host = host + elif host_prefix: + sl_host = f"{host_prefix}.semantic-layer.{host}" + else: + sl_host = f"semantic-layer.{host}" + return SemanticLayerConfig( + url=f"http://{sl_host}" if is_local else f"https://{sl_host}/api/graphql", + host=sl_host, + prod_environment_id=prod_env.id, + token_provider=token_provider, + headers_provider=SemanticLayerHeadersProvider( + token_provider=token_provider + ), + ) + async def get_config(self) -> SemanticLayerConfig: settings, token_provider = await self.credentials_provider.get_credentials() assert settings.actual_host and settings.actual_prod_environment_id diff --git a/src/dbt_mcp/main.py b/src/dbt_mcp/main.py index 46a66a57..a05d49b0 100644 --- a/src/dbt_mcp/main.py +++ b/src/dbt_mcp/main.py @@ -1,14 +1,19 @@ import asyncio +import logging import os from dbt_mcp.config.config import load_config from dbt_mcp.config.transport import validate_transport from dbt_mcp.mcp.server import create_dbt_mcp +logger = logging.getLogger(__name__) + def main() -> None: config = load_config() + server = asyncio.run(create_dbt_mcp(config)) + transport = validate_transport(os.environ.get("MCP_TRANSPORT", "stdio")) server.run(transport=transport) diff --git a/src/dbt_mcp/mcp/server.py b/src/dbt_mcp/mcp/server.py index db5c7b71..3fc3c497 100644 --- a/src/dbt_mcp/mcp/server.py +++ b/src/dbt_mcp/mcp/server.py @@ -1,5 +1,6 @@ import asyncio import logging +import os import time import uuid from collections.abc import AsyncIterator, Callable, Sequence @@ -15,18 +16,19 @@ from dbt_mcp.dbt_admin.tools import register_admin_api_tools from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools from dbt_mcp.dbt_codegen.tools import register_dbt_codegen_tools -from dbt_mcp.product_docs.tools import register_product_docs_tools from dbt_mcp.discovery.tools import register_discovery_tools -from dbt_mcp.mcp_server_metadata.tools import register_mcp_server_tools from dbt_mcp.lsp.providers.local_lsp_client_provider import LocalLSPClientProvider from dbt_mcp.lsp.providers.local_lsp_connection_provider import ( LocalLSPConnectionProvider, ) from dbt_mcp.lsp.providers.lsp_connection_provider import LSPConnectionProviderProtocol from dbt_mcp.lsp.tools import register_lsp_tools +from dbt_mcp.mcp_server_metadata.tools import register_mcp_server_tools +from dbt_mcp.product_docs.tools import register_product_docs_tools from dbt_mcp.proxy.tools import ProxiedToolsManager, register_proxied_tools from dbt_mcp.semantic_layer.client import DefaultSemanticLayerClientProvider from dbt_mcp.semantic_layer.tools import register_sl_tools +from dbt_mcp.semantic_layer.tools_multiproject import register_multiproject_sl_tools from dbt_mcp.tracking.tracking import DefaultUsageTracker, ToolCalledEvent, UsageTracker logger = logging.getLogger(__name__) @@ -101,6 +103,14 @@ async def call_tool( return result +def _multi_project_enabled() -> bool: + return os.environ.get("DBT_MCP_MULTI_PROJECT_ENABLED", "").lower() in ( + "true", + "1", + "yes", + ) + + @asynccontextmanager async def app_lifespan(server: FastMCP[Any]) -> AsyncIterator[bool | None]: if not isinstance(server, DbtMCP): @@ -110,7 +120,7 @@ async def app_lifespan(server: FastMCP[Any]) -> AsyncIterator[bool | None]: # register proxied tools inside the app lifespan to ensure the StreamableHTTP client (specific # to dbt Platform connection) lives on the same event loop as the running server # this avoids anyio cancel scope violations (see issue #498) - if server.config.proxied_tool_config_provider: + if server.config.proxied_tool_config_provider and not _multi_project_enabled(): logger.info("Registering proxied tools") await register_proxied_tools( dbt_mcp=server, @@ -149,6 +159,29 @@ async def app_lifespan(server: FastMCP[Any]) -> AsyncIterator[bool | None]: logger.exception("Error shutting down MCP server") +async def register_multi_project_dbt_mcp(dbt_mcp: DbtMCP, config: Config) -> None: + disabled_tools = set(config.disable_tools) + enabled_tools = ( + set(config.enable_tools) if config.enable_tools is not None else None + ) + enabled_toolsets = config.enabled_toolsets + disabled_toolsets = config.disabled_toolsets + + logger.info("Registering semantic layer tools for multi-project") + if config.semantic_layer_config_provider: + register_multiproject_sl_tools( + dbt_mcp=dbt_mcp, + config_provider=config.semantic_layer_config_provider, + client_provider=DefaultSemanticLayerClientProvider( + config_provider=config.semantic_layer_config_provider, + ), + disabled_tools=disabled_tools, + enabled_tools=enabled_tools, + enabled_toolsets=enabled_toolsets, + disabled_toolsets=disabled_toolsets, + ) + + async def create_dbt_mcp(config: Config) -> DbtMCP: dbt_mcp = DbtMCP( config=config, @@ -160,6 +193,20 @@ async def create_dbt_mcp(config: Config) -> DbtMCP: lifespan=app_lifespan, ) + multi_project_enabled = os.environ.get( + "DBT_MCP_MULTI_PROJECT_ENABLED", "" + ).lower() in ("true", "1", "yes") + + if multi_project_enabled: + logger.info("DBT_MCP_MULTI_PROJECT_ENABLED=true -> Multi-project mode") + await register_multi_project_dbt_mcp(dbt_mcp, config) + else: + logger.info("Multi-project mode disabled -> Env-var mode") + await register_dbt_mcp_tools(dbt_mcp, config) + return dbt_mcp + + +async def register_dbt_mcp_tools(dbt_mcp: DbtMCP, config: Config) -> None: disabled_tools = set(config.disable_tools) enabled_tools = ( set(config.enable_tools) if config.enable_tools is not None else None @@ -263,5 +310,3 @@ async def create_dbt_mcp(config: Config) -> DbtMCP: enabled_toolsets=enabled_toolsets, disabled_toolsets=disabled_toolsets, ) - - return dbt_mcp diff --git a/src/dbt_mcp/oauth/fastapi_app.py b/src/dbt_mcp/oauth/fastapi_app.py index febfb362..59885691 100644 --- a/src/dbt_mcp/oauth/fastapi_app.py +++ b/src/dbt_mcp/oauth/fastapi_app.py @@ -2,7 +2,6 @@ from typing import cast from urllib.parse import quote -import httpx from authlib.integrations.requests_client import OAuth2Session from fastapi import FastAPI, Request from fastapi.responses import RedirectResponse @@ -12,9 +11,7 @@ from dbt_mcp.oauth.context_manager import DbtPlatformContextManager from dbt_mcp.oauth.dbt_platform import ( - DbtPlatformAccount, DbtPlatformContext, - DbtPlatformEnvironment, DbtPlatformEnvironmentResponse, DbtPlatformProject, GetEnvironmentsRequest, @@ -24,6 +21,14 @@ from dbt_mcp.oauth.token import ( DecodedAccessToken, ) +from dbt_mcp.project.environment_resolver import ( + _get_all_environments_for_project, + resolve_environments, +) +from dbt_mcp.project.project_resolver import ( + get_all_accounts, + get_all_projects_for_account, +) logger = logging.getLogger(__name__) @@ -56,77 +61,6 @@ async def send_wrapper(message): await super().__call__(scope, receive, send_wrapper) -async def _get_all_accounts( - *, - dbt_platform_url: str, - headers: dict[str, str], -) -> list[DbtPlatformAccount]: - async with httpx.AsyncClient() as client: - response = await client.get( - url=f"{dbt_platform_url}/api/v3/accounts/", - headers=headers, - ) - response.raise_for_status() - data = response.json() - return [DbtPlatformAccount(**account) for account in data["data"]] - - -async def _get_all_projects_for_account( - *, - dbt_platform_url: str, - account: DbtPlatformAccount, - headers: dict[str, str], - page_size: int = 100, -) -> list[DbtPlatformProject]: - """Fetch all projects for an account using offset/page_size pagination.""" - offset = 0 - projects: list[DbtPlatformProject] = [] - async with httpx.AsyncClient() as client: - while True: - response = await client.get( - f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}", - headers=headers, - ) - response.raise_for_status() - page = response.json()["data"] - projects.extend( - DbtPlatformProject(**project, account_name=account.name) - for project in page - ) - if len(page) < page_size: - break - offset += page_size - return projects - - -async def _get_all_environments_for_project( - *, - dbt_platform_url: str, - account_id: int, - project_id: int, - headers: dict[str, str], - page_size: int = 100, -) -> list[DbtPlatformEnvironmentResponse]: - """Fetch all environments for a project using offset/page_size pagination.""" - offset = 0 - environments: list[DbtPlatformEnvironmentResponse] = [] - async with httpx.AsyncClient() as client: - while True: - response = await client.get( - f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}", - headers=headers, - ) - response.raise_for_status() - page = response.json()["data"] - environments.extend( - DbtPlatformEnvironmentResponse(**environment) for environment in page - ) - if len(page) < page_size: - break - offset += page_size - return environments - - def create_app( *, oauth_client: OAuth2Session, @@ -206,14 +140,14 @@ async def projects() -> list[DbtPlatformProject]: "Accept": "application/json", "Authorization": f"Bearer {access_token}", } - accounts = await _get_all_accounts( + accounts = await get_all_accounts( dbt_platform_url=dbt_platform_url, headers=headers, ) projects: list[DbtPlatformProject] = [] for account in [a for a in accounts if a.state == 1 and not a.locked]: projects.extend( - await _get_all_projects_for_account( + await get_all_projects_for_account( dbt_platform_url=dbt_platform_url, account=account, headers=headers, @@ -265,7 +199,7 @@ async def set_selected_project( "Accept": "application/json", "Authorization": f"Bearer {access_token}", } - accounts = await _get_all_accounts( + accounts = await get_all_accounts( dbt_platform_url=dbt_platform_url, headers=headers, ) @@ -282,45 +216,10 @@ async def set_selected_project( page_size=100, ) - prod_environment = None - dev_environment = None - - # If a specific prod_environment_id was provided, use it - if selected_project_request.prod_environment_id: - for environment in environments: - if environment.id == selected_project_request.prod_environment_id: - prod_environment = DbtPlatformEnvironment( - id=environment.id, - name=environment.name, - deployment_type=environment.deployment_type or "production", - ) - break - else: - # Fall back to auto-detection based on deployment_type - for environment in environments: - if ( - environment.deployment_type - and environment.deployment_type.lower() == "production" - ): - prod_environment = DbtPlatformEnvironment( - id=environment.id, - name=environment.name, - deployment_type=environment.deployment_type, - ) - break - - # Always try to auto-detect dev environment - for environment in environments: - if ( - environment.deployment_type - and environment.deployment_type.lower() == "development" - ): - dev_environment = DbtPlatformEnvironment( - id=environment.id, - name=environment.name, - deployment_type=environment.deployment_type, - ) - break + prod_environment, dev_environment = resolve_environments( + environments, + prod_environment_id=selected_project_request.prod_environment_id, + ) dbt_platform_context = dbt_platform_context_manager.update_context( new_dbt_platform_context=DbtPlatformContext( diff --git a/src/dbt_mcp/project/__init__.py b/src/dbt_mcp/project/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/dbt_mcp/project/environment_resolver.py b/src/dbt_mcp/project/environment_resolver.py new file mode 100644 index 00000000..b48858c4 --- /dev/null +++ b/src/dbt_mcp/project/environment_resolver.py @@ -0,0 +1,115 @@ +import logging + +import httpx + +from dbt_mcp.oauth.dbt_platform import ( + DbtPlatformEnvironment, + DbtPlatformEnvironmentResponse, +) + +logger = logging.getLogger(__name__) + + +async def _get_all_environments_for_project( + *, + dbt_platform_url: str, + account_id: int, + project_id: int, + headers: dict[str, str], + page_size: int = 100, +) -> list[DbtPlatformEnvironmentResponse]: + """Fetch all environments for a project using offset/page_size pagination.""" + offset = 0 + environments: list[DbtPlatformEnvironmentResponse] = [] + async with httpx.AsyncClient() as client: + while True: + response = await client.get( + f"{dbt_platform_url}/api/v3/accounts/{account_id}/projects/{project_id}/environments/?state=1&offset={offset}&limit={page_size}", + headers=headers, + ) + response.raise_for_status() + page = response.json()["data"] + environments.extend( + DbtPlatformEnvironmentResponse(**environment) for environment in page + ) + if len(page) < page_size: + break + offset += page_size + return environments + + +def resolve_environments( + environments: list[DbtPlatformEnvironmentResponse], + *, + prod_environment_id: int | None = None, +) -> tuple[DbtPlatformEnvironment | None, DbtPlatformEnvironment | None]: + """Resolve prod and dev environments from a list of environment responses. + + Returns a tuple of (prod_environment, dev_environment). + + If prod_environment_id is provided, that specific environment is used as prod. + Otherwise, auto-detects based on deployment_type == "production". + Dev environment is always auto-detected based on deployment_type == "development". + """ + prod_environment: DbtPlatformEnvironment | None = None + dev_environment: DbtPlatformEnvironment | None = None + + if prod_environment_id: + for environment in environments: + if environment.id == prod_environment_id: + prod_environment = DbtPlatformEnvironment( + id=environment.id, + name=environment.name, + deployment_type=environment.deployment_type or "production", + ) + break + else: + for environment in environments: + if ( + environment.deployment_type + and environment.deployment_type.lower() == "production" + ): + prod_environment = DbtPlatformEnvironment( + id=environment.id, + name=environment.name, + deployment_type=environment.deployment_type, + ) + break + + for environment in environments: + if ( + environment.deployment_type + and environment.deployment_type.lower() == "development" + ): + dev_environment = DbtPlatformEnvironment( + id=environment.id, + name=environment.name, + deployment_type=environment.deployment_type, + ) + break + + return prod_environment, dev_environment + + +async def get_environments_for_project( + *, + dbt_platform_url: str, + account_id: int, + project_id: int, + headers: dict[str, str], + prod_environment_id: int | None = None, +) -> tuple[DbtPlatformEnvironment | None, DbtPlatformEnvironment | None]: + """Fetch environments for a project and resolve prod/dev. + + Returns a tuple of (prod_environment, dev_environment). + """ + environments = await _get_all_environments_for_project( + dbt_platform_url=dbt_platform_url, + account_id=account_id, + project_id=project_id, + headers=headers, + ) + return resolve_environments( + environments, + prod_environment_id=prod_environment_id, + ) diff --git a/src/dbt_mcp/project/project_resolver.py b/src/dbt_mcp/project/project_resolver.py new file mode 100644 index 00000000..07e9ac7a --- /dev/null +++ b/src/dbt_mcp/project/project_resolver.py @@ -0,0 +1,53 @@ +import logging + +import httpx + +from dbt_mcp.oauth.dbt_platform import ( + DbtPlatformAccount, + DbtPlatformProject, +) + +logger = logging.getLogger(__name__) + + +async def get_all_accounts( + *, + dbt_platform_url: str, + headers: dict[str, str], +) -> list[DbtPlatformAccount]: + async with httpx.AsyncClient() as client: + response = await client.get( + url=f"{dbt_platform_url}/api/v3/accounts/", + headers=headers, + ) + response.raise_for_status() + data = response.json() + return [DbtPlatformAccount(**account) for account in data["data"]] + + +async def get_all_projects_for_account( + *, + dbt_platform_url: str, + account: DbtPlatformAccount, + headers: dict[str, str], + page_size: int = 100, +) -> list[DbtPlatformProject]: + """Fetch all projects for an account using offset/page_size pagination.""" + offset = 0 + projects: list[DbtPlatformProject] = [] + async with httpx.AsyncClient() as client: + while True: + response = await client.get( + f"{dbt_platform_url}/api/v3/accounts/{account.id}/projects/?state=1&offset={offset}&limit={page_size}", + headers=headers, + ) + response.raise_for_status() + page = response.json()["data"] + projects.extend( + DbtPlatformProject(**project, account_name=account.name) + for project in page + ) + if len(page) < page_size: + break + offset += page_size + return projects diff --git a/src/dbt_mcp/semantic_layer/client.py b/src/dbt_mcp/semantic_layer/client.py index 2634e42f..ebea1209 100644 --- a/src/dbt_mcp/semantic_layer/client.py +++ b/src/dbt_mcp/semantic_layer/client.py @@ -112,15 +112,25 @@ def __init__( self, config_provider: ConfigProvider[SemanticLayerConfig], client_provider: SemanticLayerClientProvider, + config: SemanticLayerConfig | None = None, ): self.client_provider = client_provider self.config_provider = config_provider + self._config = config self.entities_cache: dict[str, list[EntityToolResponse]] = {} self.dimensions_cache: dict[str, list[DimensionToolResponse]] = {} - async def list_metrics(self, search: str | None = None) -> list[MetricToolResponse]: + async def _resolve_config(self) -> SemanticLayerConfig: + if self._config is not None: + return self._config + return await self.config_provider.get_config() + + async def list_metrics( + self, + search: str | None = None, + ) -> list[MetricToolResponse]: metrics_result = await submit_request( - await self.config_provider.get_config(), + await self._resolve_config(), {"query": GRAPHQL_QUERIES["metrics"], "variables": {"search": search}}, ) return [ @@ -135,11 +145,12 @@ async def list_metrics(self, search: str | None = None) -> list[MetricToolRespon ] async def list_saved_queries( - self, search: str | None = None + self, + search: str | None = None, ) -> list[SavedQueryToolResponse]: """Fetch all saved queries from the Semantic Layer API.""" saved_queries_result = await submit_request( - await self.config_provider.get_config(), + await self._resolve_config(), { "query": GRAPHQL_QUERIES["saved_queries"], "variables": {"search": search}, @@ -168,12 +179,14 @@ async def list_saved_queries( ] async def get_dimensions( - self, metrics: list[str], search: str | None = None + self, + metrics: list[str], + search: str | None = None, ) -> list[DimensionToolResponse]: metrics_key = ",".join(sorted(metrics)) if metrics_key not in self.dimensions_cache: dimensions_result = await submit_request( - await self.config_provider.get_config(), + await self._resolve_config(), { "query": GRAPHQL_QUERIES["dimensions"], "variables": { @@ -199,12 +212,14 @@ async def get_dimensions( return self.dimensions_cache[metrics_key] async def get_entities( - self, metrics: list[str], search: str | None = None + self, + metrics: list[str], + search: str | None = None, ) -> list[EntityToolResponse]: metrics_key = ",".join(sorted(metrics)) if metrics_key not in self.entities_cache: entities_result = await submit_request( - await self.config_provider.get_config(), + await self._resolve_config(), { "query": GRAPHQL_QUERIES["entities"], "variables": { diff --git a/src/dbt_mcp/semantic_layer/tools_multiproject.py b/src/dbt_mcp/semantic_layer/tools_multiproject.py new file mode 100644 index 00000000..cf4d9be7 --- /dev/null +++ b/src/dbt_mcp/semantic_layer/tools_multiproject.py @@ -0,0 +1,247 @@ +import logging +from dataclasses import dataclass + +from dbtsl.api.shared.query_params import GroupByParam +from mcp.server.fastmcp import FastMCP + +from dbt_mcp.config.config_providers import ( + DefaultSemanticLayerConfigProvider, + SemanticLayerConfig, +) +from dbt_mcp.prompts.prompts import get_prompt +from dbt_mcp.semantic_layer.client import ( + SemanticLayerClientProvider, + SemanticLayerFetcher, +) +from dbt_mcp.semantic_layer.types import ( + DimensionToolResponse, + EntityToolResponse, + GetMetricsCompiledSqlSuccess, + MetricToolResponse, + OrderByParam, + QueryMetricsSuccess, + SavedQueryToolResponse, +) +from dbt_mcp.tools.definitions import GenericToolDefinition, generic_dbt_mcp_tool +from dbt_mcp.tools.register import register_tools +from dbt_mcp.tools.tool_names import ToolName +from dbt_mcp.tools.toolsets import Toolset + +logger = logging.getLogger(__name__) + + +@dataclass +class MultiProjectSemanticLayerToolContext: + semantic_layer_config_provider: DefaultSemanticLayerConfigProvider + _client_provider: SemanticLayerClientProvider + + def __init__( + self, + config_provider: DefaultSemanticLayerConfigProvider, + client_provider: SemanticLayerClientProvider, + ): + self.semantic_layer_config_provider = config_provider + self._client_provider = client_provider + + def fetcher_for_config( + self, config: SemanticLayerConfig + ) -> SemanticLayerFetcher: + return SemanticLayerFetcher( + config_provider=self.semantic_layer_config_provider, + client_provider=self._client_provider, + config=config, + ) + + +@generic_dbt_mcp_tool( + description=get_prompt("semantic_layer/list_metrics"), + name_enum=ToolName, + name="list_metrics", + title="List Metrics", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, +) +async def list_metrics( + context: MultiProjectSemanticLayerToolContext, + project_id: int, + search: str | None = None, +) -> list[MetricToolResponse]: + config = await context.semantic_layer_config_provider.get_config_for_project( + project_id + ) + fetcher = context.fetcher_for_config(config) + return await fetcher.list_metrics(search=search) + + +@generic_dbt_mcp_tool( + description=get_prompt("semantic_layer/list_saved_queries"), + name_enum=ToolName, + name="list_saved_queries", + title="List Saved Queries", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, +) +async def list_saved_queries( + context: MultiProjectSemanticLayerToolContext, + project_id: int, + search: str | None = None, +) -> list[SavedQueryToolResponse]: + config = await context.semantic_layer_config_provider.get_config_for_project( + project_id + ) + fetcher = context.fetcher_for_config(config) + return await fetcher.list_saved_queries(search=search) + + +@generic_dbt_mcp_tool( + description=get_prompt("semantic_layer/get_dimensions"), + name_enum=ToolName, + name="get_dimensions", + title="Get Dimensions", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, +) +async def get_dimensions( + context: MultiProjectSemanticLayerToolContext, + project_id: int, + metrics: list[str], + search: str | None = None, +) -> list[DimensionToolResponse]: + config = await context.semantic_layer_config_provider.get_config_for_project( + project_id + ) + fetcher = context.fetcher_for_config(config) + return await fetcher.get_dimensions(metrics=metrics, search=search) + + +@generic_dbt_mcp_tool( + description=get_prompt("semantic_layer/get_entities"), + name_enum=ToolName, + name="get_entities", + title="Get Entities", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, +) +async def get_entities( + context: MultiProjectSemanticLayerToolContext, + project_id: int, + metrics: list[str], + search: str | None = None, +) -> list[EntityToolResponse]: + config = await context.semantic_layer_config_provider.get_config_for_project( + project_id + ) + fetcher = context.fetcher_for_config(config) + return await fetcher.get_entities(metrics=metrics, search=search) + + +@generic_dbt_mcp_tool( + description=get_prompt("semantic_layer/query_metrics"), + name_enum=ToolName, + name="query_metrics", + title="Query Metrics", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, +) +async def query_metrics( + context: MultiProjectSemanticLayerToolContext, + project_id: int, + metrics: list[str], + group_by: list[GroupByParam] | None = None, + order_by: list[OrderByParam] | None = None, + where: str | None = None, + limit: int | None = None, +) -> str: + config = await context.semantic_layer_config_provider.get_config_for_project( + project_id + ) + fetcher = context.fetcher_for_config(config) + result = await fetcher.query_metrics( + metrics=metrics, + group_by=group_by, + order_by=order_by, + where=where, + limit=limit, + ) + if isinstance(result, QueryMetricsSuccess): + return result.result + else: + return result.error + + +@generic_dbt_mcp_tool( + description=get_prompt("semantic_layer/get_metrics_compiled_sql"), + name_enum=ToolName, + name="get_metrics_compiled_sql", + title="Compile SQL", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, +) +async def get_metrics_compiled_sql( + context: MultiProjectSemanticLayerToolContext, + project_id: int, + metrics: list[str], + group_by: list[GroupByParam] | None = None, + order_by: list[OrderByParam] | None = None, + where: str | None = None, + limit: int | None = None, +) -> str: + config = await context.semantic_layer_config_provider.get_config_for_project( + project_id + ) + fetcher = context.fetcher_for_config(config) + result = await fetcher.get_metrics_compiled_sql( + metrics=metrics, + group_by=group_by, + order_by=order_by, + where=where, + limit=limit, + ) + if isinstance(result, GetMetricsCompiledSqlSuccess): + return result.sql + else: + return result.error + + +MULTIPROJECT_SEMANTIC_LAYER_TOOLS: list[GenericToolDefinition[ToolName]] = [ + list_metrics, + list_saved_queries, + get_dimensions, + get_entities, + query_metrics, + get_metrics_compiled_sql, +] + + +def register_multiproject_sl_tools( + dbt_mcp: FastMCP, + config_provider: DefaultSemanticLayerConfigProvider, + client_provider: SemanticLayerClientProvider, + *, + disabled_tools: set[ToolName], + enabled_tools: set[ToolName] | None, + enabled_toolsets: set[Toolset], + disabled_toolsets: set[Toolset], +) -> None: + def bind_context() -> MultiProjectSemanticLayerToolContext: + return MultiProjectSemanticLayerToolContext( + config_provider=config_provider, client_provider=client_provider + ) + + register_tools( + dbt_mcp, + [ + tool.adapt_context(bind_context) + for tool in MULTIPROJECT_SEMANTIC_LAYER_TOOLS + ], + disabled_tools=disabled_tools, + enabled_tools=enabled_tools, + enabled_toolsets=enabled_toolsets, + disabled_toolsets=disabled_toolsets, + ) diff --git a/tests/unit/oauth/test_fastapi_app_pagination.py b/tests/unit/oauth/test_fastapi_app_pagination.py index d850ec08..c5daca16 100644 --- a/tests/unit/oauth/test_fastapi_app_pagination.py +++ b/tests/unit/oauth/test_fastapi_app_pagination.py @@ -3,10 +3,8 @@ import pytest from dbt_mcp.oauth.dbt_platform import DbtPlatformAccount -from dbt_mcp.oauth.fastapi_app import ( - _get_all_environments_for_project, - _get_all_projects_for_account, -) +from dbt_mcp.project.environment_resolver import _get_all_environments_for_project +from dbt_mcp.project.project_resolver import get_all_projects_for_account @pytest.fixture @@ -62,7 +60,7 @@ async def test_get_all_projects_for_account_paginates(base_headers, account): mock_client = create_mock_httpx_client([first_page_resp, second_page_resp]) with patch("httpx.AsyncClient", return_value=mock_client): - result = await _get_all_projects_for_account( + result = await get_all_projects_for_account( dbt_platform_url="https://cloud.getdbt.com", account=account, headers=base_headers, diff --git a/tests/unit/tools/test_tool_names.py b/tests/unit/tools/test_tool_names.py index 2d69cd2a..004a2cbc 100644 --- a/tests/unit/tools/test_tool_names.py +++ b/tests/unit/tools/test_tool_names.py @@ -30,8 +30,8 @@ async def test_tool_names_match_server_tools(env_setup): # Get all tools from the server server_tools = await dbt_mcp.list_tools() - # Manually adding proxied tools here because the server doesn't get them - # in this unit test. + # Manually adding proxied tools here because + # they are not registered on the default server in this unit test. server_tool_names = {tool.name for tool in server_tools} | { p.value for p in proxied_tools } diff --git a/tests/unit/tools/test_toolsets.py b/tests/unit/tools/test_toolsets.py index c8321fe0..eecb812d 100644 --- a/tests/unit/tools/test_toolsets.py +++ b/tests/unit/tools/test_toolsets.py @@ -55,8 +55,8 @@ async def test_toolsets_match_server_tools(env_setup): # Get all tools from the server server_tools = await dbt_mcp.list_tools() - # Manually adding SQL tools here because the server doesn't get them - # in this unit test. + # Manually adding proxied tools here because + # they are not registered on the default server in this unit test. server_tool_names = {tool.name for tool in server_tools} | { p.value for p in proxied_tools }