Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/dbt_mcp/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,16 @@ def load_config(enable_proxied_tools: bool = True) -> Config:
settings = DbtMcpSettings() # type: ignore

inner_credentials = CredentialsProvider(settings)
credentials_provider = ElicitingCredentialsProvider(inner_credentials)
eliciting_credentials = ElicitingCredentialsProvider(inner_credentials)

# Platform providers get eliciting wrapper when platform toolsets are active.
# CLI-only users get raw credentials — defense-in-depth alongside register.py's
# allowlist gating which already prevents platform tool calls for these users.
platform_credentials = (
eliciting_credentials
if settings.any_platform_toolset_active
else inner_credentials
)

# Set default warn error options if not provided
if settings.dbt_warn_error_options is None:
Expand All @@ -127,34 +136,35 @@ def load_config(enable_proxied_tools: bool = True) -> Config:
if getattr(settings, attr_name, False)
}

# Proxied tools still gated on explicit opt-in flag
# Proxied tools stay on inner_credentials — lifespan-time registration,
# request_ctx is None, elicitation impossible by design
proxied_tool_config_provider = None
if enable_proxied_tools:
proxied_tool_config_provider = DefaultProxiedToolConfigProvider(
credentials_provider=inner_credentials
)

admin_api_config_provider = DefaultAdminApiConfigProvider(
credentials_provider=inner_credentials,
credentials_provider=platform_credentials,
)
admin_client = DbtAdminAPIClient(admin_api_config_provider)
multi_project_discovery_config_provider = MultiProjectDiscoveryConfigProvider(
credentials_provider=inner_credentials,
credentials_provider=platform_credentials,
admin_client=admin_client,
)
multi_project_semantic_layer_config_provider = (
MultiProjectSemanticLayerConfigProvider(
credentials_provider=inner_credentials,
credentials_provider=platform_credentials,
admin_client=admin_client,
metrics_related_max=settings.sl_metrics_related_max,
max_response_chars=settings.sl_metrics_max_response_chars,
)
)
discovery_config_provider = DefaultDiscoveryConfigProvider(
credentials_provider=inner_credentials,
credentials_provider=platform_credentials,
)
semantic_layer_config_provider = DefaultSemanticLayerConfigProvider(
credentials_provider=inner_credentials,
credentials_provider=platform_credentials,
metrics_related_max=settings.sl_metrics_related_max,
max_response_chars=settings.sl_metrics_max_response_chars,
)
Expand Down Expand Up @@ -211,6 +221,6 @@ def load_config(enable_proxied_tools: bool = True) -> Config:
multi_project_semantic_layer_config_provider=multi_project_semantic_layer_config_provider,
semantic_layer_config_provider=semantic_layer_config_provider,
admin_api_config_provider=admin_api_config_provider,
credentials_provider=credentials_provider,
credentials_provider=eliciting_credentials,
lsp_config=lsp_config,
)
9 changes: 2 additions & 7 deletions src/dbt_mcp/config/config_providers/admin_api.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from dbt_mcp.config.headers import AdminApiHeadersProvider

from .base import AdminApiConfig, ConfigProvider

if TYPE_CHECKING:
from dbt_mcp.config.credentials import CredentialsProvider
from .base import AdminApiConfig, ConfigProvider, CredentialsProviderProtocol


