Skip to content

Commit 5d9828b

Browse files
committed
feat: ensure project cleanup on gc across singleton refs via weakref finalizer
1 parent 08bcbf2 commit 5d9828b

File tree

5 files changed

+129
-89
lines changed

5 files changed

+129
-89
lines changed

src/dbt_core_interface/container.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import typing as t
99
from collections.abc import Generator
1010
from pathlib import Path
11+
from weakref import WeakValueDictionary
1112

1213
if t.TYPE_CHECKING:
1314
from dbt_core_interface.project import DbtConfiguration, DbtProject
@@ -24,7 +25,7 @@ class DbtProjectContainer:
2425
_instance: DbtProjectContainer | None = None
2526
_instance_lock: threading.Lock = threading.Lock()
2627

27-
_projects: dict[Path, DbtProject] = {}
28+
_projects: WeakValueDictionary[Path, DbtProject] = WeakValueDictionary()
2829
_default_project: Path | None = None
2930
_lock = threading.RLock()
3031

@@ -108,7 +109,6 @@ def drop_project(self, path: Path | str) -> DbtProject | None:
108109
project = self._projects.pop(p := Path(path).expanduser().resolve(), None)
109110
if project is None:
110111
return
111-
project.adapter.connections.cleanup_all()
112112
if p == self._default_project:
113113
self._default_project = next(iter(self._projects), None)
114114
return project
@@ -143,6 +143,11 @@ def __getitem__(self, path: Path | str) -> DbtProject:
143143
raise KeyError(f"No project registered under '{path}'.")
144144
return project
145145

