Skip to content

Commit 54776d2

Browse files
committed
feat: improve server client interface and contract, prepare for dbt defer behavior
1 parent 49494ab commit 54776d2

File tree

4 files changed

+202
-91
lines changed

4 files changed

+202
-91
lines changed

src/dbt_core_interface/client.py

Lines changed: 87 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# pyright: reportAny=false
22
"""dbt-core-interface client for interacting with a dbt-core-interface FastAPI server."""
33

4+
from __future__ import annotations
5+
46
import functools
57
import logging
68
import typing as t
9+
from pathlib import Path
10+
from urllib.parse import urljoin
711

812
import requests
913

@@ -25,7 +29,7 @@
2529

2630

2731
@t.final
28-
class ServerError(Exception):
32+
class ServerErrorException(Exception): # noqa: N818
2933
"""Custom exception for handling server errors from the dbt-core-interface."""
3034

3135
def __init__(self, error: _ServerError) -> None:
@@ -54,30 +58,48 @@ def __init__(
5458
target: str | None = None,
5559
base_url: str = "http://localhost:8581",
5660
timeout: float | tuple[float, float] = 10.0,
61+
unregister_on_close: bool = True,
5762
) -> None:
5863
"""Initialize the client with the base URL and optional project name."""
59-
self.project_dir = project_dir
64+
self.project_dir = Path(project_dir).resolve()
6065
self.base_url = base_url.rstrip("/")
6166
self.timeout = timeout
6267
self.session = requests.Session()
68+
self.session.headers.update(
69+
{
70+
"Content-Type": "application/json",
71+
"User-Agent": "dbt-core-interface-client/1.0",
72+
"X-dbt-Project": self.project_dir.name,
73+
}
74+
)
75+
self.unregister_on_close = unregister_on_close
6376
response = self._register_project(profiles_dir=profiles_dir, target=target)
6477
logger.info("Registered project '%s' with server at %s", response.added, self.base_url)
6578

66-
def __del__(self) -> None:
79+
def close(self) -> None:
6780
"""Unregister the project on client destruction."""
68-
try:
69-
response = self._unregister_project()
70-
logger.info(
71-
"Unregistered project '%s' with server at %s", response.removed, self.base_url
72-
)
73-
except Exception as e:
74-
logger.error("Failed to unregister project '%s': %s", self.project_dir, e)
81+
if self.unregister_on_close:
82+
try:
83+
response = self._unregister_project()
84+
logger.info(
85+
"Unregistered project '%s' with server at %s", response.removed, self.base_url
86+
)
87+
except Exception as e:
88+
logger.error("Failed to unregister project '%s': %s", self.project_dir, e)
7589

76-
def _headers(self) -> dict[str, str]:
77-
return {
78-
"User-Agent": "dbt-core-interface-client/1.0",
79-
"X-dbt-Project": self.project_dir,
80-
}
90+
def __enter__(self) -> DbtInterfaceClient:
91+
"""Context manager for the client to ensure proper cleanup."""
92+
return self
93+
94+
def __exit__(
95+
self,
96+
exc_type: type[BaseException] | None,
97+
exc_value: Exception | None,
98+
traceback: t.Any | None,
99+
) -> None:
100+
"""Close the client and unregister the project."""
101+
self.close()
102+
self.session.close()
81103

