11# pyright: reportAny=false
22"""dbt-core-interface client for interacting with a dbt-core-interface FastAPI server."""
33
4+ from __future__ import annotations
5+
46import functools
57import logging
68import typing as t
9+ from pathlib import Path
10+ from urllib .parse import urljoin
711
812import requests
913
2529
2630
2731@t .final
28- class ServerError (Exception ):
32+ class ServerErrorException (Exception ): # noqa: N818
2933 """Custom exception for handling server errors from the dbt-core-interface."""
3034
3135 def __init__ (self , error : _ServerError ) -> None :
@@ -54,30 +58,48 @@ def __init__(
5458 target : str | None = None ,
5559 base_url : str = "http://localhost:8581" ,
5660 timeout : float | tuple [float , float ] = 10.0 ,
61+ unregister_on_close : bool = True ,
5762 ) -> None :
5863 """Initialize the client with the base URL and optional project name."""
59- self .project_dir = project_dir
64+ self .project_dir = Path ( project_dir ). resolve ()
6065 self .base_url = base_url .rstrip ("/" )
6166 self .timeout = timeout
6267 self .session = requests .Session ()
68+ self .session .headers .update (
69+ {
70+ "Content-Type" : "application/json" ,
71+ "User-Agent" : "dbt-core-interface-client/1.0" ,
72+ "X-dbt-Project" : self .project_dir .name ,
73+ }
74+ )
75+ self .unregister_on_close = unregister_on_close
6376 response = self ._register_project (profiles_dir = profiles_dir , target = target )
6477 logger .info ("Registered project '%s' with server at %s" , response .added , self .base_url )
6578
66- def __del__ (self ) -> None :
79+ def close (self ) -> None :
6780 """Unregister the project on client destruction."""
68- try :
69- response = self ._unregister_project ()
70- logger .info (
71- "Unregistered project '%s' with server at %s" , response .removed , self .base_url
72- )
73- except Exception as e :
74- logger .error ("Failed to unregister project '%s': %s" , self .project_dir , e )
81+ if self .unregister_on_close :
82+ try :
83+ response = self ._unregister_project ()
84+ logger .info (
85+ "Unregistered project '%s' with server at %s" , response .removed , self .base_url
86+ )
87+ except Exception as e :
88+ logger .error ("Failed to unregister project '%s': %s" , self .project_dir , e )
7589
76- def _headers (self ) -> dict [str , str ]:
77- return {
78- "User-Agent" : "dbt-core-interface-client/1.0" ,
79- "X-dbt-Project" : self .project_dir ,
80- }
90+ def __enter__ (self ) -> DbtInterfaceClient :
91+ """Context manager for the client to ensure proper cleanup."""
92+ return self
93+
94+ def __exit__ (
95+ self ,
96+ exc_type : type [BaseException ] | None ,
97+ exc_value : Exception | None ,
98+ traceback : t .Any | None ,
99+ ) -> None :
100+ """Close the client and unregister the project."""
101+ self .close ()
102+ self .session .close ()
81103
82104 def _request (
83105 self ,
@@ -88,8 +110,11 @@ def _request(
88110 json_payload : t .Any = None ,
89111 headers : dict [str , str ] | None = None ,
90112 ) -> requests .Response :
91- url = f"{ self .base_url } { path } "
92- headers = {** self ._headers (), ** (headers or {})}
113+ url = urljoin (self .base_url , path )
114+ headers = headers or {}
115+ params = params or {}
116+ params ["project_dir" ] = str (self .project_dir )
117+
93118 logger .debug (
94119 "Requesting %s %s with params=%s, data=%s, json=%s, headers=%s" ,
95120 method ,
@@ -108,13 +133,15 @@ def _request(
108133 headers = headers ,
109134 timeout = self .timeout ,
110135 )
136+
111137 if resp .status_code >= 400 :
112138 try :
113139 err = ServerErrorContainer .model_validate (resp .json ())
114- raise ServerError (err .error )
140+ raise ServerErrorException (err .error )
115141 except ValueError as e :
116142 logger .error ("Failed to parse error response: %s" , e )
117143 resp .raise_for_status ()
144+
118145 return resp
119146
120147 def _register_project (
@@ -123,32 +150,32 @@ def _register_project(
123150 target : str | None = None ,
124151 ) -> ServerRegisterResult :
125152 """Register a new dbt project."""
126- params : dict [str , t .Any ] = {"project_dir" : self . project_dir }
153+ params : dict [str , t .Any ] = {}
127154 if profiles_dir is not None :
128155 params ["profiles_dir" ] = profiles_dir
129156 if target is not None :
130157 params ["target" ] = target
131- resp = self ._request ("POST " , "/register" , params = params )
158+ resp = self ._request ("GET " , "/api/v1 /register" , params = params )
132159 return ServerRegisterResult .model_validate (resp .json ())
133160
134161 def _unregister_project (self ) -> ServerUnregisterResult :
135162 """Unregister the current project."""
136- resp = self ._request ("POST " , "/unregister " )
163+ resp = self ._request ("DELETE " , "/api/v1/register " )
137164 return ServerUnregisterResult .model_validate (resp .json ())
138165
139166 def run_sql (
140167 self ,
141168 raw_sql : str ,
142169 limit : int = 200 ,
143- path : str | None = None ,
170+ model_path : str | None = None ,
144171 ) -> ServerRunResult :
145172 """Execute raw SQL against the registered dbt project."""
146173 params : dict [str , t .Any ] = {"limit" : limit }
147- if path is not None :
148- params ["path " ] = path
174+ if model_path is not None :
175+ params ["model_path " ] = model_path
149176 resp = self ._request (
150177 method = "POST" ,
151- path = "/run" ,
178+ path = "/api/v1/ run" ,
152179 data = raw_sql ,
153180 headers = {"Content-Type" : "text/plain" },
154181 params = params ,
@@ -158,48 +185,21 @@ def run_sql(
158185 def compile_sql (
159186 self ,
160187 raw_sql : str ,
161- path : str | None = None ,
188+ model_path : str | None = None ,
162189 ) -> ServerCompileResult :
163190 """Compile raw SQL without executing it."""
164191 params : dict [str , t .Any ] = {}
165- if path is not None :
166- params ["path " ] = path
192+ if model_path is not None :
193+ params ["model_path " ] = model_path
167194 resp = self ._request (
168195 method = "POST" ,
169- path = "/compile" ,
196+ path = "/api/v1/ compile" ,
170197 data = raw_sql ,
171198 headers = {"Content-Type" : "text/plain" },
172199 params = params ,
173200 )
174201 return ServerCompileResult .model_validate (resp .json ())
175202
176- def reset_project (
177- self ,
178- target : str | None = None ,
179- reset : bool = False ,
180- write_manifest : bool = False ,
181- ) -> ServerResetResult :
182- """Re-parse the dbt project."""
183- params : dict [str , t .Any ] = {}
184- if target is not None :
185- params ["target" ] = target
186- if reset :
187- params ["reset" ] = reset
188- if write_manifest :
189- params ["write_manifest" ] = write_manifest
190- resp = self ._request ("GET" , "/reset" , params = params )
191- return ServerResetResult .model_validate (resp .json ())
192-
193- def health_check (self ) -> dict [str , t .Any ]:
194- """Check server health and project status."""
195- resp = self ._request ("GET" , "/health" )
196- return resp .json ()
197-
198- def heartbeat (self ) -> dict [str , t .Any ]:
199- """Check server availability."""
200- resp = self ._request ("GET" , "/heartbeat" )
201- return resp .json ()
202-
203203 def lint_sql (
204204 self ,
205205 sql_path : str | None = None ,
@@ -217,7 +217,7 @@ def lint_sql(
217217 if raw_sql is not None and sql_path is None :
218218 data = raw_sql
219219 headers = {"Content-Type" : "text/plain" }
220- resp = self ._request ("POST" , "/lint" , params = params , data = data , headers = headers )
220+ resp = self ._request ("POST" , "/api/v1/ lint" , params = params , data = data , headers = headers )
221221 return ServerLintResult .model_validate (resp .json ())
222222
223223 def format_sql (
@@ -237,9 +237,26 @@ def format_sql(
237237 if raw_sql is not None and sql_path is None :
238238 data = raw_sql
239239 headers = {"Content-Type" : "text/plain" }
240- resp = self ._request ("POST" , "/format" , params = params , data = data , headers = headers )
240+ resp = self ._request ("POST" , "/api/v1/ format" , params = params , data = data , headers = headers )
241241 return ServerFormatResult .model_validate (resp .json ())
242242
243+ def parse_project (
244+ self ,
245+ target : str | None = None ,
246+ reset : bool = False ,
247+ write_manifest : bool = False ,
248+ ) -> ServerResetResult :
249+ """Re-parse the dbt project."""
250+ params : dict [str , t .Any ] = {}
251+ if target is not None :
252+ params ["target" ] = target
253+ if reset :
254+ params ["reset" ] = reset
255+ if write_manifest :
256+ params ["write_manifest" ] = write_manifest
257+ resp = self ._request ("GET" , "/api/v1/parse" , params = params )
258+ return ServerResetResult .model_validate (resp .json ())
259+
243260 def command (
244261 self ,
245262 cmd : str ,
@@ -250,7 +267,7 @@ def command(
250267 payload : dict [str , t .Any ] = {"args" : args , "kwargs" : kwargs }
251268 resp = self ._request (
252269 method = "POST" ,
253- path = "/command" ,
270+ path = "/api/v1/ command" ,
254271 json_payload = payload ,
255272 params = {"cmd" : cmd },
256273 )
@@ -272,3 +289,14 @@ def command(
272289 snapshot = functools .partialmethod (command , "snapshot" )
273290 source_freshness = functools .partialmethod (command , "source freshness" )
274291 test = functools .partialmethod (command , "test" )
292+
293+ def status (self ) -> dict [str , t .Any ]:
294+ """Check server diagnostic status."""
295+ resp = self ._request ("GET" , "/api/v1/status" )
296+ return resp .json ()
297+
298+ def heartbeat (self ) -> bool :
299+ """Check server availability."""
300+ resp = self ._request ("GET" , "/api/v1/heartbeat" )
301+ pulse = resp .json ()
302+ return pulse ["result" ]["status" ] == "ready"
0 commit comments