Skip to content

Commit 6082a00

Browse files
committed
Feat: Tag all queries under a plan
1 parent ca58880 commit 6082a00

File tree

6 files changed

+56
-19
lines changed

6 files changed

+56
-19
lines changed

sqlmesh/core/config/scheduler.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ class SchedulerConfig(abc.ABC):
2828
"""Abstract base class for Scheduler configurations."""
2929

3030
@abc.abstractmethod
31-
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
31+
def create_plan_evaluator(
32+
self, context: GenericContext, job_id: t.Optional[str] = None
33+
) -> PlanEvaluator:
3234
"""Creates a Plan Evaluator instance.
3335
3436
Args:
3537
context: The SQLMesh Context.
38+
job_id: The plan ID.
3639
"""
3740

3841
@abc.abstractmethod
@@ -127,10 +130,12 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
127130

128131
type_: t.Literal["builtin"] = Field(alias="type", default="builtin")
129132

130-
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
133+
def create_plan_evaluator(
134+
self, context: GenericContext, job_id: t.Optional[str] = None
135+
) -> PlanEvaluator:
131136
return BuiltInPlanEvaluator(
132137
state_sync=context.state_sync,
133-
snapshot_evaluator=context.snapshot_evaluator,
138+
snapshot_evaluator=context.snapshot_evaluator(job_id),
134139
create_scheduler=context.create_scheduler,
135140
default_catalog=context.default_catalog,
136141
console=context.console,

sqlmesh/core/context.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,14 @@ def engine_adapter(self) -> EngineAdapter:
446446
self._engine_adapter = self.connection_config.create_engine_adapter()
447447
return self._engine_adapter
448448

449-
@property
450-
def snapshot_evaluator(self) -> SnapshotEvaluator:
451-
if not self._snapshot_evaluator:
449+
def snapshot_evaluator(self, job_id: t.Optional[str] = None) -> SnapshotEvaluator:
450+
# Cache snapshot evaluators by job_id to avoid old job_ids being attached to future Context operations
451+
if not self._snapshot_evaluator or any(
452+
adapter._job_id != job_id for adapter in self._snapshot_evaluator.adapters.values()
453+
):
452454
self._snapshot_evaluator = SnapshotEvaluator(
453455
{
454-
gateway: adapter.with_log_level(logging.INFO)
456+
gateway: adapter.with_log_level(logging.INFO, job_id)
455457
for gateway, adapter in self.engine_adapters.items()
456458
},
457459
ddl_concurrent_tasks=self.concurrent_tasks,
@@ -538,7 +540,9 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
538540

539541
return self.create_scheduler(snapshots)
540542

541-
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
543+
def create_scheduler(
544+
self, snapshots: t.Iterable[Snapshot], job_id: t.Optional[str] = None
545+
) -> Scheduler:
542546
"""Creates the built-in scheduler.
543547
544548
Args:
@@ -549,7 +553,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
549553
"""
550554
return Scheduler(
551555
snapshots,
552-
self.snapshot_evaluator,
556+
self.snapshot_evaluator(job_id),
553557
self.state_sync,
554558
default_catalog=self.default_catalog,
555559
max_workers=self.concurrent_tasks,
@@ -714,7 +718,7 @@ def run(
714718
NotificationEvent.RUN_START, environment=environment
715719
)
716720
analytics_run_id = analytics.collector.on_run_start(
717-
engine_type=self.snapshot_evaluator.adapter.dialect,
721+
engine_type=self.snapshot_evaluator().adapter.dialect,
718722
state_sync_type=self.state_sync.state_type(),
719723
)
720724
self._load_materializations()
@@ -1076,7 +1080,7 @@ def evaluate(
10761080
and not parent_snapshot.categorized
10771081
]
10781082

1079-
df = self.snapshot_evaluator.evaluate_and_fetch(
1083+
df = self.snapshot_evaluator().evaluate_and_fetch(
10801084
snapshot,
10811085
start=start,
10821086
end=end,
@@ -2110,7 +2114,7 @@ def audit(
21102114
errors = []
21112115
skipped_count = 0
21122116
for snapshot in snapshots:
2113-
for audit_result in self.snapshot_evaluator.audit(
2117+
for audit_result in self.snapshot_evaluator().audit(
21142118
snapshot=snapshot,
21152119
start=start,
21162120
end=end,
@@ -2142,7 +2146,7 @@ def audit(
21422146
self.console.log_status_update(f"Got {error.count} results, expected 0.")
21432147
if error.query:
21442148
self.console.show_sql(
2145-
f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}"
2149+
f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}"
21462150
)
21472151

21482152
self.console.log_status_update("Done.")
@@ -2388,7 +2392,7 @@ def _run(
23882392
return completion_status
23892393

23902394
def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
2391-
self._scheduler.create_plan_evaluator(self).evaluate(
2395+
self._scheduler.create_plan_evaluator(self, job_id=plan.plan_id).evaluate(
23922396
plan.to_evaluatable(), circuit_breaker=circuit_breaker
23932397
)
23942398

@@ -2682,7 +2686,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
26822686
)
26832687

26842688
# Remove the expired snapshots tables
2685-
self.snapshot_evaluator.cleanup(
2689+
self.snapshot_evaluator().cleanup(
26862690
target_snapshots=cleanup_targets,
26872691
on_complete=self.console.update_cleanup_progress,
26882692
)

sqlmesh/core/engine_adapter/base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
pre_ping: bool = False,
124124
pretty_sql: bool = False,
125125
shared_connection: bool = False,
126+
job_id: t.Optional[str] = None,
126127
**kwargs: t.Any,
127128
):
128129
self.dialect = dialect.lower() or self.DIALECT
@@ -144,8 +145,9 @@ def __init__(
144145
self._pre_ping = pre_ping
145146
self._pretty_sql = pretty_sql
146147
self._multithreaded = multithreaded
148+
self._job_id = job_id
147149

148-
def with_log_level(self, level: int) -> EngineAdapter:
150+
def with_log_level(self, level: int, job_id: t.Optional[str] = None) -> EngineAdapter:
149151
adapter = self.__class__(
150152
self._connection_pool,
151153
dialect=self.dialect,
@@ -156,6 +158,7 @@ def with_log_level(self, level: int) -> EngineAdapter:
156158
null_connection=True,
157159
multithreaded=self._multithreaded,
158160
pretty_sql=self._pretty_sql,
161+
job_id=job_id,
159162
**self._extra_config,
160163
)
161164

@@ -2175,6 +2178,9 @@ def execute(
21752178
else:
21762179
sql = t.cast(str, e)
21772180

2181+
if self._job_id:
2182+
sql = f"/* sqlmesh_ref: {self._job_id} */ {sql}"
2183+
21782184
self._log_sql(
21792185
sql,
21802186
expression=e if isinstance(e, exp.Expression) else None,

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
266266
def push_plan(context: Context, plan: Plan) -> None:
267267
plan_evaluator = BuiltInPlanEvaluator(
268268
context.state_sync,
269-
context.snapshot_evaluator,
269+
context.snapshot_evaluator(),
270270
context.create_scheduler,
271271
context.default_catalog,
272272
)

tests/core/test_integration.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1137,7 +1137,7 @@ def test_non_breaking_change_after_forward_only_in_dev(
11371137
init_and_plan_context: t.Callable, has_view_binding: bool
11381138
):
11391139
context, plan = init_and_plan_context("examples/sushi")
1140-
context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding
1140+
context.snapshot_evaluator().adapter.HAS_VIEW_BINDING = has_view_binding
11411141
context.apply(plan)
11421142

11431143
model = context.get_model("sushi.waiter_revenue_by_day")
@@ -6467,3 +6467,25 @@ def plan_with_output(ctx: Context, environment: str):
64676467
for environment in ["dev", "prod"]:
64686468
context_diff = ctx._context_diff(environment)
64696469
assert context_diff.environment == environment
6470+
6471+
6472+
def test_plan_evaluator_job_id(tmp_path: Path, mocker: MockerFixture):
6473+
def _to_sqls(mock_logger):
6474+
return [call[0][0] for call in mock_logger.call_args_list]
6475+
6476+
create_temp_file(
6477+
tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
6478+
)
6479+
6480+
# Case 1: Ensure that the job id (plan_id) is included in the SQL
6481+
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6482+
ctx = Context(paths=[tmp_path], config=Config())
6483+
plan = ctx.plan(auto_apply=True, no_prompts=True)
6484+
6485+
assert any(f"/* sqlmesh_ref: {plan.plan_id} */" in sql for sql in _to_sqls(mock_logger))
6486+
6487+
# Case 2: Ensure that the previous job id is not included in the SQL for other operations
6488+
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
6489+
ctx.snapshot_evaluator().adapter.execute("SELECT 1")
6490+
6491+
assert not any(f"/* sqlmesh_ref: {plan.plan_id} */" in sql for sql in _to_sqls(mock_logger))

tests/core/test_plan_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):
5959

6060
evaluator = BuiltInPlanEvaluator(
6161
sushi_context.state_sync,
62-
sushi_context.snapshot_evaluator,
62+
sushi_context.snapshot_evaluator(),
6363
sushi_context.create_scheduler,
6464
sushi_context.default_catalog,
6565
console=sushi_context.console,

0 commit comments

Comments
 (0)