|
8 | 8 |
|
9 | 9 | from dbt_mcp.config.config_providers import ConfigProvider, DiscoveryConfig |
10 | 10 | from dbt_mcp.discovery.graphql import load_query |
11 | | -from dbt_mcp.errors import InvalidParameterError |
| 11 | +from dbt_mcp.errors import InvalidParameterError, ToolCallError |
12 | 12 | from dbt_mcp.gql.errors import raise_gql_error |
13 | 13 |
|
14 | 14 | DEFAULT_PAGE_SIZE = 100 |
@@ -302,6 +302,9 @@ class GraphQLQueries: |
302 | 302 | } |
303 | 303 | """) |
304 | 304 |
|
| 305 | + # Lineage query |
| 306 | + GET_FULL_LINEAGE = load_query("get_full_lineage.gql") |
| 307 | + |
305 | 308 |
|
306 | 309 | class MetadataAPIClient: |
307 | 310 | def __init__(self, config_provider: ConfigProvider[DiscoveryConfig]): |
@@ -667,3 +670,121 @@ async def fetch_details( |
667 | 670 | if not edges: |
668 | 671 | return [] |
669 | 672 | return [e["node"] for e in edges] |
| 673 | + |
| 674 | + |
| 675 | +class LineageResourceType(StrEnum): |
| 676 | + """Resource types supported by the lineage API.""" |
| 677 | + |
| 678 | + MODEL = "Model" |
| 679 | + SOURCE = "Source" |
| 680 | + SEED = "Seed" |
| 681 | + SNAPSHOT = "Snapshot" |
| 682 | + EXPOSURE = "Exposure" |
| 683 | + METRIC = "Metric" |
| 684 | + SEMANTIC_MODEL = "SemanticModel" |
| 685 | + SAVED_QUERY = "SavedQuery" |
| 686 | + TEST = "Test" |
| 687 | + |
| 688 | + |
| 689 | +class LineageFetcher: |
| 690 | + """Fetcher for lineage data. Returns nodes connected to the target.""" |
| 691 | + |
| 692 | + def __init__(self, api_client: MetadataAPIClient): |
| 693 | + self.api_client = api_client |
| 694 | + |
| 695 | + async def fetch_lineage( |
| 696 | + self, |
| 697 | + unique_id: str, |
| 698 | + depth: int, |
| 699 | + types: list[LineageResourceType] | None = None, |
| 700 | + ) -> list[dict]: |
| 701 | + """Fetch lineage graph filtered to nodes connected to unique_id. |
| 702 | +
|
| 703 | + Args: |
| 704 | + unique_id: The dbt unique ID of the resource to get lineage for. |
| 705 | + types: List of resource types to include. If None, includes all types. |
| 706 | +
|
| 707 | + Returns: |
| 708 | + List of nodes connected to unique_id (upstream + downstream). |
| 709 | + """ |
| 710 | + if depth <= 0: |
| 711 | + raise ToolCallError("Depth must be greater than 0") |
| 712 | + config = await self.api_client.config_provider.get_config() |
| 713 | + type_filter = [ |
| 714 | + t.value for t in (types if types is not None else LineageResourceType) |
| 715 | + ] |
| 716 | + variables = { |
| 717 | + "environmentId": config.environment_id, |
| 718 | + "types": type_filter, |
| 719 | + # uniqueId removed - not used by GraphQL |
| 720 | + } |
| 721 | + |
| 722 | + result = await self.api_client.execute_query( |
| 723 | + GraphQLQueries.GET_FULL_LINEAGE, variables |
| 724 | + ) |
| 725 | + raise_gql_error(result) |
| 726 | + |
| 727 | + all_nodes = ( |
| 728 | + result.get("data", {}) |
| 729 | + .get("environment", {}) |
| 730 | + .get("applied", {}) |
| 731 | + .get("lineage", []) |
| 732 | + ) |
| 733 | + |
| 734 | + # Filter to connected nodes only |
| 735 | + return self._filter_connected_nodes(all_nodes, unique_id, depth) |
| 736 | + |
| 737 | + def _filter_connected_nodes( |
| 738 | + self, nodes: list[dict], target_id: str, depth: int |
| 739 | + ) -> list[dict]: |
| 740 | + """Return only nodes connected to target_id (upstream and downstream). |
| 741 | +
|
| 742 | + Uses BFS to find all nodes reachable from target in both directions. |
| 743 | + """ |
| 744 | + node_map = { |
| 745 | + n["uniqueId"]: n |
| 746 | + for n in nodes |
| 747 | + if (resource_type := n.get("resourceType")) |
| 748 | + and isinstance(resource_type, str) |
| 749 | + # Filtering out macros because they have large |
| 750 | + # dependency graphs that aren't always useful. |
| 751 | + and resource_type.strip().lower() != "macro" |
| 752 | + } |
| 753 | + |
| 754 | + if target_id not in node_map: |
| 755 | + return [] |
| 756 | + |
| 757 | + # BFS to find all connected nodes |
| 758 | + connected = {target_id} |
| 759 | + queue = [(target_id, 0)] |
| 760 | + |
| 761 | + while queue: |
| 762 | + current_id, current_depth = queue.pop(0) |
| 763 | + node = node_map.get(current_id) |
| 764 | + if not node: |
| 765 | + continue |
| 766 | + |
| 767 | + # Stop traversing beyond the depth limit |
| 768 | + if current_depth >= depth: |
| 769 | + continue |
| 770 | + |
| 771 | + # Traverse upstream (parents) |
| 772 | + for parent_id in node.get("parentIds", []): |
| 773 | + if parent_id not in connected and parent_id in node_map: |
| 774 | + connected.add(parent_id) |
| 775 | + queue.append((parent_id, current_depth + 1)) |
| 776 | + |
| 777 | + # Traverse downstream (children) |
| 778 | + for candidate in nodes: |
| 779 | + candidate_id = candidate.get("uniqueId") |
| 780 | + if not candidate_id or candidate_id not in node_map: |
| 781 | + continue |
| 782 | + if ( |
| 783 | + current_id in candidate.get("parentIds", []) |
| 784 | + and candidate_id not in connected |
| 785 | + ): |
| 786 | + connected.add(candidate_id) |
| 787 | + queue.append((candidate_id, current_depth + 1)) |
| 788 | + |
| 789 | + # Return in original order |
| 790 | + return [node_map[uid] for uid in connected] |
0 commit comments