Skip to content

Commit 976ffee

Browse files
authored
Feat(experimental): DBT project conversion (#4495)
1 parent 7e5f195 commit 976ffee

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+3714
-39
lines changed

sqlmesh/cli/main.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"rollback",
3434
"run",
3535
"table_name",
36+
"dbt",
3637
)
3738
SKIP_CONTEXT_COMMANDS = ("init", "ui")
3839

@@ -1219,3 +1220,39 @@ def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool
12191220
"""Import a state export file back into the state database"""
12201221
confirm = not no_confirm
12211222
obj.import_state(input_file=input_file, clear=replace, confirm=confirm)
1223+
1224+
1225+
@cli.group(no_args_is_help=True, hidden=True)
1226+
def dbt() -> None:
1227+
"""Commands for doing dbt-specific things"""
1228+
pass
1229+
1230+
1231+
@dbt.command("convert")
1232+
@click.option(
1233+
"-i",
1234+
"--input-dir",
1235+
help="Path to the DBT project",
1236+
required=True,
1237+
type=click.Path(exists=True, dir_okay=True, file_okay=False, readable=True, path_type=Path),
1238+
)
1239+
@click.option(
1240+
"-o",
1241+
"--output-dir",
1242+
required=True,
1243+
help="Path to write out the converted SQLMesh project",
1244+
type=click.Path(exists=False, dir_okay=True, file_okay=False, readable=True, path_type=Path),
1245+
)
1246+
@click.option("--no-prompts", is_flag=True, help="Disable interactive prompts", default=False)
1247+
@click.pass_obj
1248+
@error_handler
1249+
@cli_analytics
1250+
def dbt_convert(obj: Context, input_dir: Path, output_dir: Path, no_prompts: bool) -> None:
1251+
"""Convert a DBT project to a SQLMesh project"""
1252+
from sqlmesh.dbt.converter.convert import convert_project_files
1253+
1254+
convert_project_files(
1255+
input_dir.absolute(),
1256+
output_dir.absolute(),
1257+
no_prompts=no_prompts,
1258+
)

sqlmesh/core/config/root.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
scheduler_config_validator,
4040
)
4141
from sqlmesh.core.config.ui import UIConfig
42-
from sqlmesh.core.loader import Loader, SqlMeshLoader
42+
from sqlmesh.core.loader import Loader, SqlMeshLoader, MigratedDbtProjectLoader
4343
from sqlmesh.core.notification_target import NotificationTarget
4444
from sqlmesh.core.user import User
4545
from sqlmesh.utils.date import to_timestamp, now
@@ -219,6 +219,13 @@ def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any:
219219
f"^{k}$": v for k, v in physical_schema_override.items()
220220
}
221221

222+
if (
223+
(variables := data.get("variables", ""))
224+
and isinstance(variables, dict)
225+
and c.MIGRATED_DBT_PROJECT_NAME in variables
226+
):
227+
data["loader"] = MigratedDbtProjectLoader
228+
222229
return data
223230

224231
@model_validator(mode="after")

sqlmesh/core/constants.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
MAX_MODEL_DEFINITION_SIZE = 10000
3232
"""Maximum number of characters in a model definition"""
3333

34+
MIGRATED_DBT_PROJECT_NAME = "__dbt_project_name__"
35+
MIGRATED_DBT_PACKAGES = "__dbt_packages__"
36+
3437

3538
# The maximum number of fork processes, used for loading projects
3639
# None means default to process pool, 1 means don't fork, :N is number of processes