82104
def _request(
83105
self,
@@ -88,8 +110,11 @@ def _request(
88110
json_payload: t.Any = None,
89111
headers: dict[str, str] | None = None,
90112
) -> requests.Response:
91-
url = f"{self.base_url}{path}"
92-
headers = {**self._headers(), **(headers or {})}
113+
url = urljoin(self.base_url, path)
114+
headers = headers or {}
115+
params = params or {}
116+
params["project_dir"] = str(self.project_dir)
117+
93118
logger.debug(
94119
"Requesting %s %s with params=%s, data=%s, json=%s, headers=%s",
95120
method,
@@ -108,13 +133,15 @@ def _request(
108133
headers=headers,
109134
timeout=self.timeout,
110135
)
136+
111137
if resp.status_code >= 400:
112138
try:
113139
err = ServerErrorContainer.model_validate(resp.json())
114-
raise ServerError(err.error)
140+
raise ServerErrorException(err.error)
115141
except ValueError as e:
116142
logger.error("Failed to parse error response: %s", e)
117143
resp.raise_for_status()
144+
118145
return resp
119146

120147
def _register_project(
@@ -123,32 +150,32 @@ def _register_project(
123150
target: str | None = None,
124151
) -> ServerRegisterResult:
125152
"""Register a new dbt project."""
126-
params: dict[str, t.Any] = {"project_dir": self.project_dir}
153+
params: dict[str, t.Any] = {}
127154
if profiles_dir is not None:
128155
params["profiles_dir"] = profiles_dir
129156
if target is not None:
130157
params["target"] = target
131-
resp = self._request("POST", "/register", params=params)
158+
resp = self._request("GET", "/api/v1/register", params=params)
132159
return ServerRegisterResult.model_validate(resp.json())
133160

134161
def _unregister_project(self) -> ServerUnregisterResult:
135162
"""Unregister the current project."""
136-
resp = self._request("POST", "/unregister")
163+
resp = self._request("DELETE", "/api/v1/register")
137164
return ServerUnregisterResult.model_validate(resp.json())
138165

139166
def run_sql(
140167
self,
141168
raw_sql: str,
142169
limit: int = 200,
143-
path: str | None = None,
170+
model_path: str | None = None,
144171
) -> ServerRunResult:
145172
"""Execute raw SQL against the registered dbt project."""
146173
params: dict[str, t.Any] = {"limit": limit}
147-
if path is not None:
148-
params["path"] = path
174+
if model_path is not None:
175+
params["model_path"] = model_path
149176
resp = self._request(
150177
method="POST",
151-
path="/run",
178+
path="/api/v1/run",
152179
data=raw_sql,
153180
headers={"Content-Type": "text/plain"},
154181
params=params,
@@ -158,48 +185,21 @@ def run_sql(
158185
def compile_sql(
159186
self,
160187
raw_sql: str,
161-
path: str | None = None,
188+
model_path: str | None = None,
162189
) -> ServerCompileResult:
163190
"""Compile raw SQL without executing it."""
164191
params: dict[str, t.Any] = {}
165-
if path is not None:
166-
params["path"] = path
192+
if model_path is not None:
193+
params["model_path"] = model_path
167194
resp = self._request(
168195
method="POST",
169-
path="/compile",
196+
path="/api/v1/compile",
170197
data=raw_sql,
171198
headers={"Content-Type": "text/plain"},
172199
params=params,
173200
)
174201
return ServerCompileResult.model_validate(resp.json())
175202

176-
def reset_project(
177-
self,
178-
target: str | None = None,
179-
reset: bool = False,
180-
write_manifest: bool = False,
181-
) -> ServerResetResult:
182-
"""Re-parse the dbt project."""
183-
params: dict[str, t.Any] = {}
184-
if target is not None:
185-
params["target"] = target
186-
if reset:
187-
params["reset"] = reset
188-
if write_manifest:
189-
params["write_manifest"] = write_manifest
190-
resp = self._request("GET", "/reset", params=params)
191-
return ServerResetResult.model_validate(resp.json())
192-
193-
def health_check(self) -> dict[str, t.Any]:
194-
"""Check server health and project status."""
195-
resp = self._request("GET", "/health")
196-
return resp.json()
197-
198-
def heartbeat(self) -> dict[str, t.Any]:
199-
"""Check server availability."""
200-
resp = self._request("GET", "/heartbeat")
201-
return resp.json()
202-
203203
def lint_sql(
204204
self,
205205
sql_path: str | None = None,
@@ -217,7 +217,7 @@ def lint_sql(
217217
if raw_sql is not None and sql_path is None:
218218
data = raw_sql
219219
headers = {"Content-Type": "text/plain"}
220-
resp = self._request("POST", "/lint", params=params, data=data, headers=headers)
220+
resp = self._request("POST", "/api/v1/lint", params=params, data=data, headers=headers)
221221
return ServerLintResult.model_validate(resp.json())
222222

223223
def format_sql(
@@ -237,9 +237,26 @@ def format_sql(
237237
if raw_sql is not None and sql_path is None:
238238
data = raw_sql
239239
headers = {"Content-Type": "text/plain"}
240-
resp = self._request("POST", "/format", params=params, data=data, headers=headers)
240+
resp = self._request("POST", "/api/v1/format", params=params, data=data, headers=headers)
241241
return ServerFormatResult.model_validate(resp.json())
242242

243+
def parse_project(
244+
self,
245+
target: str | None = None,
246+
reset: bool = False,
247+
write_manifest: bool = False,
248+
) -> ServerResetResult:
249+
"""Re-parse the dbt project."""
250+
params: dict[str, t.Any] = {}
251+
if target is not None:
252+
params["target"] = target
253+
if reset:
254+
params["reset"] = reset
255+
if write_manifest:
256+
params["write_manifest"] = write_manifest
257+
resp = self._request("GET", "/api/v1/parse", params=params)
258+
return ServerResetResult.model_validate(resp.json())
259+
243260
def command(
244261
self,
245262
cmd: str,
@@ -250,7 +267,7 @@ def command(
250267
payload: dict[str, t.Any] = {"args": args, "kwargs": kwargs}
251268
resp = self._request(
252269
method="POST",
253-
path="/command",
270+
path="/api/v1/command",
254271
json_payload=payload,
255272
params={"cmd": cmd},
256273
)
@@ -272,3 +289,14 @@ def command(
272289
snapshot = functools.partialmethod(command, "snapshot")
273290
source_freshness = functools.partialmethod(command, "source freshness")
274291
test = functools.partialmethod(command, "test")
292+
293+
def status(self) -> dict[str, t.Any]:
294+
"""Check server diagnostic status."""
295+
resp = self._request("GET", "/api/v1/status")
296+
return resp.json()
297+
298+
def heartbeat(self) -> bool:
299+
"""Check server availability."""
300+
resp = self._request("GET", "/api/v1/heartbeat")
301+
pulse = resp.json()
302+
return pulse["result"]["status"] == "ready"

src/dbt_core_interface/project.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _patched_adapter_accessor(config: t.Any) -> t.Any:
4444
from dbt.context.providers import generate_runtime_macro_context, generate_runtime_model_context
4545
from dbt.contracts.graph.manifest import Manifest
4646
from dbt.contracts.graph.nodes import ManifestNode, SourceDefinition
47+
from dbt.contracts.state import PreviousState
4748
from dbt.flags import set_from_args
4849
from dbt.parser.manifest import ManifestLoader, process_node
4950
from dbt.parser.read_files import FileDiff, InputFile, ReadFilesFromFileSystem
@@ -919,3 +920,21 @@ def format(
919920
f"format_command returning success={success}, result_sql={result_sql[:100] if result_sql is not None else 'n/a'}"
920921
)
921922
return success, result_sql
923+
924+
def inject_deferred_state(self, state_path: Path | str) -> None:
925+
"""Merge the manifest from a previous state artifact for dbt deferral behavior."""
926+
previous_state = PreviousState(
927+
state_path=Path(state_path).resolve(),
928+
target_path=self.target_path,
929+
project_root=self.project_root,
930+
)
931+
if previous_state.manifest is None:
932+
logger.warning(f"No manifest found in previous state at {state_path}")
933+
return
934+
self.manifest.merge_from_artifact(previous_state.manifest)
935+
936+
def clear_deferred_state(self) -> None:
937+
"""Clear the deferred state from the manifest."""
938+
for node in self.manifest.nodes.values():
939+
if hasattr(node, "defer_relation"):
940+
node.defer_relation = None # pyright: ignore[reportAttributeAccessIssue]

0 commit comments

Comments
 (0)