146+
def __delitem__(self, path: Path | str) -> None:
147+
"""Unregister the project at the given path."""
148+
with self._lock:
149+
_ = self.drop_project(path)
150+
146151
def __contains__(self, path: Path | str) -> bool:
147152
"""Check if a project is registered at the given path."""
148153
return path in self._projects
@@ -157,8 +162,8 @@ def __repr__(self) -> str: # pyright: ignore[reportImplicitOverride]
157162
if len(self._projects) == 0:
158163
return s.format("<empty>")
159164
return s.format(
160-
"\n ",
161-
"\n ".join(
165+
"\n "
166+
+ "\n ".join(
162167
f"DbtProject(name={proj.project_name}, root={proj.project_root}),"
163168
for proj in self._projects.values()
164169
)

src/dbt_core_interface/project.py

Lines changed: 62 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
import time
1717
import typing as t
1818
import uuid
19+
import weakref
1920
from concurrent.futures import ThreadPoolExecutor
2021
from dataclasses import asdict, dataclass, field
22+
from dataclasses import replace as dc_replace
2123
from datetime import datetime
2224
from multiprocessing import get_context as get_mp_context
2325
from pathlib import Path
26+
from weakref import WeakValueDictionary
2427

2528
import dbt.adapters.factory
2629

@@ -45,7 +48,7 @@ def _patched_adapter_accessor(config: t.Any) -> t.Any:
4548
from dbt.contracts.graph.manifest import Manifest
4649
from dbt.contracts.graph.nodes import ManifestNode, SourceDefinition
4750
from dbt.contracts.state import PreviousState
48-
from dbt.flags import set_from_args
51+
from dbt.flags import get_flag_dict, set_from_args
4952
from dbt.parser.manifest import ManifestLoader, process_node
5053
from dbt.parser.read_files import FileDiff, InputFile, ReadFilesFromFileSystem
5154
from dbt.parser.sql import SqlBlockParser, SqlMacroParser
@@ -79,7 +82,7 @@ def _set_invocation_context() -> None:
7982
LoggerConfig(name=__name__, logger=logger),
8083
)
8184

82-
__all__ = ["DbtProject", "DbtConfiguration", "ExecutionResult", "CompilationResult"]
85+
__all__ = ["DbtProject", "DbtConfiguration"]
8386

8487
T = t.TypeVar("T")
8588

@@ -91,39 +94,49 @@ def _set_invocation_context() -> None:
9194
P = t_ext.ParamSpec("P")
9295

9396

94-
def _get_project_dir() -> Path:
97+
def _get_project_dir() -> str:
9598
"""Get the default project directory following dbt heuristics."""
9699
if "DBT_PROJECT_DIR" in os.environ:
97-
return Path(os.environ["DBT_PROJECT_DIR"]).expanduser().resolve()
100+
p = Path(os.environ["DBT_PROJECT_DIR"]).expanduser().resolve()
101+
return str(p)
98102
cwd = Path.cwd()
99103
for path in [cwd, *list(cwd.parents)]:
100104
if (path / "dbt_project.yml").exists():
101-
return path.resolve()
105+
return str(path.resolve())
102106
if path == Path.home():
103107
break
104-
return cwd.resolve()
108+
return str(cwd.resolve())
105109

106110

107-
def _get_profiles_dir(project_dir: Path | str | None = None) -> Path:
111+
def _get_profiles_dir(project_dir: Path | str | None = None) -> str:
108112
"""Get the default profiles directory following dbt heuristics."""
109113
if "DBT_PROFILES_DIR" in os.environ:
110-
return Path(os.environ["DBT_PROFILES_DIR"]).expanduser().resolve()
114+
p = Path(os.environ["DBT_PROFILES_DIR"]).expanduser().resolve()
115+
return str(p)
111116
_project_dir = Path(project_dir or _get_project_dir())
112117
if _project_dir.is_dir() and _project_dir.joinpath("profiles.yml").exists():
113-
return _project_dir
114-
return Path.home() / ".dbt"
118+
return str(_project_dir.resolve())
119+
home = Path.home()
120+
return str(home / ".dbt")
121+
122+
123+
DEFAULT_PROFILES_DIR = _get_profiles_dir()
115124

116125

117126
@dataclass(frozen=True)
118127
class DbtConfiguration:
119128
"""Minimal dbt configuration."""
120129

121-
project_dir: str = field(default_factory=lambda: str(_get_project_dir()))
122-
profiles_dir: str = field(default_factory=lambda: str(_get_profiles_dir()))
123-
target: str | None = field(default_factory=lambda: os.getenv("DBT_TARGET"))
130+
project_dir: str = field(default_factory=_get_project_dir)
131+
profiles_dir: str = field(default_factory=_get_profiles_dir)
132+
target: str | None = field(
133+
default_factory=functools.partial(os.getenv, "DBT_TARGET"),
134+
)
135+
profile: str | None = field(
136+
default_factory=functools.partial(os.getenv, "DBT_PROFILE", "default"),
137+
)
124138
threads: int = 1
125139
vars: dict[str, t.Any] = field(default_factory=dict)
126-
profile: str | None = field(default_factory=lambda: os.getenv("DBT_PROFILE"))
127140

128141
quiet: bool = True
129142
use_experimental_parser: bool = True
@@ -139,6 +152,13 @@ def single_threaded(self) -> bool:
139152
"""Return whether the project is single-threaded."""
140153
return self.threads <= 1
141154

155+
def __getattr__(self, item: str) -> t.Any:
156+
"""Get attribute with fallback to environment variables."""
157+
d = get_flag_dict()
158+
if item in d:
159+
return d[item]
160+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
161+
142162

143163
_use_slots = {}
144164
if sys.version_info >= (3, 10):
@@ -182,7 +202,7 @@ class DbtProject:
182202

183203
ADAPTER_TTL: int = 3600
184204

185-
_instances: dict[Path, DbtProject] = {}
205+
_instances: WeakValueDictionary[Path, DbtProject] = WeakValueDictionary()
186206
_instance_lock: threading.Lock = threading.Lock()
187207

188208
def __new__(
@@ -204,22 +224,11 @@ def __new__(
204224
project = super().__new__(cls)
205225
cls._instances[p] = project
206226
else:
207-
if profiles_dir is None:
208-
profiles_dir = str(_get_profiles_dir(project_dir).resolve())
209-
if (
210-
(target and target != project.runtime_config.target_name)
211-
or (vars and vars != project.args.vars)
212-
or Path(profiles_dir).expanduser().resolve() != project.profiles_yml.parent
227+
if (target and target != project.runtime_config.target_name) or (
228+
vars and vars != project.args.vars
213229
):
214-
project.args = DbtConfiguration(
215-
target=target,
216-
profiles_dir=profiles_dir,
217-
project_dir=str(p),
218-
threads=threads,
219-
vars=vars or {},
220-
)
221-
project.parse_project(reparse_configuration=True)
222-
return project
230+
project.set_args(target=target, vars=vars or {}, threads=threads)
231+
return project
223232

224233
def __init__(
225234
self,
@@ -236,11 +245,8 @@ def __init__(
236245
if hasattr(self, "_args"):
237246
return
238247

239-
if project_dir is not None and profiles_dir is None:
240-
profiles_dir = str(_get_profiles_dir(project_dir).resolve())
241-
242-
project_dir = project_dir or str(_get_project_dir())
243-
profiles_dir = profiles_dir or str(_get_project_dir())
248+
project_dir = project_dir or _get_profiles_dir()
249+
profiles_dir = profiles_dir or _get_profiles_dir(project_dir)
244250

245251
self._args = DbtConfiguration(
246252
target=target,
@@ -281,31 +287,25 @@ def __init__(
281287

282288
CONTAINER.add_project(self)
283289

284-
def __del__(self) -> None:
285-
"""Clean up resources on deletion."""
286-
from dbt_core_interface.container import DbtProjectContainer
287-
from dbt_core_interface.watcher import DbtProjectWatcher
290+
ref = weakref.ref(self)
288291

289-
with self._adapter_lock:
290-
if self._adapter:
291-
self._adapter.connections.cleanup_all()
292-
_ = atexit.unregister(self._adapter.connections.cleanup_all)
293-
self._adapter = None
292+
def finalizer() -> None:
293+
from dbt_core_interface.container import CONTAINER
294294

295-
if self._pool:
296-
self._pool.shutdown(wait=False, cancel_futures=True)
297-
self._pool = None
295+
if (instance := ref()) is not None:
296+
with DbtProject._instance_lock:
297+
del DbtProject._instances[instance.project_root]
298298

299-
with contextlib.suppress(Exception):
300-
container = DbtProjectContainer()
301-
_ = container.drop_project(self.project_root)
299+
if instance._adapter:
300+
atexit.unregister(instance._adapter.connections.cleanup_all)
301+
instance._adapter.connections.cleanup_all()
302302

303-
with contextlib.suppress(Exception):
304-
watcher = DbtProjectWatcher(self)
305-
del watcher
303+
if instance._pool:
304+
instance._pool.shutdown(wait=True, cancel_futures=True)
306305

307-
with self._instance_lock:
308-
_ = self._instances.pop(self.project_root, None)
306+
del CONTAINER[instance.project_root]
307+
308+
self._finalize = weakref.finalize(self, finalizer)
309309

310310
def __repr__(self) -> str: # pyright: ignore[reportImplicitOverride]
311311
"""Return a string representation of the DbtProject instance."""
@@ -322,6 +322,7 @@ def from_config(cls, config: DbtConfiguration) -> DbtProject:
322322
project_dir=config.project_dir,
323323
threads=config.threads,
324324
vars=config.vars,
325+
profile=config.profile,
325326
)
326327

327328
@property
@@ -333,13 +334,15 @@ def args(self) -> DbtConfiguration:
333334
def args(self, value: DbtConfiguration | dict[str, t.Any]) -> None: # pyright: ignore[reportPropertyTypeMismatch]
334335
"""Set the args for the DbtProject instance and update runtime config."""
335336
if isinstance(value, dict):
336-
merged_args = asdict(self._args)
337-
merged_args.update(value)
338-
value = DbtConfiguration(**merged_args)
337+
value = dc_replace(self._args, **value)
339338
set_from_args(value, None) # pyright: ignore[reportArgumentType]
340339
self.parse_project(reparse_configuration=True)
341340
self._args = value
342341

342+
def set_args(self, **kwargs: t.Any) -> None:
343+
"""Set the args for the DbtProject instance."""
344+
self.args = kwargs
345+
343346
@property
344347
def adapter(self) -> BaseAdapter:
345348
"""Get adapter with TTL management for long-running processes."""
@@ -440,7 +443,7 @@ def create_adapter(
440443
self._pool = None
441444

442445
self._adapter.connections.cleanup_all()
443-
_ = atexit.unregister(self._adapter.connections.cleanup_all)
446+
atexit.unregister(self._adapter.connections.cleanup_all)
444447

445448
adapter_cls = get_adapter_class_by_name(self.runtime_config.credentials.type)
446449
self._adapter = t.cast(

0 commit comments

Comments
 (0)