Skip to content

Commit dc0a4bf

Browse files
committed
feat: get default project within current tree, fix test fixture for sqlfluff 3.4
1 parent fa7e741 commit dc0a4bf

File tree

4 files changed

+140
-122
lines changed

4 files changed

+140
-122
lines changed

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "dbt-core-interface"
3-
version = "1.0.2"
3+
version = "1.0.3"
44
dynamic = []
55
description = "Dbt Core Interface"
66
authors = [
@@ -18,8 +18,8 @@ keywords = [
1818
]
1919
dependencies = [
2020
"dbt-core>=1.8.0,<2.0.0",
21-
"rich",
22-
"typing-extensions; python_version < '3.10'"
21+
"rich>10.0",
22+
"typing-extensions; python_version < '3.10'",
2323
]
2424

2525
[project.urls]

src/dbt_core_interface/project.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -92,35 +92,38 @@ def _set_invocation_context() -> None:
9292

9393
def _get_project_dir() -> Path:
9494
"""Get the default project directory following dbt heuristics."""
95-
return Path(os.getenv("DBT_PROJECT_DIR", os.getcwd())).expanduser().resolve()
95+
if "DBT_PROJECT_DIR" in os.environ:
96+
return Path(os.environ["DBT_PROJECT_DIR"]).expanduser().resolve()
97+
cwd = Path.cwd()
98+
for path in [cwd, *list(cwd.parents)]:
99+
if (path / "dbt_project.yml").exists():
100+
return path.resolve()
101+
if path == Path.home():
102+
break
103+
return cwd.resolve()
96104

97105

98106
def _get_profiles_dir(project_dir: Path | str | None = None) -> Path:
99107
"""Get the default profiles directory following dbt heuristics."""
100-
if "DBT_PROFILES_DIR" not in os.environ:
101-
_project_dir = Path(project_dir or _get_project_dir())
102-
if _project_dir.is_dir() and _project_dir.joinpath("profiles.yml").exists():
103-
return _project_dir
104-
return Path.home() / ".dbt"
105-
return Path(os.environ["DBT_PROFILES_DIR"]).expanduser().resolve()
106-
107-
108-
DEFAULT_PROFILES_DIR = str(_get_profiles_dir())
109-
DEFAULT_PROJECT_DIR = str(_get_project_dir())
108+
if "DBT_PROFILES_DIR" in os.environ:
109+
return Path(os.environ["DBT_PROFILES_DIR"]).expanduser().resolve()
110+
_project_dir = Path(project_dir or _get_project_dir())
111+
if _project_dir.is_dir() and _project_dir.joinpath("profiles.yml").exists():
112+
return _project_dir
113+
return Path.home() / ".dbt"
110114

111115

112116
@dataclass(frozen=True)
113117
class DbtConfiguration:
114118
"""Minimal dbt configuration."""
115119

116-
project_dir: str = DEFAULT_PROJECT_DIR
117-
profiles_dir: str = DEFAULT_PROFILES_DIR
120+
project_dir: str = field(default_factory=lambda: str(_get_project_dir()))
121+
profiles_dir: str = field(default_factory=lambda: str(_get_profiles_dir()))
118122
target: str | None = None
119123
threads: int = 1
120124
vars: dict[str, t.Any] = field(default_factory=dict)
121125
profile: str | None = None
122126

123-
single_threaded: bool = True
124127
quiet: bool = True
125128
use_experimental_parser: bool = True
126129
static_parser: bool = True
@@ -130,6 +133,11 @@ class DbtConfiguration:
130133
which: str = "zezima was here"
131134
REQUIRE_RESOURCE_NAMES_WITHOUT_SPACES: bool = field(default_factory=bool)
132135

136+
@property
137+
def single_threaded(self) -> bool:
138+
"""Return whether the project is single-threaded."""
139+
return self.threads <= 1
140+
133141

134142
_use_slots = {}
135143
if sys.version_info >= (3, 10):
@@ -181,14 +189,15 @@ def __new__(
181189
target: str | None = None,
182190
project_dir: str | None = None,
183191
profiles_dir: str | None = None,
192+
profile: str | None = None,
184193
threads: int = 1,
185194
vars: dict[str, t.Any] | None = None,
186195
load: bool = True,
187196
autoregister: bool = True,
188197
) -> DbtProject:
189198
"""Create a new DbtProject instance, ensuring only one instance per project root."""
190199
with cls._instance_lock:
191-
p = Path(project_dir or DEFAULT_PROJECT_DIR).expanduser().resolve()
200+
p = Path(project_dir or _get_project_dir()).expanduser().resolve()
192201
project = cls._instances.get(p)
193202
if not project:
194203
project = super().__new__(cls)
@@ -216,6 +225,7 @@ def __init__(
216225
target: str | None = None,
217226
project_dir: str | None = None,
218227
profiles_dir: str | None = None,
228+
profile: str | None = None,
219229
threads: int = 1,
220230
vars: dict[str, t.Any] | None = None,
221231
load: bool = True,
@@ -228,15 +238,16 @@ def __init__(
228238
if project_dir is not None and profiles_dir is None:
229239
profiles_dir = str(_get_profiles_dir(project_dir).resolve())
230240

231-
project_dir = project_dir or DEFAULT_PROJECT_DIR
232-
profiles_dir = profiles_dir or DEFAULT_PROFILES_DIR
241+
project_dir = project_dir or str(_get_project_dir())
242+
profiles_dir = profiles_dir or str(_get_project_dir())
233243

234244
self._args = DbtConfiguration(
235245
target=target,
236246
profiles_dir=profiles_dir,
237247
project_dir=project_dir,
238248
threads=threads,
239249
vars=vars or {},
250+
profile=profile,
240251
)
241252

242253
set_from_args(self._args, None) # pyright: ignore[reportArgumentType]

tests/sqlfluff_templater/fixtures/dbt/templater.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,11 @@ def sqlfluff_config_path() -> str:
4242
@pytest.fixture()
4343
def dbt_templater() -> DbtTemplater:
4444
"""Return an instance of the DbtTemplater."""
45-
return t.cast(DbtTemplater, FluffConfig(overrides={"dialect": "ansi"}).get_templater("dbt"))
45+
try:
46+
return t.cast(DbtTemplater, FluffConfig(overrides={"dialect": "ansi"}).get_templater("dbt"))
47+
except TypeError:
48+
# 3.4+
49+
return t.cast(
50+
DbtTemplater,
51+
FluffConfig(overrides={"dialect": "ansi", "templater": "dbt"}).get_templater(),
52+
)

0 commit comments

Comments
 (0)