Skip to content

Commit 640ef92

Browse files
committed
Feat: Tag all queries under a plan
1 parent 4a44d89 commit 640ef92

File tree

11 files changed

+123
-36
lines changed

11 files changed

+123
-36
lines changed

sqlmesh/core/config/scheduler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
130130
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
131131
return BuiltInPlanEvaluator(
132132
state_sync=context.state_sync,
133-
snapshot_evaluator=context.snapshot_evaluator,
134133
create_scheduler=context.create_scheduler,
135134
default_catalog=context.default_catalog,
136135
console=context.console,

sqlmesh/core/context.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
run_tests,
117117
)
118118
from sqlmesh.core.user import User
119-
from sqlmesh.utils import UniqueKeyDict, Verbosity
119+
from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId
120120
from sqlmesh.utils.concurrency import concurrent_apply_to_values
121121
from sqlmesh.utils.dag import DAG
122122
from sqlmesh.utils.date import (
@@ -418,7 +418,7 @@ def __init__(
418418
self.config.get_state_connection(self.gateway) or self.connection_config
419419
)
420420

421-
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
421+
self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {}
422422

423423
self.console = get_console()
424424
setattr(self.console, "dialect", self.config.dialect)
@@ -446,18 +446,22 @@ def engine_adapter(self) -> EngineAdapter:
446446
self._engine_adapter = self.connection_config.create_engine_adapter()
447447
return self._engine_adapter
448448

449-
@property
450-
def snapshot_evaluator(self) -> SnapshotEvaluator:
451-
if not self._snapshot_evaluator:
452-
self._snapshot_evaluator = SnapshotEvaluator(
449+
def snapshot_evaluator(
450+
self, correlation_id: t.Optional[CorrelationId] = None
451+
) -> SnapshotEvaluator:
452+
# Cache snapshot evaluators by correlation_id to avoid old correlation_ids being attached to future Context operations
453+
if correlation_id not in self._snapshot_evaluators:
454+
self._snapshot_evaluators[correlation_id] = SnapshotEvaluator(
453455
{
454-
gateway: adapter.with_log_level(logging.INFO)
456+
gateway: adapter.with_settings(
457+
log_level=logging.INFO, correlation_id=correlation_id
458+
)
455459
for gateway, adapter in self.engine_adapters.items()
456460
},
457461
ddl_concurrent_tasks=self.concurrent_tasks,
458462
selected_gateway=self.selected_gateway,
459463
)
460-
return self._snapshot_evaluator
464+
return self._snapshot_evaluators[correlation_id]
461465

462466
def execution_context(
463467
self,
@@ -538,7 +542,9 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
538542

539543
return self.create_scheduler(snapshots)
540544

541-
def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
545+
def create_scheduler(
546+
self, snapshots: t.Iterable[Snapshot], correlation_id: t.Optional[CorrelationId] = None
547+
) -> Scheduler:
542548
"""Creates the built-in scheduler.
543549
544550
Args:
@@ -549,7 +555,7 @@ def create_scheduler(self, snapshots: t.Iterable[Snapshot]) -> Scheduler:
549555
"""
550556
return Scheduler(
551557
snapshots,
552-
self.snapshot_evaluator,
558+
self.snapshot_evaluator(correlation_id),
553559
self.state_sync,
554560
default_catalog=self.default_catalog,
555561
max_workers=self.concurrent_tasks,
@@ -714,7 +720,7 @@ def run(
714720
NotificationEvent.RUN_START, environment=environment
715721
)
716722
analytics_run_id = analytics.collector.on_run_start(
717-
engine_type=self.snapshot_evaluator.adapter.dialect,
723+
engine_type=self.snapshot_evaluator().adapter.dialect,
718724
state_sync_type=self.state_sync.state_type(),
719725
)
720726
self._load_materializations()
@@ -1076,7 +1082,7 @@ def evaluate(
10761082
and not parent_snapshot.categorized
10771083
]
10781084

1079-
df = self.snapshot_evaluator.evaluate_and_fetch(
1085+
df = self.snapshot_evaluator().evaluate_and_fetch(
10801086
snapshot,
10811087
start=start,
10821088
end=end,
@@ -1588,7 +1594,12 @@ def apply(
15881594
default_catalog=self.default_catalog,
15891595
console=self.console,
15901596
)
1591-
explainer.evaluate(plan.to_evaluatable())
1597+
explainer.evaluate(
1598+
plan.to_evaluatable(),
1599+
snapshot_evaluator=self.snapshot_evaluator(
1600+
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
1601+
),
1602+
)
15921603
return
15931604

15941605
self.notification_target_manager.notify(
@@ -1902,7 +1913,7 @@ def _table_diff(
19021913
)
19031914

19041915
return TableDiff(
1905-
adapter=adapter.with_log_level(logger.getEffectiveLevel()),
1916+
adapter=adapter.with_settings(logger.getEffectiveLevel()),
19061917
source=source,
19071918
target=target,
19081919
on=on,
@@ -2111,7 +2122,7 @@ def audit(
21112122
errors = []
21122123
skipped_count = 0
21132124
for snapshot in snapshots:
2114-
for audit_result in self.snapshot_evaluator.audit(
2125+
for audit_result in self.snapshot_evaluator().audit(
21152126
snapshot=snapshot,
21162127
start=start,
21172128
end=end,
@@ -2143,7 +2154,7 @@ def audit(
21432154
self.console.log_status_update(f"Got {error.count} results, expected 0.")
21442155
if error.query:
21452156
self.console.show_sql(
2146-
f"{error.query.sql(dialect=self.snapshot_evaluator.adapter.dialect)}"
2157+
f"{error.query.sql(dialect=self.snapshot_evaluator().adapter.dialect)}"
21472158
)
21482159

21492160
self.console.log_status_update("Done.")
@@ -2335,11 +2346,14 @@ def print_environment_names(self) -> None:
23352346

23362347
def close(self) -> None:
23372348
"""Releases all resources allocated by this context."""
2338-
if self._snapshot_evaluator:
2339-
self._snapshot_evaluator.close()
2349+
for evaluator in self._snapshot_evaluators.values():
2350+
evaluator.close()
2351+
23402352
if self._state_sync:
23412353
self._state_sync.close()
23422354

2355+
self._snapshot_evaluators.clear()
2356+
23432357
def _run(
23442358
self,
23452359
environment: str,
@@ -2390,7 +2404,11 @@ def _run(
23902404

23912405
def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
23922406
self._scheduler.create_plan_evaluator(self).evaluate(
2393-
plan.to_evaluatable(), circuit_breaker=circuit_breaker
2407+
plan.to_evaluatable(),
2408+
snapshot_evaluator=self.snapshot_evaluator(
2409+
correlation_id=CorrelationId.from_plan_id(plan.plan_id)
2410+
),
2411+
circuit_breaker=circuit_breaker,
23942412
)
23952413

23962414
@python_api_analytics
@@ -2683,7 +2701,7 @@ def _run_janitor(self, ignore_ttl: bool = False) -> None:
26832701
)
26842702

26852703
# Remove the expired snapshots tables
2686-
self.snapshot_evaluator.cleanup(
2704+
self.snapshot_evaluator().cleanup(
26872705
target_snapshots=cleanup_targets,
26882706
on_complete=self.console.update_cleanup_progress,
26892707
)

sqlmesh/core/engine_adapter/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any
4747
):
4848
# Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config
49-
# which means that EngineAdapter.with_log_level() keeps this property when it makes a clone
49+
# which means that EngineAdapter.with_settings() keeps this property when it makes a clone
5050
super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
5151
self.s3_warehouse_location = s3_warehouse_location
5252

sqlmesh/core/engine_adapter/base.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from sqlmesh.core.model.kind import TimeColumn
4141
from sqlmesh.core.schema_diff import SchemaDiffer
42-
from sqlmesh.utils import columns_to_types_all_known, random_id
42+
from sqlmesh.utils import columns_to_types_all_known, random_id, CorrelationId
4343
from sqlmesh.utils.connection_pool import create_connection_pool, ConnectionPool
4444
from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column
4545
from sqlmesh.utils.errors import (
@@ -123,6 +123,7 @@ def __init__(
123123
pre_ping: bool = False,
124124
pretty_sql: bool = False,
125125
shared_connection: bool = False,
126+
correlation_id: t.Optional[CorrelationId] = None,
126127
**kwargs: t.Any,
127128
):
128129
self.dialect = dialect.lower() or self.DIALECT
@@ -144,19 +145,21 @@ def __init__(
144145
self._pre_ping = pre_ping
145146
self._pretty_sql = pretty_sql
146147
self._multithreaded = multithreaded
148+
self.correlation_id = correlation_id
147149

148-
def with_log_level(self, level: int) -> EngineAdapter:
150+
def with_settings(self, log_level: int, **kwargs: t.Any) -> EngineAdapter:
149151
adapter = self.__class__(
150152
self._connection_pool,
151153
dialect=self.dialect,
152154
sql_gen_kwargs=self._sql_gen_kwargs,
153155
default_catalog=self._default_catalog,
154-
execute_log_level=level,
156+
execute_log_level=log_level,
155157
register_comments=self._register_comments,
156158
null_connection=True,
157159
multithreaded=self._multithreaded,
158160
pretty_sql=self._pretty_sql,
159161
**self._extra_config,
162+
**kwargs,
160163
)
161164

162165
return adapter
@@ -2211,6 +2214,9 @@ def execute(
22112214
else:
22122215
sql = t.cast(str, e)
22132216

2217+
if self.correlation_id:
2218+
sql = f"/* {self.correlation_id} */ {sql}"
2219+
22142220
self._log_sql(
22152221
sql,
22162222
expression=e if isinstance(e, exp.Expression) else None,

sqlmesh/core/plan/evaluator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@
4949
class PlanEvaluator(abc.ABC):
5050
@abc.abstractmethod
5151
def evaluate(
52-
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
52+
self,
53+
plan: EvaluatablePlan,
54+
snapshot_evaluator: SnapshotEvaluator,
55+
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
5356
) -> None:
5457
"""Evaluates a plan by pushing snapshots and backfilling data.
5558
@@ -60,20 +63,20 @@ def evaluate(
6063
6164
Args:
6265
plan: The plan to evaluate.
66+
snapshot_evaluator: The snapshot evaluator to use.
67+
circuit_breaker: The circuit breaker to use.
6368
"""
6469

6570

6671
class BuiltInPlanEvaluator(PlanEvaluator):
6772
def __init__(
6873
self,
6974
state_sync: StateSync,
70-
snapshot_evaluator: SnapshotEvaluator,
7175
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
7276
default_catalog: t.Optional[str],
7377
console: t.Optional[Console] = None,
7478
):
7579
self.state_sync = state_sync
76-
self.snapshot_evaluator = snapshot_evaluator
7780
self.create_scheduler = create_scheduler
7881
self.default_catalog = default_catalog
7982
self.console = console or get_console()
@@ -82,9 +85,12 @@ def __init__(
8285
def evaluate(
8386
self,
8487
plan: EvaluatablePlan,
88+
snapshot_evaluator: SnapshotEvaluator,
8589
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
8690
) -> None:
8791
self._circuit_breaker = circuit_breaker
92+
self.snapshot_evaluator = snapshot_evaluator
93+
8894
self.console.start_plan_evaluation(plan)
8995
analytics.collector.on_plan_apply_start(
9096
plan=plan,

sqlmesh/core/plan/explainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2121
from sqlmesh.utils.date import to_ts
2222
from sqlmesh.utils.errors import SQLMeshError
23+
from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator
2324

2425

2526
logger = logging.getLogger(__name__)
@@ -37,7 +38,10 @@ def __init__(
3738
self.console = console or get_console()
3839

3940
def evaluate(
40-
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
41+
self,
42+
plan: EvaluatablePlan,
43+
snapshot_evaluator: SnapshotEvaluator,
44+
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
4145
) -> None:
4246
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)
4347
explainer_console = _get_explainer_console(

sqlmesh/utils/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import types
1414
import typing as t
1515
import uuid
16+
from dataclasses import dataclass
1617
from collections import defaultdict
1718
from contextlib import contextmanager
1819
from copy import deepcopy
@@ -382,3 +383,23 @@ def to_snake_case(name: str) -> str:
382383
return "".join(
383384
f"_{c.lower()}" if c.isupper() and idx != 0 else c.lower() for idx, c in enumerate(name)
384385
)
386+
387+
388+
class JobType(Enum):
389+
PLAN = "SQLMESH_PLAN"
390+
RUN = "SQLMESH_RUN"
391+
392+
393+
@dataclass(frozen=True)
394+
class CorrelationId:
395+
"""ID that is added to each query in order to identify the job that created it."""
396+
397+
job_type: JobType
398+
job_id: str
399+
400+
def __str__(self) -> str:
401+
return f"{self.job_type.value}: {self.job_id}"
402+
403+
@classmethod
404+
def from_plan_id(cls, plan_id: str) -> CorrelationId:
405+
return CorrelationId(JobType.PLAN, plan_id)

tests/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
SnapshotDataVersion,
4343
SnapshotFingerprint,
4444
)
45-
from sqlmesh.utils import random_id
45+
from sqlmesh.utils import random_id, CorrelationId
4646
from sqlmesh.utils.date import TimeLike, to_date
4747
from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
4848
from sqlmesh.core.engine_adapter.shared import CatalogSupport
@@ -266,10 +266,12 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
266266
def push_plan(context: Context, plan: Plan) -> None:
267267
plan_evaluator = BuiltInPlanEvaluator(
268268
context.state_sync,
269-
context.snapshot_evaluator,
270269
context.create_scheduler,
271270
context.default_catalog,
272271
)
272+
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(
273+
CorrelationId.from_plan_id(plan.plan_id)
274+
)
273275
deployability_index = DeployabilityIndex.create(context.snapshots.values())
274276
evaluatable_plan = plan.to_evaluatable()
275277
stages = plan_stages.build_plan_stages(

0 commit comments

Comments
 (0)