Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Enhancement or New Feature
body: Add multi-project semantic layer tools with project_id parameter
time: 2026-03-17T08:21:45.385201-05:00
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ __pycache__/
.venv/
.mypy_cache/
.pytest_cache/
pytest-of-*/
*.egg-info/

# IDEs
Expand Down
59 changes: 58 additions & 1 deletion src/dbt_mcp/config/config_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
SemanticLayerHeadersProvider,
TokenProvider,
)
from dbt_mcp.config.settings import CredentialsProvider
from dbt_mcp.config.settings import CredentialsProvider, DbtMcpSettings
from dbt_mcp.oauth.dbt_platform import DbtPlatformEnvironment
from dbt_mcp.project.environment_resolver import get_environments_for_project


@dataclass
Expand Down Expand Up @@ -45,6 +47,37 @@ class ProxiedToolConfig:
headers_provider: ProxiedToolHeadersProvider


async def _resolve_project_environments(
credentials_provider: CredentialsProvider,
project_id: int,
) -> tuple[
DbtMcpSettings,
TokenProvider,
DbtPlatformEnvironment,
DbtPlatformEnvironment | None,
]:
settings, token_provider = await credentials_provider.get_credentials()
assert settings.actual_host and settings.dbt_account_id
dbt_platform_url = (
f"https://{settings.actual_host_prefix}.{settings.actual_host}"
if settings.actual_host_prefix
else f"https://{settings.actual_host}"
)
headers = {
"Accept": "application/json",
"Authorization": f"Bearer {token_provider.get_token()}",
}
prod_env, dev_env = await get_environments_for_project(
dbt_platform_url=dbt_platform_url,
account_id=settings.dbt_account_id,
project_id=project_id,
headers=headers,
)
if not prod_env:
raise ValueError(f"No production environment found for project {project_id}")
return settings, token_provider, prod_env, dev_env


class ConfigProvider[ConfigType](ABC):
@abstractmethod
async def get_config(self) -> ConfigType: ...
Expand All @@ -54,6 +87,30 @@ class DefaultSemanticLayerConfigProvider(ConfigProvider[SemanticLayerConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
self.credentials_provider = credentials_provider

async def get_config_for_project(self, project_id: int) -> SemanticLayerConfig:
settings, token_provider, prod_env, _ = await _resolve_project_environments(
self.credentials_provider, project_id
)
assert settings.actual_host
host = settings.actual_host
host_prefix = settings.actual_host_prefix
is_local = host.startswith("localhost")
if is_local:
sl_host = host
elif host_prefix:
sl_host = f"{host_prefix}.semantic-layer.{host}"
else:
sl_host = f"semantic-layer.{host}"
return SemanticLayerConfig(
url=f"http://{sl_host}" if is_local else f"https://{sl_host}/api/graphql",
host=sl_host,
prod_environment_id=prod_env.id,
token_provider=token_provider,
headers_provider=SemanticLayerHeadersProvider(
token_provider=token_provider
),
)

async def get_config(self) -> SemanticLayerConfig:
settings, token_provider = await self.credentials_provider.get_credentials()
assert settings.actual_host and settings.actual_prod_environment_id
Expand Down
31 changes: 23 additions & 8 deletions src/dbt_mcp/semantic_layer/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,25 @@ def __init__(
self,
config_provider: ConfigProvider[SemanticLayerConfig],
client_provider: SemanticLayerClientProvider,
config: SemanticLayerConfig | None = None,
):
self.client_provider = client_provider
self.config_provider = config_provider
self._config = config
self.entities_cache: dict[str, list[EntityToolResponse]] = {}
self.dimensions_cache: dict[str, list[DimensionToolResponse]] = {}

async def list_metrics(self, search: str | None = None) -> list[MetricToolResponse]:
async def _resolve_config(self) -> SemanticLayerConfig:
if self._config is not None:
return self._config
return await self.config_provider.get_config()

async def list_metrics(
self,
search: str | None = None,
) -> list[MetricToolResponse]:
metrics_result = await submit_request(
await self.config_provider.get_config(),
await self._resolve_config(),
{"query": GRAPHQL_QUERIES["metrics"], "variables": {"search": search}},
)
return [
Expand All @@ -135,11 +145,12 @@ async def list_metrics(self, search: str | None = None) -> list[MetricToolRespon
]

async def list_saved_queries(
self, search: str | None = None
self,
search: str | None = None,
) -> list[SavedQueryToolResponse]:
"""Fetch all saved queries from the Semantic Layer API."""
saved_queries_result = await submit_request(
await self.config_provider.get_config(),
await self._resolve_config(),
{
"query": GRAPHQL_QUERIES["saved_queries"],
"variables": {"search": search},
Expand Down Expand Up @@ -168,12 +179,14 @@ async def list_saved_queries(
]

async def get_dimensions(
self, metrics: list[str], search: str | None = None
self,
metrics: list[str],
search: str | None = None,
) -> list[DimensionToolResponse]:
metrics_key = ",".join(sorted(metrics))
if metrics_key not in self.dimensions_cache:
dimensions_result = await submit_request(
await self.config_provider.get_config(),
await self._resolve_config(),
{
"query": GRAPHQL_QUERIES["dimensions"],
"variables": {
Expand All @@ -199,12 +212,14 @@ async def get_dimensions(
return self.dimensions_cache[metrics_key]

async def get_entities(
self, metrics: list[str], search: str | None = None
self,
metrics: list[str],
search: str | None = None,
) -> list[EntityToolResponse]:
metrics_key = ",".join(sorted(metrics))
if metrics_key not in self.entities_cache:
entities_result = await submit_request(
await self.config_provider.get_config(),
await self._resolve_config(),
{
"query": GRAPHQL_QUERIES["entities"],
"variables": {
Expand Down
Loading
Loading