sqlmesh/core/loader.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@
3838
from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns
3939
from sqlmesh.utils import UniqueKeyDict, sys_path
4040
from sqlmesh.utils.errors import ConfigError
41-
from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor
41+
from sqlmesh.utils.jinja import (
42+
JinjaMacroRegistry,
43+
MacroExtractor,
44+
SQLMESH_DBT_COMPATIBILITY_PACKAGE,
45+
)
4246
from sqlmesh.utils.metaprogramming import import_python_file
4347
from sqlmesh.utils.pydantic import validation_error_message
4448
from sqlmesh.utils.process import create_process_pool_executor
@@ -548,6 +552,7 @@ def _load_sql_models(
548552
signals: UniqueKeyDict[str, signal],
549553
cache: CacheBase,
550554
gateway: t.Optional[str],
555+
loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
551556
) -> UniqueKeyDict[str, Model]:
552557
"""Loads the sql models into a Dict"""
553558
models: UniqueKeyDict[str, Model] = UniqueKeyDict("models")
@@ -590,6 +595,7 @@ def _load_sql_models(
590595
infer_names=self.config.model_naming.infer_names,
591596
signal_definitions=signals,
592597
default_catalog_per_gateway=self.context.default_catalog_per_gateway,
598+
**loading_default_kwargs or {},
593599
)
594600

595601
with create_process_pool_executor(
@@ -942,3 +948,104 @@ def _model_cache_entry_id(self, model_path: Path) -> str:
942948
self._loader.context.gateway or self._loader.config.default_gateway_name,
943949
]
944950
)
951+
952+
953+
class MigratedDbtProjectLoader(SqlMeshLoader):
954+
@property
955+
def migrated_dbt_project_name(self) -> str:
956+
return self.config.variables[c.MIGRATED_DBT_PROJECT_NAME]
957+
958+
def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]:
959+
from sqlmesh.dbt.converter.common import infer_dbt_package_from_path
960+
from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS
961+
962+
# Store a copy of the macro registry
963+
standard_macros = macro.get_registry()
964+
965+
jinja_macros = JinjaMacroRegistry(
966+
create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE,
967+
top_level_packages=["dbt", self.migrated_dbt_project_name],
968+
)
969+
extractor = MacroExtractor()
970+
971+
macros_max_mtime: t.Optional[float] = None
972+
973+
for path in self._glob_paths(
974+
self.config_path / c.MACROS,
975+
ignore_patterns=self.config.ignore_patterns,
976+
extension=".py",
977+
):
978+
if import_python_file(path, self.config_path):
979+
self._track_file(path)
980+
macro_file_mtime = self._path_mtimes[path]
981+
macros_max_mtime = (
982+
max(macros_max_mtime, macro_file_mtime)
983+
if macros_max_mtime
984+
else macro_file_mtime
985+
)
986+
987+
for path in self._glob_paths(
988+
self.config_path / c.MACROS,
989+
ignore_patterns=self.config.ignore_patterns,
990+
extension=".sql",
991+
):
992+
self._track_file(path)
993+
macro_file_mtime = self._path_mtimes[path]
994+
macros_max_mtime = (
995+
max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime
996+
)
997+
998+
with open(path, "r", encoding="utf-8") as file:
999+
try:
1000+
package = infer_dbt_package_from_path(path) or self.migrated_dbt_project_name
1001+
1002+
jinja_macros.add_macros(
1003+
extractor.extract(file.read(), dialect=self.config.model_defaults.dialect),
1004+
package=package,
1005+
)
1006+
except Exception as e:
1007+
raise ConfigError(f"Failed to load macro file: {path}", e)
1008+
1009+
self._macros_max_mtime = macros_max_mtime
1010+
1011+
macros = macro.get_registry()
1012+
macro.set_registry(standard_macros)
1013+
1014+
connection_config = self.context.connection_config
1015+
# this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work
1016+
if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_):
1017+
try:
1018+
jinja_macros.add_globals(
1019+
{
1020+
"target": dbt_config_type.from_sqlmesh(
1021+
connection_config,
1022+
name=self.config.default_gateway_name,
1023+
).attribute_dict()
1024+
}
1025+
)
1026+
except NotImplementedError:
1027+
raise ConfigError(f"Unsupported dbt target type: {connection_config.type_}")
1028+
1029+
return macros, jinja_macros
1030+
1031+
def _load_sql_models(
1032+
self,
1033+
macros: MacroRegistry,
1034+
jinja_macros: JinjaMacroRegistry,
1035+
audits: UniqueKeyDict[str, ModelAudit],
1036+
signals: UniqueKeyDict[str, signal],
1037+
cache: CacheBase,
1038+
gateway: t.Optional[str],
1039+
loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None,
1040+
) -> UniqueKeyDict[str, Model]:
1041+
return super()._load_sql_models(
1042+
macros=macros,
1043+
jinja_macros=jinja_macros,
1044+
audits=audits,
1045+
signals=signals,
1046+
cache=cache,
1047+
gateway=gateway,
1048+
loading_default_kwargs=dict(
1049+
migrated_dbt_project_name=self.migrated_dbt_project_name,
1050+
),
1051+
)

