|
38 | 38 | from sqlmesh.core.test import ModelTestMetadata, filter_tests_by_patterns |
39 | 39 | from sqlmesh.utils import UniqueKeyDict, sys_path |
40 | 40 | from sqlmesh.utils.errors import ConfigError |
41 | | -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor |
| 41 | +from sqlmesh.utils.jinja import ( |
| 42 | + JinjaMacroRegistry, |
| 43 | + MacroExtractor, |
| 44 | + SQLMESH_DBT_COMPATIBILITY_PACKAGE, |
| 45 | +) |
42 | 46 | from sqlmesh.utils.metaprogramming import import_python_file |
43 | 47 | from sqlmesh.utils.pydantic import validation_error_message |
44 | 48 | from sqlmesh.utils.process import create_process_pool_executor |
@@ -548,6 +552,7 @@ def _load_sql_models( |
548 | 552 | signals: UniqueKeyDict[str, signal], |
549 | 553 | cache: CacheBase, |
550 | 554 | gateway: t.Optional[str], |
| 555 | + loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None, |
551 | 556 | ) -> UniqueKeyDict[str, Model]: |
552 | 557 | """Loads the sql models into a Dict""" |
553 | 558 | models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") |
@@ -590,6 +595,7 @@ def _load_sql_models( |
590 | 595 | infer_names=self.config.model_naming.infer_names, |
591 | 596 | signal_definitions=signals, |
592 | 597 | default_catalog_per_gateway=self.context.default_catalog_per_gateway, |
| 598 | + **loading_default_kwargs or {}, |
593 | 599 | ) |
594 | 600 |
|
595 | 601 | with create_process_pool_executor( |
@@ -942,3 +948,104 @@ def _model_cache_entry_id(self, model_path: Path) -> str: |
942 | 948 | self._loader.context.gateway or self._loader.config.default_gateway_name, |
943 | 949 | ] |
944 | 950 | ) |
| 951 | + |
| 952 | + |
| 953 | +class MigratedDbtProjectLoader(SqlMeshLoader): |
| 954 | + @property |
| 955 | + def migrated_dbt_project_name(self) -> str: |
| 956 | + return self.config.variables[c.MIGRATED_DBT_PROJECT_NAME] |
| 957 | + |
| 958 | + def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: |
| 959 | + from sqlmesh.dbt.converter.common import infer_dbt_package_from_path |
| 960 | + from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS |
| 961 | + |
| 962 | + # Store a copy of the macro registry |
| 963 | + standard_macros = macro.get_registry() |
| 964 | + |
| 965 | + jinja_macros = JinjaMacroRegistry( |
| 966 | + create_builtins_module=SQLMESH_DBT_COMPATIBILITY_PACKAGE, |
| 967 | + top_level_packages=["dbt", self.migrated_dbt_project_name], |
| 968 | + ) |
| 969 | + extractor = MacroExtractor() |
| 970 | + |
| 971 | + macros_max_mtime: t.Optional[float] = None |
| 972 | + |
| 973 | + for path in self._glob_paths( |
| 974 | + self.config_path / c.MACROS, |
| 975 | + ignore_patterns=self.config.ignore_patterns, |
| 976 | + extension=".py", |
| 977 | + ): |
| 978 | + if import_python_file(path, self.config_path): |
| 979 | + self._track_file(path) |
| 980 | + macro_file_mtime = self._path_mtimes[path] |
| 981 | + macros_max_mtime = ( |
| 982 | + max(macros_max_mtime, macro_file_mtime) |
| 983 | + if macros_max_mtime |
| 984 | + else macro_file_mtime |
| 985 | + ) |
| 986 | + |
| 987 | + for path in self._glob_paths( |
| 988 | + self.config_path / c.MACROS, |
| 989 | + ignore_patterns=self.config.ignore_patterns, |
| 990 | + extension=".sql", |
| 991 | + ): |
| 992 | + self._track_file(path) |
| 993 | + macro_file_mtime = self._path_mtimes[path] |
| 994 | + macros_max_mtime = ( |
| 995 | + max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime |
| 996 | + ) |
| 997 | + |
| 998 | + with open(path, "r", encoding="utf-8") as file: |
| 999 | + try: |
| 1000 | + package = infer_dbt_package_from_path(path) or self.migrated_dbt_project_name |
| 1001 | + |
| 1002 | + jinja_macros.add_macros( |
| 1003 | + extractor.extract(file.read(), dialect=self.config.model_defaults.dialect), |
| 1004 | + package=package, |
| 1005 | + ) |
| 1006 | + except Exception as e: |
| 1007 | + raise ConfigError(f"Failed to load macro file: {path}", e) |
| 1008 | + |
| 1009 | + self._macros_max_mtime = macros_max_mtime |
| 1010 | + |
| 1011 | + macros = macro.get_registry() |
| 1012 | + macro.set_registry(standard_macros) |
| 1013 | + |
| 1014 | + connection_config = self.context.connection_config |
| 1015 | + # this triggers the DBT create_builtins_module to have a `target` property which is required for a bunch of DBT macros to work |
| 1016 | + if dbt_config_type := TARGET_TYPE_TO_CONFIG_CLASS.get(connection_config.type_): |
| 1017 | + try: |
| 1018 | + jinja_macros.add_globals( |
| 1019 | + { |
| 1020 | + "target": dbt_config_type.from_sqlmesh( |
| 1021 | + connection_config, |
| 1022 | + name=self.config.default_gateway_name, |
| 1023 | + ).attribute_dict() |
| 1024 | + } |
| 1025 | + ) |
| 1026 | + except NotImplementedError: |
| 1027 | + raise ConfigError(f"Unsupported dbt target type: {connection_config.type_}") |
| 1028 | + |
| 1029 | + return macros, jinja_macros |
| 1030 | + |
| 1031 | + def _load_sql_models( |
| 1032 | + self, |
| 1033 | + macros: MacroRegistry, |
| 1034 | + jinja_macros: JinjaMacroRegistry, |
| 1035 | + audits: UniqueKeyDict[str, ModelAudit], |
| 1036 | + signals: UniqueKeyDict[str, signal], |
| 1037 | + cache: CacheBase, |
| 1038 | + gateway: t.Optional[str], |
| 1039 | + loading_default_kwargs: t.Optional[t.Dict[str, t.Any]] = None, |
| 1040 | + ) -> UniqueKeyDict[str, Model]: |
| 1041 | + return super()._load_sql_models( |
| 1042 | + macros=macros, |
| 1043 | + jinja_macros=jinja_macros, |
| 1044 | + audits=audits, |
| 1045 | + signals=signals, |
| 1046 | + cache=cache, |
| 1047 | + gateway=gateway, |
| 1048 | + loading_default_kwargs=dict( |
| 1049 | + migrated_dbt_project_name=self.migrated_dbt_project_name, |
| 1050 | + ), |
| 1051 | + ) |
0 commit comments