-
Notifications
You must be signed in to change notification settings - Fork 122
Expand file tree
/
Copy pathtools.py
More file actions
213 lines (192 loc) · 7.93 KB
/
tools.py
File metadata and controls
213 lines (192 loc) · 7.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import json
import os
import subprocess
from typing import Any
from mcp.server.fastmcp import FastMCP
from pydantic import Field
from dbt_mcp.config.config import DbtCodegenConfig
from dbt_mcp.dbt_cli.binary_type import get_color_disable_flag
from dbt_mcp.prompts.prompts import get_prompt
from dbt_mcp.tools.annotations import create_tool_annotations
from dbt_mcp.tools.definitions import ToolDefinition
from dbt_mcp.tools.register import register_tools
from dbt_mcp.tools.tool_names import ToolName
from dbt_mcp.tools.toolsets import Toolset
def create_dbt_codegen_tool_definitions(
config: DbtCodegenConfig,
) -> list[ToolDefinition]:
def _run_codegen_operation(
macro_name: str,
args: dict[str, Any] | None = None,
) -> str:
"""Execute a dbt-codegen macro using dbt run-operation."""
try:
# Build the dbt run-operation command
command = ["run-operation", macro_name]
# Add arguments if provided
if args:
# Convert args to JSON string for dbt
args_json = json.dumps(args)
command.extend(["--args", args_json])
full_command = command.copy()
# Add --quiet flag to reduce output verbosity
main_command = full_command[0]
command_args = full_command[1:] if len(full_command) > 1 else []
full_command = [main_command, "--quiet", *command_args]
# We change the path only if this is an absolute path, otherwise we can have
# problems with relative paths applied multiple times as DBT_PROJECT_DIR
# is applied to dbt Core and Fusion as well (but not the dbt Cloud CLI)
cwd_path = config.project_dir if os.path.isabs(config.project_dir) else None
# Add appropriate color disable flag based on binary type
color_flag = get_color_disable_flag(config.binary_type)
args_list = [config.dbt_path, color_flag, *full_command]
process = subprocess.Popen(
args=args_list,
cwd=cwd_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
stdin=subprocess.DEVNULL,
text=True,
)
stdout, stderr = process.communicate(timeout=config.dbt_cli_timeout)
# Use returncode as the success signal so noise on stderr (e.g.
# urllib3 deprecation warnings under externalbrowser auth) can't
# masquerade as the result of a successful command. On failure
# we surface stderr too, since some dbt errors only appear there.
if process.returncode != 0:
parts = []
if stdout:
parts.append(f"--- stdout ---\n{stdout.rstrip()}")
if stderr:
parts.append(f"--- stderr ---\n{stderr.rstrip()}")
combined = "\n".join(parts)
detail = combined or f"exit code {process.returncode} (no output)"
if "dbt found" in combined and "resource" in combined:
return (
"Error: dbt-codegen package may not be installed. "
f"Run 'dbt deps' to install it.\n{detail}"
)
return f"Error running dbt-codegen macro: {detail}"
return stdout or "OK"
except subprocess.TimeoutExpired:
return f"Timeout: dbt-codegen operation took longer than {config.dbt_cli_timeout} seconds."
except Exception as e:
return str(e)
def generate_source(
schema_name: str = Field(
description=get_prompt("dbt_codegen/args/schema_name")
),
database_name: str | None = Field(
default=None, description=get_prompt("dbt_codegen/args/database_name")
),
table_names: list[str] | None = Field(
default=None, description=get_prompt("dbt_codegen/args/table_names")
),
generate_columns: bool = Field(
default=False, description=get_prompt("dbt_codegen/args/generate_columns")
),
include_descriptions: bool = Field(
default=False,
description=get_prompt("dbt_codegen/args/include_descriptions"),
),
) -> str:
args: dict[str, Any] = {"schema_name": schema_name}
if database_name:
args["database_name"] = database_name
if table_names:
args["table_names"] = table_names
args["generate_columns"] = generate_columns
args["include_descriptions"] = include_descriptions
return _run_codegen_operation("generate_source", args)
def generate_model_yaml(
model_names: list[str] = Field(
description=get_prompt("dbt_codegen/args/model_names")
),
upstream_descriptions: bool = Field(
default=False,
description=get_prompt("dbt_codegen/args/upstream_descriptions"),
),
include_data_types: bool = Field(
default=True, description=get_prompt("dbt_codegen/args/include_data_types")
),
) -> str:
args: dict[str, Any] = {
"model_names": model_names,
"upstream_descriptions": upstream_descriptions,
"include_data_types": include_data_types,
}
return _run_codegen_operation("generate_model_yaml", args)
def generate_staging_model(
source_name: str = Field(
description=get_prompt("dbt_codegen/args/source_name")
),
table_name: str = Field(description=get_prompt("dbt_codegen/args/table_name")),
leading_commas: bool = Field(
default=False, description=get_prompt("dbt_codegen/args/leading_commas")
),
case_sensitive_cols: bool = Field(
default=False,
description=get_prompt("dbt_codegen/args/case_sensitive_cols"),
),
materialized: str | None = Field(
default=None, description=get_prompt("dbt_codegen/args/materialized")
),
) -> str:
args: dict[str, Any] = {
"source_name": source_name,
"table_name": table_name,
"leading_commas": leading_commas,
"case_sensitive_cols": case_sensitive_cols,
}
if materialized:
args["materialized"] = materialized
return _run_codegen_operation("generate_base_model", args)
return [
ToolDefinition(
fn=generate_source,
title="Generate Source",
description=get_prompt("dbt_codegen/generate_source"),
annotations=create_tool_annotations(
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=generate_model_yaml,
title="Generate Model YAML",
description=get_prompt("dbt_codegen/generate_model_yaml"),
annotations=create_tool_annotations(
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
ToolDefinition(
fn=generate_staging_model,
title="Generate Staging Model",
description=get_prompt("dbt_codegen/generate_staging_model"),
annotations=create_tool_annotations(
read_only_hint=True,
destructive_hint=False,
idempotent_hint=True,
),
),
]
def register_dbt_codegen_tools(
dbt_mcp: FastMCP,
config: DbtCodegenConfig,
*,
disabled_tools: set[ToolName],
enabled_tools: set[ToolName] | None,
enabled_toolsets: set[Toolset],
disabled_toolsets: set[Toolset],
) -> None:
register_tools(
dbt_mcp,
tool_definitions=create_dbt_codegen_tool_definitions(config),
disabled_tools=disabled_tools,
enabled_tools=enabled_tools,
enabled_toolsets=enabled_toolsets,
disabled_toolsets=disabled_toolsets,
)