diff --git a/.changes/unreleased/Enhancement or New Feature-20260304-120000.yaml b/.changes/unreleased/Enhancement or New Feature-20260304-120000.yaml new file mode 100644 index 000000000..1ed24c7a5 --- /dev/null +++ b/.changes/unreleased/Enhancement or New Feature-20260304-120000.yaml @@ -0,0 +1,3 @@ +kind: Enhancement or New feature +body: "Adds product docs tools: search_product_docs and get_product_doc_pages" +time: 2026-03-04T12:00:00.000000-06:00 diff --git a/README.md b/README.md index 3254683cd..f1b16faf7 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,10 @@ The dbt MCP server architecture allows for your agent to connect to a variety of - `fusion.get_column_lineage`: Traces column-level lineage via dbt Platform. - `get_column_lineage`: Traces column-level lineage locally (requires dbt-lsp via dbt Labs VSCE). +### Product Docs +- `get_product_doc_pages`: Fetches the full Markdown content of one or more docs.getdbt.com pages by path or URL. +- `search_product_docs`: Searches docs.getdbt.com for pages matching a query; returns titles, URLs, and descriptions ranked by relevance. Use get_product_doc_pages to fetch full content. + ### MCP Server Metadata - `get_mcp_server_branch`: Returns the current git branch of the running dbt MCP server. - `get_mcp_server_version`: Returns the current version of the dbt MCP server. diff --git a/docs/diagram.d2 b/docs/diagram.d2 index 3d0803808..dc9ae7b72 100644 --- a/docs/diagram.d2 +++ b/docs/diagram.d2 @@ -56,6 +56,10 @@ tools: Tools { style.border-radius: 8 } + product_docs: Product Docs { + style.border-radius: 8 + } + mcp_server_metadata: MCP Server Metadata { style.border-radius: 8 } diff --git a/src/dbt_mcp/config/config.py b/src/dbt_mcp/config/config.py index 8f2af03f6..6c6385c03 100644 --- a/src/dbt_mcp/config/config.py +++ b/src/dbt_mcp/config/config.py @@ -28,6 +28,7 @@ Toolset.DISCOVERY: "disable_discovery", Toolset.DBT_LSP: "disable_lsp", Toolset.SQL: "actual_disable_sql", + Toolset.PRODUCT_DOCS: "disable_product_docs", Toolset.MCP_SERVER_METADATA: "disable_mcp_server_metadata", } @@ -39,6 +40,7 @@ Toolset.DISCOVERY: "enable_discovery", Toolset.DBT_LSP: "enable_lsp", Toolset.SQL: "enable_sql", + Toolset.PRODUCT_DOCS: "enable_product_docs", Toolset.MCP_SERVER_METADATA: "enable_mcp_server_metadata", } diff --git a/src/dbt_mcp/config/settings.py b/src/dbt_mcp/config/settings.py index 8e458aa2b..621a4dbf2 100644 --- a/src/dbt_mcp/config/settings.py +++ b/src/dbt_mcp/config/settings.py @@ -98,6 +98,7 @@ class DbtMcpSettings(BaseSettings): None, alias="DISABLE_TOOLS" ) disable_lsp: bool | None = Field(None, alias="DISABLE_LSP") + disable_product_docs: bool = Field(False, alias="DISABLE_PRODUCT_DOCS") disable_mcp_server_metadata: bool = Field(True, alias="DISABLE_MCP_SERVER_METADATA") # Enable tool settings (allowlist) @@ -111,6 +112,7 @@ class DbtMcpSettings(BaseSettings): enable_discovery: bool = Field(False, alias="DBT_MCP_ENABLE_DISCOVERY") enable_lsp: bool = Field(False, alias="DBT_MCP_ENABLE_LSP") enable_sql: bool = Field(False, alias="DBT_MCP_ENABLE_SQL") + enable_product_docs: bool = Field(False, alias="DBT_MCP_ENABLE_PRODUCT_DOCS") enable_mcp_server_metadata: bool = Field( False, alias="DBT_MCP_ENABLE_MCP_SERVER_METADATA" ) @@ -135,6 +137,7 @@ def __repr__(self): f"disable_discovery={self.disable_discovery}, " f"disable_admin_api={self.disable_admin_api}, " f"disable_sql={self.disable_sql}, " + f"disable_product_docs={self.disable_product_docs}, " f"disable_tools={self.disable_tools}, " f"disable_lsp={self.disable_lsp}, " # enable settings @@ -145,6 +148,7 @@ def __repr__(self): f"enable_dbt_codegen={self.enable_dbt_codegen}, " f"enable_discovery={self.enable_discovery}, " f"enable_lsp={self.enable_lsp}, " + f"enable_product_docs={self.enable_product_docs}, " f"enable_sql={self.enable_sql}, " # everything else f"dbt_prod_env_id={self.dbt_prod_env_id}, " diff --git a/src/dbt_mcp/mcp/server.py b/src/dbt_mcp/mcp/server.py index 3a4995ba0..db5c7b716 100644 --- a/src/dbt_mcp/mcp/server.py +++ b/src/dbt_mcp/mcp/server.py @@ -15,6 +15,7 @@ from dbt_mcp.dbt_admin.tools import register_admin_api_tools from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools from dbt_mcp.dbt_codegen.tools import register_dbt_codegen_tools +from dbt_mcp.product_docs.tools import register_product_docs_tools from dbt_mcp.discovery.tools import register_discovery_tools from dbt_mcp.mcp_server_metadata.tools import register_mcp_server_tools from dbt_mcp.lsp.providers.local_lsp_client_provider import LocalLSPClientProvider @@ -166,6 +167,16 @@ async def create_dbt_mcp(config: Config) -> DbtMCP: enabled_toolsets = config.enabled_toolsets disabled_toolsets = config.disabled_toolsets + # Register product docs tools (always available, fetches from public docs.getdbt.com) + logger.info("Registering product docs tools") + register_product_docs_tools( + dbt_mcp, + disabled_tools=disabled_tools, + enabled_tools=enabled_tools, + enabled_toolsets=enabled_toolsets, + disabled_toolsets=disabled_toolsets, + ) + # Register MCP server tools (always available) logger.info("Registering MCP server tools") register_mcp_server_tools( diff --git a/src/dbt_mcp/product_docs/__init__.py b/src/dbt_mcp/product_docs/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/dbt_mcp/product_docs/__init__.py @@ -0,0 +1 @@ + diff --git a/src/dbt_mcp/product_docs/client.py b/src/dbt_mcp/product_docs/client.py new file mode 100644 index 000000000..c4728bae5 --- /dev/null +++ b/src/dbt_mcp/product_docs/client.py @@ -0,0 +1,425 @@ +"""Client for fetching, caching, and searching docs.getdbt.com content. + +Provides ``ProductDocsClient``, which handles HTTP requests, in-memory +caching, and keyword search/ranking over the llms.txt index and +llms-full.txt full-text corpus. + +Caches live for the lifetime of the MCP server process; restart the +server to refresh. +""" + +import logging +import re +from typing import Any + +import httpx + +logger = logging.getLogger(__name__) + +LLMS_TXT_URL = "https://docs.getdbt.com/llms.txt" +LLMS_FULL_TXT_URL = "https://docs.getdbt.com/llms-full.txt" +DOCS_BASE_URL = "https://docs.getdbt.com" + +# Cap page content length in tool responses to avoid Cursor IDE freezing. +MAX_CONTENT_CHARS_PER_PAGE = 28_000 + +_STOP_WORDS = frozenset( + { + "a", + "an", + "the", + "in", + "on", + "of", + "for", + "to", + "and", + "or", + "is", + "it", + "how", + "do", + "i", + "my", + "what", + "with", + "from", + "can", + "does", + "about", + "using", + "use", + "set", + "up", + "dbt", + } +) + +_ABBREVIATION_EXPANSIONS: dict[str, list[str]] = { + "udf": ["user-defined function", "user-defined functions"], + "udfs": ["user-defined function", "user-defined functions"], + "ci": ["continuous integration"], + "cd": ["continuous deployment"], + "ci/cd": ["continuous integration", "continuous deployment"], + "sl": ["semantic layer"], + "ide": ["studio ide", "ide"], + "cli": [ + "command line", + "dbt cli", + "dbt platform cli", + "fusion cli", + "dbt core cli", + ], + "vc": ["version control"], + "pr": ["pull request"], + "mr": ["merge request"], + "prs": ["pull requests"], + "env": ["environment"], + "envs": ["environments"], + "sso": ["single sign-on"], + "oauth": ["open authentication"], + "repo": ["repository"], + "gh": ["github"], + "aws": ["amazon web services"], + "gcp": ["google cloud platform"], +} + +_LLMS_TXT_ENTRY_RE = re.compile(r"^-\s+\[([^\]]+)\]\(([^)]+)\)(?::\s*(.+))?$") + +# Relevance scoring weights for search_index ranking. +# Higher weights push results toward the top when terms appear in +# more-specific fields (title > description > section). +SCORE_KEYWORD_IN_TITLE = 10 +SCORE_KEYWORD_IN_DESCRIPTION = 5 +SCORE_KEYWORD_IN_SECTION = 1 +SCORE_BIGRAM_IN_TITLE = 20 +SCORE_BIGRAM_IN_DESCRIPTION = 12 +SCORE_EXACT_QUERY_IN_TITLE = 40 +SCORE_EXACT_QUERY_IN_DESCRIPTION = 25 +SCORE_ALL_KEYWORDS_MATCHED = 15 +SCORE_EXACT_TITLE_MATCH = 50 +SCORE_TITLE_FOCUS_MAX = 15 # scaled by len(query) / len(title) + + +# --------------------------------------------------------------------------- +# Pure helper functions (no state needed) +# --------------------------------------------------------------------------- + + +def truncate_content(content: str, max_chars: int, url: str) -> str: + """Truncate page content to *max_chars* and append a note with the URL.""" + if len(content) <= max_chars: + return content + return ( + content[:max_chars].rstrip() + + f"\n\n---\n*Content truncated for length. Full page: {url}*" + ) + + +def display_url(url: str) -> str: + """Convert an internal ``.md`` URL to a browser-friendly URL.""" + if url.endswith(".md"): + return url[:-3] + return url + + +def normalize_doc_url(path: str) -> str: + """Turn a path or URL into a full ``docs.getdbt.com`` ``.md`` URL. + + Raises ``ValueError`` if the resulting URL is not on docs.getdbt.com. + """ + url = path.strip() + + if not url.startswith("http"): + url = url.lstrip("/") + url = f"{DOCS_BASE_URL}/{url}" + + url = url.rstrip("/") + + if not url.endswith(".md"): + url = f"{url}.md" + + if not url.startswith(f"{DOCS_BASE_URL}/"): + raise ValueError(f"URL must be on docs.getdbt.com, got: {url}") + + return url + + +def parse_llms_full_txt(text: str) -> list[dict[str, str]]: + """Parse ``llms-full.txt`` into a list of page entries with full content.""" + pages: list[dict[str, str]] = [] + url_pattern = re.compile(r"https://docs\.getdbt\.com/[^\s\]\)]+") + + chunks = re.split(r"\n---\n", text) + + for chunk in chunks: + lines = chunk.strip().splitlines() + if not lines: + continue + + title = "" + url = "" + for line in lines[:10]: + stripped = line.strip() + if stripped.startswith("### ") and not title: + title = stripped.removeprefix("### ").strip() + if not url: + url_match = url_pattern.search(stripped) + if url_match: + url = url_match.group(0) + if title and url: + break + + if not url: + continue + + content_lower = chunk.lower() + pages.append({"url": url, "title": title, "content_lower": content_lower}) + + return pages + + +def parse_llms_txt(text: str) -> list[dict[str, str]]: + """Parse ``llms.txt`` markdown into a list of page entries.""" + entries: list[dict[str, str]] = [] + current_section = "" + + for line in text.splitlines(): + stripped = line.strip() + + if stripped.startswith("## "): + current_section = stripped.removeprefix("## ").strip() + continue + + match = _LLMS_TXT_ENTRY_RE.match(stripped) + if not match: + continue + + title = match.group(1).strip() + url = match.group(2).strip() + description = (match.group(3) or "").strip() + + entries.append( + { + "title": title, + "url": url, + "description": description, + "section": current_section, + "title_lower": title.lower(), + "description_lower": description.lower(), + "section_lower": current_section.lower(), + } + ) + + return entries + + +def score_index_entry( + entry: dict[str, str], + keywords: list[str], + bigrams: list[str], + query_lower: str, +) -> float | None: + """Score a single llms.txt index entry against the search terms. + + Returns the relevance score, or ``None`` if no keyword matched at all. + + The scoring heuristic ranks entries by *where* a match occurs + (title > description > section) and gives bonuses for bigram + matches, full-query substring matches, and exact title matches. + """ + title_lower = entry["title_lower"] + desc_lower = entry["description_lower"] + section_lower = entry["section_lower"] + searchable = f"{title_lower} {desc_lower} {section_lower}" + + matching = [kw for kw in keywords if kw in searchable] + if not matching: + return None + + score: float = 0 + + for kw in matching: + if kw in title_lower: + score += SCORE_KEYWORD_IN_TITLE + if kw in desc_lower: + score += SCORE_KEYWORD_IN_DESCRIPTION + if kw in section_lower: + score += SCORE_KEYWORD_IN_SECTION + + for bigram in bigrams: + if bigram in title_lower: + score += SCORE_BIGRAM_IN_TITLE + if bigram in desc_lower: + score += SCORE_BIGRAM_IN_DESCRIPTION + + if query_lower in title_lower: + score += SCORE_EXACT_QUERY_IN_TITLE + if query_lower in desc_lower: + score += SCORE_EXACT_QUERY_IN_DESCRIPTION + + if len(matching) == len(keywords): + score += SCORE_ALL_KEYWORDS_MATCHED + + if query_lower == title_lower: + score += SCORE_EXACT_TITLE_MATCH + + if title_lower and query_lower in title_lower: + focus_ratio = len(query_lower) / len(title_lower) + score += focus_ratio * SCORE_TITLE_FOCUS_MAX + + return score + + +def expand_keywords(query: str) -> list[str]: + """Extract keywords from *query*, filtering stop words and expanding abbreviations.""" + all_words = query.lower().split() + if not all_words: + return [] + + keywords = [w for w in all_words if w not in _STOP_WORDS] + if not keywords: + keywords = list(all_words) + + expanded: list[str] = [] + for kw in keywords: + if kw in _ABBREVIATION_EXPANSIONS: + expanded.extend( + term for exp in _ABBREVIATION_EXPANSIONS[kw] for term in exp.split() + ) + for term in expanded: + if term not in keywords and term not in _STOP_WORDS: + keywords.append(term) + + return keywords + + +# --------------------------------------------------------------------------- +# Client class +# --------------------------------------------------------------------------- + + +class ProductDocsClient: + """Async client for fetching and searching docs.getdbt.com content. + + Caches are simple dicts that live for the lifetime of the instance + (and thus the MCP server process). Restart the server to refresh. + """ + + def __init__(self) -> None: + self._cache: dict[str, Any] = {} + + # -- fetchers ------------------------------------------------------------ + + async def get_index(self) -> list[dict[str, str]]: + """Return the cached llms.txt index, fetching on first call.""" + if "index" not in self._cache: + logger.info("Fetching llms.txt index from %s", LLMS_TXT_URL) + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + response = await client.get(LLMS_TXT_URL) + response.raise_for_status() + self._cache["index"] = parse_llms_txt(response.text) + logger.info("Cached llms.txt index: %d pages", len(self._cache["index"])) + return self._cache["index"] + + async def get_full_text_index(self) -> list[dict[str, str]]: + """Return the cached llms-full.txt page index, fetching on first call.""" + if "full_text" not in self._cache: + logger.info("Fetching llms-full.txt from %s", LLMS_FULL_TXT_URL) + async with httpx.AsyncClient( + timeout=120.0, follow_redirects=True + ) as client: + response = await client.get(LLMS_FULL_TXT_URL) + response.raise_for_status() + self._cache["full_text"] = parse_llms_full_txt(response.text) + logger.info("Cached llms-full.txt: %d pages", len(self._cache["full_text"])) + return self._cache["full_text"] + + async def get_page(self, url: str) -> str: + """Fetch a page with caching. + + Returns the page content as a string on success. + Raises httpx.HTTPStatusError on 4xx/5xx responses. + Raises httpx.RequestError on network/connection failures. + """ + if url not in self._cache: + logger.info("Fetching product doc page: %s", url) + async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + response = await client.get(url) + response.raise_for_status() + self._cache[url] = response.text + return self._cache[url] + + # -- search -------------------------------------------------------------- + + async def search_index(self, query: str) -> list[dict[str, str]]: + """Keyword search over the llms.txt index with relevance ranking.""" + index = await self.get_index() + query_lower = query.lower() + all_words = query_lower.split() + + if not all_words: + return [] + + keywords = expand_keywords(query) + + bigrams = [] + for i in range(len(all_words) - 1): + bigram = f"{all_words[i]} {all_words[i + 1]}" + if not all(w in _STOP_WORDS for w in [all_words[i], all_words[i + 1]]): + bigrams.append(bigram) + + scored: list[tuple[float, dict[str, str]]] = [] + + for entry in index: + score = score_index_entry(entry, keywords, bigrams, query_lower) + if score is not None: + scored.append((score, entry)) + + scored.sort(key=lambda x: x[0], reverse=True) + + results = [] + for _, entry in scored[:10]: + result: dict[str, str] = { + "title": entry["title"], + "url": display_url(entry["url"]), + } + if entry["description"]: + result["description"] = entry["description"] + if entry["section"]: + result["section"] = entry["section"] + results.append(result) + + return results + + async def search_full_text( + self, keywords: list[str], max_results: int = 10 + ) -> list[dict[str, str]]: + """Search llms-full.txt page content for keyword matches (OR logic).""" + pages = await self.get_full_text_index() + scored: list[tuple[int, dict[str, str]]] = [] + + keywords_lower = [kw.lower() for kw in keywords] + + for page in pages: + content = page["content_lower"] + hit_count = 0 + keywords_matched = 0 + for kw in keywords_lower: + count = content.count(kw) + if count > 0: + keywords_matched += 1 + hit_count += count + + if keywords_matched == 0: + continue + + score = hit_count + (keywords_matched * 50) + scored.append((score, page)) + + scored.sort(key=lambda x: x[0], reverse=True) + + return [ + {"url": display_url(entry["url"]), "title": entry["title"]} + for _, entry in scored[:max_results] + ] diff --git a/src/dbt_mcp/product_docs/tools.py b/src/dbt_mcp/product_docs/tools.py new file mode 100644 index 000000000..0bfe2af88 --- /dev/null +++ b/src/dbt_mcp/product_docs/tools.py @@ -0,0 +1,219 @@ +"""Product Docs MCP tool definitions. + +Registers two tools that let AI agents query the public dbt product +documentation at docs.getdbt.com in real time: + +- ``search_product_docs`` — keyword search against the llms.txt index + (with query-expansion retry for abbreviations/synonyms) +- ``get_product_doc_pages`` — fetch one or more pages as clean Markdown +""" + +import asyncio +import logging +from dataclasses import dataclass, field + +import httpx +from mcp.server.fastmcp import FastMCP +from pydantic import Field + +from dbt_mcp.product_docs.client import ( + MAX_CONTENT_CHARS_PER_PAGE, + ProductDocsClient, + display_url, + expand_keywords, + normalize_doc_url, + truncate_content, +) +from dbt_mcp.product_docs.types import ( + DocSearchResult, + GetProductDocPagesResponse, + ProductDocPageResponse, + SearchProductDocsResponse, +) +from dbt_mcp.prompts.prompts import get_prompt +from dbt_mcp.tools.definitions import dbt_mcp_tool +from dbt_mcp.tools.register import register_tools +from dbt_mcp.tools.tool_names import ToolName +from dbt_mcp.tools.toolsets import Toolset + +logger = logging.getLogger(__name__) + +QUERY_EXPANSION_THRESHOLD = 3 + + +@dataclass +class ProductDocsToolContext: + client: ProductDocsClient = field(default_factory=ProductDocsClient) + + +def _dict_to_doc_search_result(entry: dict[str, str]) -> DocSearchResult: + return DocSearchResult( + title=entry.get("title", ""), + url=entry.get("url", ""), + description=entry.get("description", ""), + section=entry.get("section", ""), + ) + + +async def _fetch_page(client: ProductDocsClient, url: str) -> ProductDocPageResponse: + """Fetch a single page, returning a typed response with error handling.""" + try: + normalized = normalize_doc_url(url) + except ValueError as e: + return ProductDocPageResponse(url=url, content="", error=str(e)) + try: + content = await client.get_page(normalized) + except httpx.HTTPStatusError as e: + return ProductDocPageResponse( + url=display_url(normalized), + content="", + error=f"Page not found or unavailable: {normalized} (HTTP {e.response.status_code})", + ) + except httpx.RequestError as e: + return ProductDocPageResponse( + url=display_url(normalized), + content="", + error=f"Failed to fetch page: {normalized} ({e})", + ) + + content = truncate_content( + content, MAX_CONTENT_CHARS_PER_PAGE, display_url(normalized) + ) + return ProductDocPageResponse(url=display_url(normalized), content=content) + + +@dbt_mcp_tool( + description=get_prompt("product_docs/search_product_docs"), + title="Search Product Docs", + # read_only: the in-memory cache is internal process state, + # not an externally observable side-effect. + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, + open_world_hint=True, +) +async def search_product_docs( + context: ProductDocsToolContext, query: str +) -> SearchProductDocsResponse: + """Search docs.getdbt.com for pages matching a keyword query. + + Args: + query: Search terms to match against page titles and descriptions. + """ + if not query.strip(): + return SearchProductDocsResponse( + query=query, + total_matches=0, + showing=0, + results=[], + error="Query must not be empty.", + ) + + client = context.client + results = await client.search_index(query) + used_query_expansion = False + + if len(results) < QUERY_EXPANSION_THRESHOLD: + expanded_keywords = expand_keywords(query) + expanded_query = " ".join(expanded_keywords) + if expanded_query and expanded_query.lower() != query.strip().lower(): + try: + expansion_results = await client.search_index(expanded_query) + seen_urls = {r["url"] for r in results} + for entry in expansion_results: + if entry["url"] not in seen_urls: + results.append(entry) + seen_urls.add(entry["url"]) + if expansion_results: + used_query_expansion = True + except Exception as e: + logger.warning("Query expansion fallback failed: %s", e) + + doc_results = [_dict_to_doc_search_result(r) for r in results] + + return SearchProductDocsResponse( + query=query, + total_matches=len(doc_results), + showing=len(doc_results), + results=doc_results, + search_method="query_expansion" if used_query_expansion else None, + ) + + +@dbt_mcp_tool( + description=get_prompt("product_docs/get_product_doc_pages"), + title="Get Product Doc Pages", + read_only_hint=True, + destructive_hint=False, + idempotent_hint=True, + open_world_hint=True, +) +async def get_product_doc_pages( + context: ProductDocsToolContext, + paths: list[str] = Field( + description="List of docs.getdbt.com URLs or relative paths to fetch " + "(e.g. ['/docs/build/incremental-models', '/docs/build/models']). " + "Max 10 pages per call.", + ), +) -> GetProductDocPagesResponse: + """Fetch the full Markdown content of one or more docs.getdbt.com pages in parallel. + + Args: + paths: List of docs.getdbt.com URLs or relative paths to fetch. + """ + paths = paths[:10] + results = await asyncio.gather( + *[_fetch_page(context.client, path) for path in paths], + return_exceptions=True, + ) + pages: list[ProductDocPageResponse] = [] + for i, result in enumerate(results): + if isinstance(result, BaseException): + try: + err_url = display_url(normalize_doc_url(paths[i])) + except Exception: + err_url = paths[i] + logger.warning("Failed to fetch %s: %s", err_url, result) + pages.append( + ProductDocPageResponse( + url=err_url, + content="", + error=f"Failed to fetch page: {err_url} ({result})", + ) + ) + else: + pages.append(result) + + return GetProductDocPagesResponse(pages=pages) + + +PRODUCT_DOCS_TOOLS = [ + search_product_docs, + get_product_doc_pages, +] + + +def register_product_docs_tools( + dbt_mcp: FastMCP, + *, + disabled_tools: set[ToolName], + enabled_tools: set[ToolName] | None, + enabled_toolsets: set[Toolset], + disabled_toolsets: set[Toolset], +) -> None: + """Register Product Docs tools.""" + shared_client = ProductDocsClient() + + def bind_context() -> ProductDocsToolContext: + return ProductDocsToolContext(client=shared_client) + + register_tools( + dbt_mcp, + tool_definitions=[ + tool.adapt_context(bind_context) for tool in PRODUCT_DOCS_TOOLS + ], + disabled_tools=disabled_tools, + enabled_tools=enabled_tools, + enabled_toolsets=enabled_toolsets, + disabled_toolsets=disabled_toolsets, + ) diff --git a/src/dbt_mcp/product_docs/types.py b/src/dbt_mcp/product_docs/types.py new file mode 100644 index 000000000..88def9f20 --- /dev/null +++ b/src/dbt_mcp/product_docs/types.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass, field + + +@dataclass +class DocSearchResult: + title: str + url: str + description: str = "" + section: str = "" + + +@dataclass +class SearchProductDocsResponse: + query: str + total_matches: int + showing: int + results: list[DocSearchResult] + search_method: str | None = None + error: str | None = None + + +@dataclass +class ProductDocPageResponse: + url: str + content: str + error: str | None = None + + +@dataclass +class GetProductDocPagesResponse: + pages: list[ProductDocPageResponse] = field(default_factory=list) diff --git a/src/dbt_mcp/prompts/product_docs/get_product_doc_pages.md b/src/dbt_mcp/prompts/product_docs/get_product_doc_pages.md new file mode 100644 index 000000000..247ae5c35 --- /dev/null +++ b/src/dbt_mcp/prompts/product_docs/get_product_doc_pages.md @@ -0,0 +1,8 @@ +Retrieve the full Markdown content of one or more dbt documentation pages from docs.getdbt.com. Pass a list of URLs or relative paths (e.g. from search_product_docs results). Up to 10 pages can be fetched per call; pages are fetched in parallel for speed. + +IMPORTANT — how to present docs content to the user: +1. Give a direct answer first, then explain. Don't just summarize — be specific about WHERE to find a feature in the UI and HOW it works. +2. Structure around the user's task: lead with what they asked, not the doc's own structure. Include step-by-step actionable detail when the docs describe a workflow or UI feature. +3. Call out practical limitations and edge cases from the docs. +4. Close with clear guidance: if the feature only partially answers the question, say what else they can do and where. +5. ALWAYS include the docs page URL(s) as markdown hyperlinks at the end of your response, e.g. [Page title](https://docs.getdbt.com/...). diff --git a/src/dbt_mcp/prompts/product_docs/search_product_docs.md b/src/dbt_mcp/prompts/product_docs/search_product_docs.md new file mode 100644 index 000000000..646229753 --- /dev/null +++ b/src/dbt_mcp/prompts/product_docs/search_product_docs.md @@ -0,0 +1,5 @@ +Search the dbt product documentation at docs.getdbt.com for pages matching a query. Returns matching page titles, URLs, and descriptions ranked by relevance. Only returns metadata, not page content — use get_product_doc_pages with URLs from the results to retrieve full page content. + +If the title/description search finds few matches, it automatically falls back to a deep full-text search across all documentation content to find pages where the topic is discussed in the body text. + +If your first query returns few results, try rephrasing with synonyms or the full term (e.g. 'user-defined functions' instead of 'UDFs', 'version control' instead of 'git'). Use the abbreviations on the page as well. diff --git a/src/dbt_mcp/tools/human_descriptions.py b/src/dbt_mcp/tools/human_descriptions.py index 3aa58e2e6..40a3545cc 100644 --- a/src/dbt_mcp/tools/human_descriptions.py +++ b/src/dbt_mcp/tools/human_descriptions.py @@ -67,6 +67,9 @@ ToolName.FUSION_GET_COLUMN_LINEAGE: "Traces column-level lineage via dbt Platform.", # MCP Server tools ToolName.GET_MCP_SERVER_VERSION: "Returns the current version of the dbt MCP server.", + # Product docs tools + ToolName.SEARCH_PRODUCT_DOCS: "Searches docs.getdbt.com for pages matching a query; returns titles, URLs, and descriptions ranked by relevance. Use get_product_doc_pages to fetch full content.", + ToolName.GET_PRODUCT_DOC_PAGES: "Fetches the full Markdown content of one or more docs.getdbt.com pages by path or URL.", ToolName.GET_MCP_SERVER_BRANCH: "Returns the current git branch of the running dbt MCP server.", } diff --git a/src/dbt_mcp/tools/tool_names.py b/src/dbt_mcp/tools/tool_names.py index 699fdc3cb..9d8ee8311 100644 --- a/src/dbt_mcp/tools/tool_names.py +++ b/src/dbt_mcp/tools/tool_names.py @@ -73,6 +73,10 @@ class ToolName(Enum): FUSION_COMPILE_SQL = "fusion.compile_sql" FUSION_GET_COLUMN_LINEAGE = "fusion.get_column_lineage" + # Product Docs tools (docs.getdbt.com) + SEARCH_PRODUCT_DOCS = "search_product_docs" + GET_PRODUCT_DOC_PAGES = "get_product_doc_pages" + # MCP Server tools GET_MCP_SERVER_VERSION = "get_mcp_server_version" GET_MCP_SERVER_BRANCH = "get_mcp_server_branch" diff --git a/src/dbt_mcp/tools/toolsets.py b/src/dbt_mcp/tools/toolsets.py index 3c0fa8e7e..08533e154 100644 --- a/src/dbt_mcp/tools/toolsets.py +++ b/src/dbt_mcp/tools/toolsets.py @@ -19,6 +19,7 @@ class Toolset(Enum): ADMIN_API = "admin_api" DBT_CODEGEN = "dbt_codegen" DBT_LSP = "dbt_lsp" + PRODUCT_DOCS = "product_docs" MCP_SERVER_METADATA = "mcp_server_metadata" @@ -112,6 +113,10 @@ class Toolset(Enum): ToolName.FUSION_COMPILE_SQL, ToolName.FUSION_GET_COLUMN_LINEAGE, }, + Toolset.PRODUCT_DOCS: { + ToolName.SEARCH_PRODUCT_DOCS, + ToolName.GET_PRODUCT_DOC_PAGES, + }, Toolset.MCP_SERVER_METADATA: { ToolName.GET_MCP_SERVER_VERSION, ToolName.GET_MCP_SERVER_BRANCH, diff --git a/tests/integration/product_docs/test_product_docs.py b/tests/integration/product_docs/test_product_docs.py new file mode 100644 index 000000000..12addd77e --- /dev/null +++ b/tests/integration/product_docs/test_product_docs.py @@ -0,0 +1,110 @@ +import pytest + +from dbt_mcp.product_docs.client import ProductDocsClient +from dbt_mcp.product_docs.tools import ( + ProductDocsToolContext, + get_product_doc_pages, + search_product_docs, +) + + +@pytest.fixture +def context() -> ProductDocsToolContext: + return ProductDocsToolContext() + + +@pytest.fixture +def client() -> ProductDocsClient: + return ProductDocsClient() + + +class TestProductDocsClient: + @pytest.mark.asyncio + async def test_get_index(self, client: ProductDocsClient): + index = await client.get_index() + + assert isinstance(index, list) + assert len(index) > 0 + for entry in index[:3]: + assert "title" in entry + assert "url" in entry + + @pytest.mark.asyncio + async def test_search_index(self, client: ProductDocsClient): + results = await client.search_index("incremental models") + + assert isinstance(results, list) + assert len(results) > 0 + assert any("incremental" in r["title"].lower() for r in results) + + @pytest.mark.asyncio + async def test_get_page(self, client: ProductDocsClient): + content = await client.get_page( + "https://docs.getdbt.com/docs/build/incremental-models.md" + ) + + assert isinstance(content, str) + assert len(content) > 0 + + @pytest.mark.asyncio + async def test_search_full_text(self, client: ProductDocsClient): + results = await client.search_full_text(["incremental"], max_results=5) + + assert isinstance(results, list) + assert len(results) > 0 + for entry in results: + assert "url" in entry + assert "title" in entry + + +class TestProductDocsTools: + @pytest.mark.asyncio + async def test_search_product_docs(self, context: ProductDocsToolContext): + result = await search_product_docs.fn(context, "incremental models") + + assert result.error is None + assert result.total_matches > 0 + assert len(result.results) > 0 + + @pytest.mark.asyncio + async def test_search_product_docs_empty_query( + self, context: ProductDocsToolContext + ): + result = await search_product_docs.fn(context, " ") + + assert result.error is not None + assert result.total_matches == 0 + + @pytest.mark.asyncio + async def test_get_single_page(self, context: ProductDocsToolContext): + result = await get_product_doc_pages.fn( + context, ["/docs/build/incremental-models"] + ) + + assert len(result.pages) == 1 + page = result.pages[0] + assert page.error is None + assert len(page.content) > 0 + assert page.url.startswith("https://docs.getdbt.com/") + + @pytest.mark.asyncio + async def test_get_page_not_found(self, context: ProductDocsToolContext): + result = await get_product_doc_pages.fn( + context, ["/docs/this-page-does-not-exist-abc123"] + ) + + assert len(result.pages) == 1 + assert result.pages[0].error is not None + + @pytest.mark.asyncio + async def test_get_multiple_pages(self, context: ProductDocsToolContext): + result = await get_product_doc_pages.fn( + context, + ["/docs/build/incremental-models", "/docs/build/models"], + ) + + assert len(result.pages) == 2 + for page in result.pages: + assert page.url + assert page.error is None + assert len(page.content) > 0 diff --git a/tests/unit/tools/test_product_docs.py b/tests/unit/tools/test_product_docs.py new file mode 100644 index 000000000..5af67e739 --- /dev/null +++ b/tests/unit/tools/test_product_docs.py @@ -0,0 +1,559 @@ +"""Unit tests for the Product Docs toolset.""" + +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import pytest + +from dbt_mcp.config.config import load_config +from dbt_mcp.dbt_cli.binary_type import BinaryType +from dbt_mcp.mcp.server import create_dbt_mcp +from dbt_mcp.product_docs.client import ( + normalize_doc_url, + parse_llms_full_txt, + parse_llms_txt, +) +from dbt_mcp.product_docs.tools import ( + ProductDocsToolContext, + get_product_doc_pages, + search_product_docs, +) + +SAMPLE_LLMS_TXT = """\ +# dbt Developer Hub + +> End user documentation, guides and technical reference for dbt + +- [The dbt Developer Hub](https://docs.getdbt.com/index.md) + +## Build + +- [Incremental models](https://docs.getdbt.com/docs/build/incremental-models.md): dbt incremental models let you transform and load only new or changed data. +- [Models](https://docs.getdbt.com/docs/build/models.md): A model is a SELECT statement. Models can be materialized as tables, views, or ephemeral. +- [Seeds](https://docs.getdbt.com/docs/build/seeds.md): Seeds are CSV files that dbt can load into your data warehouse. + +## Deploy + +- [Deploy jobs](https://docs.getdbt.com/docs/deploy/deploy-jobs.md): Create and schedule deploy jobs in dbt Cloud. +- [Continuous integration](https://docs.getdbt.com/docs/deploy/continuous-integration.md): Set up CI checks for your dbt project. + +## Reference + +- [Incremental strategy configs](https://docs.getdbt.com/reference/resource-configs/incremental-strategy.md): Configure incremental strategies for your models. +""" + +SAMPLE_LLMS_FULL_TXT = """\ +# dbt Developer Hub + +> End user documentation, guides and technical reference for dbt + +--- + +### Incremental models + +[About incremental models](https://docs.getdbt.com/docs/build/incremental-models-overview.md) + +Incremental models in dbt is a materialization strategy designed to efficiently +update your data warehouse tables. The merge strategy uses MERGE statements. +Incremental incremental incremental for keyword frequency testing. + +--- + +### About the Fusion engine + +[About Fusion](https://docs.getdbt.com/docs/fusion/about-fusion.md) + +The dbt Fusion engine is the next-generation engine for dbt, written from +the ground up in Rust. It catches incorrect SQL immediately and enables +state-aware orchestration. + +--- + +### Deploy jobs + +[Deploy jobs](https://docs.getdbt.com/docs/deploy/deploy-jobs.md) + +Create and schedule deploy jobs in dbt Cloud for production runs. +Use CI checks to validate changes before merging. +""" + + +@pytest.fixture +def mock_client(): + """Create a mock ProductDocsClient with AsyncMock methods.""" + client = Mock() + client.get_index = AsyncMock(return_value=parse_llms_txt(SAMPLE_LLMS_TXT)) + client.get_page = AsyncMock(return_value="# Page Content") + client.search_index = AsyncMock(return_value=[]) + return client + + +@pytest.fixture +def context(mock_client): + """Create ProductDocsToolContext with a mocked client.""" + ctx = ProductDocsToolContext.__new__(ProductDocsToolContext) + ctx.client = mock_client + return ctx + + +class TestParseLlmsTxt: + def test_parses_entries(self): + entries = parse_llms_txt(SAMPLE_LLMS_TXT) + assert len(entries) == 7 + + def test_parses_title_and_url(self): + entries = parse_llms_txt(SAMPLE_LLMS_TXT) + inc = next(e for e in entries if "Incremental models" in e["title"]) + assert inc["url"] == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_parses_description(self): + entries = parse_llms_txt(SAMPLE_LLMS_TXT) + inc = next(e for e in entries if "Incremental models" in e["title"]) + assert "incremental" in inc["description"].lower() + + def test_parses_section(self): + entries = parse_llms_txt(SAMPLE_LLMS_TXT) + inc = next(e for e in entries if "Incremental models" in e["title"]) + assert inc["section"] == "Build" + + def test_entry_without_description(self): + entries = parse_llms_txt(SAMPLE_LLMS_TXT) + hub = next(e for e in entries if "Developer Hub" in e["title"]) + assert hub["description"] == "" + + def test_empty_input(self): + entries = parse_llms_txt("") + assert entries == [] + + +class TestNormalizeDocUrl: + def test_relative_path(self): + result = normalize_doc_url("/docs/build/incremental-models") + assert result == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_full_url(self): + result = normalize_doc_url( + "https://docs.getdbt.com/docs/build/incremental-models" + ) + assert result == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_already_has_md(self): + result = normalize_doc_url( + "https://docs.getdbt.com/docs/build/incremental-models.md" + ) + assert result == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_trailing_slash(self): + result = normalize_doc_url("/docs/build/incremental-models/") + assert result == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_relative_no_leading_slash(self): + result = normalize_doc_url("docs/build/incremental-models") + assert result == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_whitespace_stripped(self): + result = normalize_doc_url(" /docs/build/incremental-models ") + assert result == "https://docs.getdbt.com/docs/build/incremental-models.md" + + def test_rejects_external_https_url(self): + with pytest.raises(ValueError, match="docs.getdbt.com"): + normalize_doc_url("https://evil.com/foo") + + def test_rejects_external_http_url(self): + with pytest.raises(ValueError, match="docs.getdbt.com"): + normalize_doc_url("http://internal:8080/secret") + + def test_rejects_subdomain_spoofing(self): + with pytest.raises(ValueError, match="docs.getdbt.com"): + normalize_doc_url("https://docs.getdbt.com.evil.com/page") + + +class TestSearchProductDocs: + @pytest.mark.asyncio + async def test_search_returns_results(self, context, mock_client): + mock_client.search_index.return_value = [ + { + "title": "Incremental models", + "url": "https://docs.getdbt.com/docs/build/incremental-models", + "description": "dbt incremental models", + "section": "Build", + }, + { + "title": "Incremental strategy configs", + "url": "https://docs.getdbt.com/reference/resource-configs/incremental-strategy", + "description": "Configure incremental strategies", + "section": "Reference", + }, + ] + result = await search_product_docs.fn(context, "incremental") + assert result.total_matches == 2 + assert len(result.results) == 2 + titles = [r.title for r in result.results] + assert any("Incremental" in t for t in titles) + + @pytest.mark.asyncio + async def test_search_no_matches(self, context, mock_client): + mock_client.search_index.return_value = [] + result = await search_product_docs.fn(context, "zzzznonexistent") + assert result.total_matches == 0 + assert result.results == [] + + @pytest.mark.asyncio + async def test_search_empty_query(self, context): + result = await search_product_docs.fn(context, "") + assert result.error is not None + assert "empty" in result.error.lower() + + @pytest.mark.asyncio + async def test_search_ranks_title_matches_higher(self, context, mock_client): + mock_client.search_index.return_value = [ + {"title": "Models", "url": "https://docs.getdbt.com/docs/build/models"}, + { + "title": "Incremental models", + "url": "https://docs.getdbt.com/docs/build/incremental-models", + }, + ] + result = await search_product_docs.fn(context, "models") + titles = [r.title for r in result.results] + assert titles[0] == "Models" + + @pytest.mark.asyncio + async def test_search_multi_keyword(self, context, mock_client): + mock_client.search_index.return_value = [ + { + "title": "Incremental strategy configs", + "url": "https://docs.getdbt.com/reference/resource-configs/incremental-strategy", + }, + ] + result = await search_product_docs.fn(context, "incremental strategy") + assert result.total_matches >= 1 + + @pytest.mark.asyncio + async def test_results_are_typed(self, context, mock_client): + mock_client.search_index.return_value = [ + { + "title": "Incremental models", + "url": "https://docs.getdbt.com/docs/build/incremental-models", + "description": "dbt incremental models", + "section": "Build", + }, + ] + result = await search_product_docs.fn(context, "incremental") + assert result.results[0].title == "Incremental models" + assert result.results[0].description == "dbt incremental models" + assert result.results[0].section == "Build" + + @pytest.mark.asyncio + async def test_query_expansion_when_few_title_matches(self, context, mock_client): + mock_client.search_index.side_effect = [ + [], + [ + { + "url": "https://docs.getdbt.com/docs/build/udfs", + "title": "User-defined functions", + }, + ], + ] + result = await search_product_docs.fn(context, "udf") + assert len(result.results) > 0 + assert result.results[0].title == "User-defined functions" + assert result.search_method == "query_expansion" + assert mock_client.search_index.call_count == 2 + + @pytest.mark.asyncio + async def test_query_expansion_deduplicates_urls(self, context, mock_client): + mock_client.search_index.side_effect = [ + [ + { + "title": "Incremental models", + "url": "https://docs.getdbt.com/docs/build/incremental-models", + }, + ], + [ + { + "url": "https://docs.getdbt.com/docs/build/incremental-models", + "title": "Duplicate", + }, + {"url": "https://docs.getdbt.com/docs/new-page", "title": "New Page"}, + ], + ] + result = await search_product_docs.fn(context, "ci") + urls = [r.url for r in result.results] + assert len(urls) == len(set(urls)) + + @pytest.mark.asyncio + async def test_no_query_expansion_when_enough_results(self, context, mock_client): + mock_client.search_index.return_value = [ + {"title": "A", "url": "https://docs.getdbt.com/a"}, + {"title": "B", "url": "https://docs.getdbt.com/b"}, + {"title": "C", "url": "https://docs.getdbt.com/c"}, + ] + result = await search_product_docs.fn(context, "test") + assert result.search_method is None + assert mock_client.search_index.call_count == 1 + + +class TestGetProductDocPages: + @pytest.mark.asyncio + async def test_fetch_single_page(self, context, mock_client): + mock_client.get_page.return_value = "# Incremental Models\n\nContent here." + + result = await get_product_doc_pages.fn( + context, ["/docs/build/incremental-models"] + ) + assert len(result.pages) == 1 + page = result.pages[0] + assert page.content == "# Incremental Models\n\nContent here." + assert "incremental-models" in page.url + assert not page.url.endswith(".md") + assert page.error is None + + @pytest.mark.asyncio + async def test_fetch_not_found(self, context, mock_client): + mock_response = Mock() + mock_response.status_code = 404 + mock_client.get_page.side_effect = httpx.HTTPStatusError( + "404 Not Found", + request=Mock(), + response=mock_response, + ) + + result = await get_product_doc_pages.fn(context, ["/docs/nonexistent"]) + page = result.pages[0] + assert page.error is not None + assert "404" in page.error + assert page.content == "" + + @pytest.mark.asyncio + async def test_fetch_request_error(self, context, mock_client): + mock_client.get_page.side_effect = httpx.RequestError( + "Connection failed", request=Mock() + ) + + result = await get_product_doc_pages.fn(context, ["/docs/build/models"]) + page = result.pages[0] + assert page.error is not None + assert "Failed to fetch" in page.error + assert page.content == "" + + @pytest.mark.asyncio + async def test_fetches_multiple_pages(self, context, mock_client): + mock_client.get_page.return_value = "# Page Content" + result = await get_product_doc_pages.fn( + context, + ["/docs/build/models", "/docs/build/seeds"], + ) + assert len(result.pages) == 2 + assert all(p.content == "# Page Content" for p in result.pages) + assert all(p.error is None for p in result.pages) + + @pytest.mark.asyncio + async def test_handles_partial_failures(self, context, mock_client): + mock_response = Mock() + mock_response.status_code = 500 + + async def side_effect(url): + if "models" in url: + raise httpx.HTTPStatusError( + "HTTP 500", request=Mock(), response=mock_response + ) + return "# Good page" + + mock_client.get_page.side_effect = side_effect + + result = await get_product_doc_pages.fn( + context, + ["/docs/build/models", "/docs/build/seeds"], + ) + assert len(result.pages) == 2 + models_page = result.pages[0] + seeds_page = result.pages[1] + assert models_page.error is not None + assert models_page.content == "" + assert seeds_page.content == "# Good page" + assert seeds_page.error is None + + @pytest.mark.asyncio + async def test_all_pages_fail(self, context, mock_client): + mock_response = Mock() + mock_response.status_code = 500 + mock_client.get_page.side_effect = httpx.HTTPStatusError( + "HTTP 500", request=Mock(), response=mock_response + ) + + result = await get_product_doc_pages.fn( + context, + ["/docs/build/models", "/docs/build/seeds"], + ) + assert len(result.pages) == 2 + assert all(p.error is not None for p in result.pages) + + @pytest.mark.asyncio + async def test_clamped_to_10_pages(self, context, mock_client): + mock_client.get_page.return_value = "# Content" + paths = [f"/docs/page-{i}" for i in range(15)] + + result = await get_product_doc_pages.fn(context, paths) + assert len(result.pages) == 10 + + @pytest.mark.asyncio + async def test_empty_paths_list(self, context, mock_client): + result = await get_product_doc_pages.fn(context, []) + assert len(result.pages) == 0 + + @pytest.mark.asyncio + async def test_urls_stripped_of_md(self, context, mock_client): + mock_client.get_page.return_value = "# Content" + result = await get_product_doc_pages.fn(context, ["/docs/build/models"]) + assert not result.pages[0].url.endswith(".md") + + @pytest.mark.asyncio + async def test_rejects_external_url(self, context, mock_client): + result = await get_product_doc_pages.fn( + context, ["https://evil.com/steal-data"] + ) + page = result.pages[0] + assert page.error is not None + assert page.error.startswith("URL must be on docs.getdbt.com") + assert page.content == "" + + @pytest.mark.asyncio + async def test_error_url_is_normalized(self, context, mock_client): + mock_response = Mock() + mock_response.status_code = 404 + mock_client.get_page.side_effect = httpx.HTTPStatusError( + "404 Not Found", + request=Mock(), + response=mock_response, + ) + result = await get_product_doc_pages.fn(context, ["/docs/build/models"]) + page = result.pages[0] + assert page.error is not None + assert page.url.startswith("https://docs.getdbt.com/") + assert not page.url.endswith(".md") + + +class TestParseLlmsFullTxt: + def test_parses_pages(self): + pages = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + assert len(pages) >= 2 + + def test_extracts_urls(self): + pages = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + urls = [p["url"] for p in pages] + assert any("about-fusion" in u for u in urls) + assert any("incremental-models" in u for u in urls) + + def test_extracts_titles(self): + pages = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + titles = [p["title"] for p in pages] + assert any("Fusion" in t for t in titles) + + def test_content_is_lowercase(self): + pages = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + for page in pages: + assert page["content_lower"] == page["content_lower"].lower() + + def test_empty_input(self): + pages = parse_llms_full_txt("") + assert pages == [] + + +class TestSearchFullText: + @pytest.mark.asyncio + async def test_finds_keyword_in_body(self): + from dbt_mcp.product_docs.client import ProductDocsClient + + client = ProductDocsClient() + client._cache["full_text"] = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + results = await client.search_full_text(["rust"]) + urls = [r["url"] for r in results] + assert any("fusion" in u for u in urls) + + @pytest.mark.asyncio + async def test_no_results_for_missing_keyword(self): + from dbt_mcp.product_docs.client import ProductDocsClient + + client = ProductDocsClient() + client._cache["full_text"] = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + results = await client.search_full_text(["zzzznonexistent"]) + assert results == [] + + @pytest.mark.asyncio + async def test_or_logic_across_keywords(self): + from dbt_mcp.product_docs.client import ProductDocsClient + + client = ProductDocsClient() + client._cache["full_text"] = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + results = await client.search_full_text(["rust", "deploy"]) + urls = [r["url"] for r in results] + assert any("fusion" in u for u in urls) + assert any("deploy" in u for u in urls) + + @pytest.mark.asyncio + async def test_ranks_by_keyword_frequency(self): + from dbt_mcp.product_docs.client import ProductDocsClient + + client = ProductDocsClient() + client._cache["full_text"] = parse_llms_full_txt(SAMPLE_LLMS_FULL_TXT) + results = await client.search_full_text(["incremental"]) + assert results[0]["url"].endswith("incremental-models-overview") + + +class TestProductDocsRegistration: + @pytest.mark.asyncio + async def test_tools_registered_by_default(self, env_setup): + with ( + env_setup(), + patch( + "dbt_mcp.config.config.detect_binary_type", + return_value=BinaryType.DBT_CORE, + ), + ): + config = load_config(enable_proxied_tools=False) + dbt_mcp = await create_dbt_mcp(config) + server_tools = await dbt_mcp.list_tools() + tool_names = {tool.name for tool in server_tools} + + assert "search_product_docs" in tool_names + assert "get_product_doc_pages" in tool_names + + @pytest.mark.asyncio + async def test_tools_disabled(self, env_setup): + with ( + env_setup(env_vars={"DISABLE_PRODUCT_DOCS": "true"}), + patch( + "dbt_mcp.config.config.detect_binary_type", + return_value=BinaryType.DBT_CORE, + ), + ): + config = load_config(enable_proxied_tools=False) + dbt_mcp = await create_dbt_mcp(config) + server_tools = await dbt_mcp.list_tools() + tool_names = {tool.name for tool in server_tools} + + assert "search_product_docs" not in tool_names + assert "get_product_doc_pages" not in tool_names + + @pytest.mark.asyncio + async def test_tools_available_without_cloud_credentials(self, env_setup): + """Product docs tools should work even without DBT_HOST or DBT_TOKEN.""" + with ( + env_setup( + env_vars={ + "DBT_HOST": "", + "DBT_TOKEN": "", + } + ), + patch( + "dbt_mcp.config.config.detect_binary_type", + return_value=BinaryType.DBT_CORE, + ), + ): + config = load_config(enable_proxied_tools=False) + dbt_mcp = await create_dbt_mcp(config) + server_tools = await dbt_mcp.list_tools() + tool_names = {tool.name for tool in server_tools} + + assert "search_product_docs" in tool_names + assert "get_product_doc_pages" in tool_names