class DefaultAdminApiConfigProvider(ConfigProvider[AdminApiConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
def __init__(self, credentials_provider: CredentialsProviderProtocol):
self.credentials_provider = credentials_provider

async def get_config(self) -> AdminApiConfig:
Expand Down
16 changes: 16 additions & 0 deletions src/dbt_mcp/config/config_providers/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol

from dbt_mcp.config.headers import (
HeadersProvider,
ProxiedToolHeadersProvider,
TokenProvider,
)

if TYPE_CHECKING:
from dbt_mcp.config.settings import DbtMcpSettings


class CredentialsProviderProtocol(Protocol):
"""Structural interface for credential providers.

Both CredentialsProvider and ElicitingCredentialsProvider satisfy this
protocol. Config providers accept this protocol so either can be injected.
"""

async def get_credentials(self) -> tuple[DbtMcpSettings, TokenProvider]: ...


class ConfigProvider[ConfigType](ABC):
@abstractmethod
Expand Down
12 changes: 8 additions & 4 deletions src/dbt_mcp/config/config_providers/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@
from dbt_mcp.errors import NotFoundError

if TYPE_CHECKING:
from dbt_mcp.config.credentials import CredentialsProvider
from dbt_mcp.dbt_admin.client import DbtAdminAPIClient

from .base import ConfigProvider, DiscoveryConfig, MultiProjectConfigProvider
from .base import (
ConfigProvider,
CredentialsProviderProtocol,
DiscoveryConfig,
MultiProjectConfigProvider,
)


class DefaultDiscoveryConfigProvider(ConfigProvider[DiscoveryConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
def __init__(self, credentials_provider: CredentialsProviderProtocol):
self.credentials_provider = credentials_provider

async def get_config(self) -> DiscoveryConfig:
Expand All @@ -35,7 +39,7 @@ class MultiProjectDiscoveryConfigProvider(MultiProjectConfigProvider[DiscoveryCo
def __init__(
self,
*,
credentials_provider: CredentialsProvider,
credentials_provider: CredentialsProviderProtocol,
admin_client: DbtAdminAPIClient,
):
self.credentials_provider = credentials_provider
Expand Down
7 changes: 4 additions & 3 deletions src/dbt_mcp/config/config_providers/proxied_tool.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

from dbt_mcp.config.headers import ProxiedToolHeadersProvider
from dbt_mcp.config.credentials import CredentialsProvider
from dbt_mcp.errors.common import MissingHostError

from .base import ConfigProvider, ProxiedToolConfig
from .base import ConfigProvider, CredentialsProviderProtocol, ProxiedToolConfig


class DefaultProxiedToolConfigProvider(ConfigProvider[ProxiedToolConfig]):
def __init__(self, credentials_provider: CredentialsProvider):
def __init__(self, credentials_provider: CredentialsProviderProtocol):
self.credentials_provider = credentials_provider

async def get_config(self) -> ProxiedToolConfig:
Expand Down
20 changes: 15 additions & 5 deletions src/dbt_mcp/config/config_providers/semantic_layer.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from dbt_mcp.config.credentials import CredentialsProvider
from __future__ import annotations

from typing import TYPE_CHECKING

from dbt_mcp.config.headers import (
SemanticLayerHeadersProvider,
)
from dbt_mcp.dbt_admin.client import DbtAdminAPIClient
from dbt_mcp.errors import NotFoundError

from .base import ConfigProvider, MultiProjectConfigProvider, SemanticLayerConfig
from .base import (
ConfigProvider,
CredentialsProviderProtocol,
MultiProjectConfigProvider,
SemanticLayerConfig,
)

if TYPE_CHECKING:
from dbt_mcp.dbt_admin.client import DbtAdminAPIClient


class DefaultSemanticLayerConfigProvider(ConfigProvider[SemanticLayerConfig]):
def __init__(
self,
credentials_provider: CredentialsProvider,
credentials_provider: CredentialsProviderProtocol,
*,
metrics_related_max: int = 10,
max_response_chars: int = 16000,
Expand Down Expand Up @@ -50,7 +60,7 @@ class MultiProjectSemanticLayerConfigProvider(
):
def __init__(
self,
credentials_provider: CredentialsProvider,
credentials_provider: CredentialsProviderProtocol,
admin_client: DbtAdminAPIClient,
*,
metrics_related_max: int = 10,
Expand Down
34 changes: 34 additions & 0 deletions src/dbt_mcp/config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,40 @@ def actual_disable_sql(self) -> bool:
return self.disable_remote
return True

@property
def any_platform_toolset_active(self) -> bool:
"""Whether any platform toolset is active under the current config mode.

Three modes (mirrors register.py:should_register_tool and config.py
enabled_toolsets/disabled_toolsets computation):
1. Allowlist mode (any ENABLE_* flag set): only explicitly enabled
platform toolsets count.
2. Denylist/default mode: platform toolsets active unless disabled.

Platform toolsets: semantic_layer, discovery, admin_api, sql.
Local toolsets: dbt_cli, dbt_codegen, lsp, product_docs, mcp_server_metadata.
"""
has_any_enable = self.enable_tools is not None or any((
self.enable_semantic_layer, self.enable_discovery,
self.enable_admin_api, self.enable_sql,
self.enable_dbt_cli, self.enable_dbt_codegen,
self.enable_lsp, self.enable_product_docs,
self.enable_mcp_server_metadata,
))
if has_any_enable:
return any((
self.enable_semantic_layer,
self.enable_discovery,
self.enable_admin_api,
self.enable_sql,
Comment on lines +191 to +195
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Detect platform tools enabled via DBT_MCP_ENABLE_TOOLS

In allowlist mode, any_platform_toolset_active is entered when enable_tools is set, but this branch only checks enable_* toolset flags and ignores individually enabled tools from DBT_MCP_ENABLE_TOOLS. That means configs that enable a platform tool by name (without setting DBT_MCP_ENABLE_DISCOVERY/_SEMANTIC_LAYER/etc.) are misclassified as “no platform toolsets active”, so load_config() injects raw CredentialsProvider into platform providers and platform calls fail with MissingHostError instead of triggering elicitation.

Useful? React with 👍 / 👎.

))
return any((
not self.disable_semantic_layer,
not self.disable_discovery,
not self.disable_admin_api,
not self.actual_disable_sql,
))

@property
def actual_host_prefix(self) -> str | None:
if self.host_prefix is not None:
Expand Down
31 changes: 13 additions & 18 deletions src/dbt_mcp/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,16 @@ def __init__(
asyncio.Task[LSPConnectionProviderProtocol] | None
) = None

async def _is_multi_project(self) -> bool:
try:
(
settings,
_,
) = await self.config.credentials_provider.inner_provider.get_credentials()
except MissingHostError as e:
logger.warning(
"Could not resolve credentials — defaulting to single-project mode: %s",
e,
)
return False
return bool(
settings.dbt_project_ids is not None and len(settings.dbt_project_ids) > 0
)
def _is_multi_project(self) -> bool:
"""Check multi-project mode from settings. No credential fetch.

Note: dbt_project_ids may be populated later by the OAuth flow
in CredentialsProvider.get_credentials(). Users who rely on OAuth-derived project IDs
without setting DBT_PROJECT_IDS env var will see single-project
mode until the first platform tool call triggers OAuth.
"""
project_ids = self.config.credentials_provider.settings.dbt_project_ids
return bool(project_ids is not None and len(project_ids) > 0)

async def call_tool(
self, name: str, arguments: dict[str, Any]
Expand All @@ -86,7 +81,7 @@ async def call_tool(
result = None
start_time = int(time.time() * 1000)
try:
if await self._is_multi_project():
if self._is_multi_project():
result = await self.multi_project_mcp.call_tool(name, arguments)
else:
result = await self.single_project_mcp.call_tool(name, arguments)
Expand Down Expand Up @@ -127,7 +122,7 @@ async def call_tool(
return result

async def list_tools(self) -> list[Tool]:
if await self._is_multi_project():
if self._is_multi_project():
return await self.multi_project_mcp.list_tools()
return await self.single_project_mcp.list_tools()

Expand All @@ -143,7 +138,7 @@ async def app_lifespan(server: FastMCP[Any]) -> AsyncIterator[bool | None]:
# this avoids anyio cancel scope violations (see issue #498)
if (
server.config.proxied_tool_config_provider
and not await server._is_multi_project()
and not server._is_multi_project()
):
try:
logger.info("Registering proxied tools")
Expand Down
3 changes: 1 addition & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from contextlib import contextmanager
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, patch
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -110,7 +110,6 @@ def path(relpath: str):
helpers.write_file(rel, content)
with patch(
"dbt_mcp.mcp.server.DbtMCP._is_multi_project",
new_callable=AsyncMock,
return_value=False,
):
try:
Expand Down
Loading