sqlmesh/core/model/definition.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,6 +2017,7 @@ def load_sql_based_model(
20172017
variables: t.Optional[t.Dict[str, t.Any]] = None,
20182018
infer_names: t.Optional[bool] = False,
20192019
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
2020+
migrated_dbt_project_name: t.Optional[str] = None,
20202021
**kwargs: t.Any,
20212022
) -> Model:
20222023
"""Load a model from a parsed SQLMesh model SQL file.
@@ -2193,6 +2194,7 @@ def load_sql_based_model(
21932194
query_or_seed_insert,
21942195
kind=kind,
21952196
time_column_format=time_column_format,
2197+
migrated_dbt_project_name=migrated_dbt_project_name,
21962198
**common_kwargs,
21972199
)
21982200

@@ -2400,6 +2402,7 @@ def _create_model(
24002402
signal_definitions: t.Optional[SignalRegistry] = None,
24012403
variables: t.Optional[t.Dict[str, t.Any]] = None,
24022404
blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None,
2405+
migrated_dbt_project_name: t.Optional[str] = None,
24032406
**kwargs: t.Any,
24042407
) -> Model:
24052408
validate_extra_and_required_fields(
@@ -2455,13 +2458,28 @@ def _create_model(
24552458

24562459
if jinja_macros:
24572460
jinja_macros = (
2458-
jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references)
2461+
jinja_macros
2462+
if jinja_macros.trimmed
2463+
else jinja_macros.trim(jinja_macro_references, package=migrated_dbt_project_name)
24592464
)
24602465
else:
24612466
jinja_macros = JinjaMacroRegistry()
24622467

2463-
for jinja_macro in jinja_macros.root_macros.values():
2464-
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
2468+
if migrated_dbt_project_name:
2469+
# extract {{ var() }} references used in all jinja macro dependencies to check for any variables specific
2470+
# to a migrated DBT package and resolve them accordingly
2471+
# vars are added into __sqlmesh_vars__ in the Python env so that the native SQLMesh var() function can resolve them
2472+
variables = variables or {}
2473+
2474+
nested_macro_used_variables, flattened_package_variables = (
2475+
_extract_migrated_dbt_variable_references(jinja_macros, variables)
2476+
)
2477+
2478+
used_variables.update(nested_macro_used_variables)
2479+
variables.update(flattened_package_variables)
2480+
else:
2481+
for jinja_macro in jinja_macros.root_macros.values():
2482+
used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1])
24652483

24662484
model = klass(
24672485
name=name,
@@ -2844,7 +2862,7 @@ def render_expression(
28442862
"cron_tz": lambda value: exp.Literal.string(value),
28452863
"partitioned_by_": _single_expr_or_tuple,
28462864
"clustered_by": _single_expr_or_tuple,
2847-
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)),
2865+
"depends_on_": lambda value: exp.Tuple(expressions=sorted(value)) if value else "()",
28482866
"pre": _list_of_calls_to_exp,
28492867
"post": _list_of_calls_to_exp,
28502868
"audits": _list_of_calls_to_exp,
@@ -2915,4 +2933,37 @@ def clickhouse_partition_func(
29152933
)
29162934

29172935

2936+
def _extract_migrated_dbt_variable_references(
2937+
jinja_macros: JinjaMacroRegistry, project_variables: t.Dict[str, t.Any]
2938+
) -> t.Tuple[t.Set[str], t.Dict[str, t.Any]]:
2939+
if not jinja_macros.trimmed:
2940+
raise ValueError("Expecting a trimmed JinjaMacroRegistry")
2941+
2942+
used_variables = set()
2943+
# note: JinjaMacroRegistry is trimmed here so "all_macros" should be just be all the macros used by this model
2944+
for _, _, jinja_macro in jinja_macros.all_macros:
2945+
_, extracted_variable_names = extract_macro_references_and_variables(jinja_macro.definition)
2946+
used_variables.update(extracted_variable_names)
2947+
2948+
flattened = {}
2949+
if (dbt_package_variables := project_variables.get(c.MIGRATED_DBT_PACKAGES)) and isinstance(
2950+
dbt_package_variables, dict
2951+
):
2952+
# flatten the nested dict structure from the migrated dbt package variables in the SQLmesh config into __dbt_packages.<package>.<variable>
2953+
# to match what extract_macro_references_and_variables() returns. This allows the usage checks in create_python_env() to work
2954+
def _flatten(prefix: str, root: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
2955+
acc = {}
2956+
for k, v in root.items():
2957+
key_with_prefix = f"{prefix}.{k}"
2958+
if isinstance(v, dict):
2959+
acc.update(_flatten(key_with_prefix, v))
2960+
else:
2961+
acc[key_with_prefix] = v
2962+
return acc
2963+
2964+
flattened = _flatten(c.MIGRATED_DBT_PACKAGES, dbt_package_variables)
2965+
2966+
return used_variables, flattened
2967+
2968+
29182969
TIME_COL_PARTITION_FUNC = {"clickhouse": clickhouse_partition_func}

sqlmesh/core/model/kind.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from enum import Enum
55
from typing_extensions import Self
66

7-
from pydantic import Field
7+
from pydantic import Field, BeforeValidator
88
from sqlglot import exp
99
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
1010
from sqlglot.optimizer.qualify_columns import quote_identifiers
@@ -33,6 +33,7 @@
3333
field_validator,
3434
get_dialect,
3535
validate_string,
36+
positive_int_validator,
3637
)
3738

3839

@@ -455,7 +456,7 @@ class IncrementalByUniqueKeyKind(_IncrementalBy):
455456
unique_key: SQLGlotListOfFields
456457
when_matched: t.Optional[exp.Whens] = None
457458
merge_filter: t.Optional[exp.Expression] = None
458-
batch_concurrency: t.Literal[1] = 1
459+
batch_concurrency: t.Annotated[t.Literal[1], BeforeValidator(positive_int_validator)] = 1
459460

460461
@field_validator("when_matched", mode="before")
461462
def _when_matched_validator(

sqlmesh/core/renderer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def _resolve_table(table: str | exp.Table) -> str:
179179
)
180180

181181
render_kwargs = {
182+
"dialect": self._dialect,
182183
**date_dict(
183184
to_datetime(execution_time or c.EPOCH),
184185
start_time,

sqlmesh/dbt/adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ def __init__(
3838
self.jinja_globals = jinja_globals.copy() if jinja_globals else {}
3939
self.jinja_globals["adapter"] = self
4040
self.project_dialect = project_dialect
41+
self.jinja_globals["dialect"] = (
42+
project_dialect # so the dialect is available in the jinja env created by self.dispatch()
43+
)
4144
self.quote_policy = quote_policy or Policy()
4245

4346
@abc.abstractmethod

sqlmesh/dbt/builtin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class Var:
156156
def __init__(self, variables: t.Dict[str, t.Any]) -> None:
157157
self.variables = variables
158158

159-
def __call__(self, name: str, default: t.Optional[t.Any] = None) -> t.Any:
159+
def __call__(self, name: str, default: t.Optional[t.Any] = None, **kwargs: t.Any) -> t.Any:
160160
return self.variables.get(name, default)
161161

162162
def has_var(self, name: str) -> bool:

sqlmesh/dbt/converter/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)