Skip to content

Commit 0d6def1

Browse files
committed
PR Feedback 1
1 parent c8e95d4 commit 0d6def1

File tree

6 files changed

+54
-26
lines changed

6 files changed

+54
-26
lines changed

sqlmesh/core/context.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@
116116
run_tests,
117117
)
118118
from sqlmesh.core.user import User
119-
from sqlmesh.utils import UniqueKeyDict, Verbosity
119+
from sqlmesh.utils import UniqueKeyDict, Verbosity, CorrelationId
120120
from sqlmesh.utils.concurrency import concurrent_apply_to_values
121121
from sqlmesh.utils.dag import DAG
122122
from sqlmesh.utils.date import (
@@ -418,7 +418,7 @@ def __init__(
418418
self.config.get_state_connection(self.gateway) or self.connection_config
419419
)
420420

421-
self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None
421+
self._snapshot_evaluators: t.Dict[t.Optional[CorrelationId], SnapshotEvaluator] = {}
422422

423423
self.console = get_console()
424424
setattr(self.console, "dialect", self.config.dialect)
@@ -446,20 +446,18 @@ def engine_adapter(self) -> EngineAdapter:
446446
self._engine_adapter = self.connection_config.create_engine_adapter()
447447
return self._engine_adapter
448448

449-
def snapshot_evaluator(self, job_id: t.Optional[str] = None) -> SnapshotEvaluator:
449+
def snapshot_evaluator(self, job_id: t.Optional[CorrelationId] = None) -> SnapshotEvaluator:
450450
# Cache snapshot evaluators by job_id to avoid old job_ids being attached to future Context operations
451-
if not self._snapshot_evaluator or any(
452-
adapter._job_id != job_id for adapter in self._snapshot_evaluator.adapters.values()
453-
):
454-
self._snapshot_evaluator = SnapshotEvaluator(
451+
if job_id not in self._snapshot_evaluators:
452+
self._snapshot_evaluators[job_id] = SnapshotEvaluator(
455453
{
456454
gateway: adapter.with_settings(level=logging.INFO, job_id=job_id)
457455
for gateway, adapter in self.engine_adapters.items()
458456
},
459457
ddl_concurrent_tasks=self.concurrent_tasks,
460458
selected_gateway=self.selected_gateway,
461459
)
462-
return self._snapshot_evaluator
460+
return self._snapshot_evaluators[job_id]
463461

464462
def execution_context(
465463
self,
@@ -541,7 +539,7 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler:
541539
return self.create_scheduler(snapshots)
542540

543541
def create_scheduler(
544-
self, snapshots: t.Iterable[Snapshot], job_id: t.Optional[str] = None
542+
self, snapshots: t.Iterable[Snapshot], job_id: t.Optional[CorrelationId] = None
545543
) -> Scheduler:
546544
"""Creates the built-in scheduler.
547545
@@ -1594,7 +1592,7 @@ def apply(
15941592
)
15951593
explainer.evaluate(
15961594
plan.to_evaluatable(),
1597-
snapshot_evaluator=self.snapshot_evaluator(job_id=plan.plan_id),
1595+
snapshot_evaluator=self.snapshot_evaluator(job_id=CorrelationId.from_plan(plan)),
15981596
)
15991597
return
16001598

@@ -2341,8 +2339,8 @@ def print_environment_names(self) -> None:
23412339

23422340
def close(self) -> None:
23432341
"""Releases all resources allocated by this context."""
2344-
if self._snapshot_evaluator:
2345-
self._snapshot_evaluator.close()
2342+
for evaluator in self._snapshot_evaluators.values():
2343+
evaluator.close()
23462344
if self._state_sync:
23472345
self._state_sync.close()
23482346

@@ -2397,7 +2395,7 @@ def _run(
23972395
def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
23982396
self._scheduler.create_plan_evaluator(self).evaluate(
23992397
plan.to_evaluatable(),
2400-
snapshot_evaluator=self.snapshot_evaluator(job_id=plan.plan_id),
2398+
snapshot_evaluator=self.snapshot_evaluator(job_id=CorrelationId.from_plan(plan)),
24012399
circuit_breaker=circuit_breaker,
24022400
)
24032401

sqlmesh/core/engine_adapter/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
)
4040
from sqlmesh.core.model.kind import TimeColumn
4141
from sqlmesh.core.schema_diff import SchemaDiffer
42-
from sqlmesh.utils import columns_to_types_all_known, random_id
42+
from sqlmesh.utils import columns_to_types_all_known, random_id, CorrelationId
4343
from sqlmesh.utils.connection_pool import create_connection_pool, ConnectionPool
4444
from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column
4545
from sqlmesh.utils.errors import (
@@ -123,7 +123,7 @@ def __init__(
123123
pre_ping: bool = False,
124124
pretty_sql: bool = False,
125125
shared_connection: bool = False,
126-
job_id: t.Optional[str] = None,
126+
job_id: t.Optional[CorrelationId] = None,
127127
**kwargs: t.Any,
128128
):
129129
self.dialect = dialect.lower() or self.DIALECT
@@ -2179,7 +2179,7 @@ def execute(
21792179
sql = t.cast(str, e)
21802180

21812181
if self._job_id:
2182-
sql = f"/* sqlmesh_ref: {self._job_id} */ {sql}"
2182+
sql = f"/* {self._job_id} */ {sql}"
21832183

21842184
self._log_sql(
21852185
sql,

sqlmesh/utils/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import types
1414
import typing as t
1515
import uuid
16+
from dataclasses import dataclass
1617
from collections import defaultdict
1718
from contextlib import contextmanager
1819
from copy import deepcopy
@@ -25,6 +26,9 @@
2526

2627
logger = logging.getLogger(__name__)
2728

29+
if t.TYPE_CHECKING:
30+
from sqlmesh.core.plan import Plan
31+
2832
T = t.TypeVar("T")
2933
KEY = t.TypeVar("KEY", bound=t.Hashable)
3034
VALUE = t.TypeVar("VALUE")
@@ -382,3 +386,23 @@ def to_snake_case(name: str) -> str:
382386
return "".join(
383387
f"_{c.lower()}" if c.isupper() and idx != 0 else c.lower() for idx, c in enumerate(name)
384388
)
389+
390+
391+
class JobType(Enum):
392+
PLAN = "PLAN"
393+
RUN = "RUN"
394+
395+
396+
@dataclass(frozen=True)
397+
class CorrelationId:
398+
"""ID that is added to each query in order to identify the job that created it."""
399+
400+
job_type: JobType
401+
job_id: str
402+
403+
def __str__(self) -> str:
404+
return f"{self.job_type.value}: {self.job_id}"
405+
406+
@classmethod
407+
def from_plan(cls, plan: Plan) -> CorrelationId:
408+
return CorrelationId(JobType.PLAN, plan.plan_id)

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
SnapshotDataVersion,
4343
SnapshotFingerprint,
4444
)
45-
from sqlmesh.utils import random_id
45+
from sqlmesh.utils import random_id, CorrelationId
4646
from sqlmesh.utils.date import TimeLike, to_date
4747
from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path
4848
from sqlmesh.core.engine_adapter.shared import CatalogSupport
@@ -269,7 +269,7 @@ def push_plan(context: Context, plan: Plan) -> None:
269269
context.create_scheduler,
270270
context.default_catalog,
271271
)
272-
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(job_id=plan.plan_id)
272+
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(CorrelationId.from_plan(plan))
273273
deployability_index = DeployabilityIndex.create(context.snapshots.values())
274274
evaluatable_plan = plan.to_evaluatable()
275275
stages = plan_stages.build_plan_stages(

tests/core/test_integration.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
SnapshotInfoLike,
6868
SnapshotTableInfo,
6969
)
70+
from sqlmesh.utils import CorrelationId
7071
from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp
7172
from sqlmesh.utils.errors import NoChangesPlanError, SQLMeshError, PlanError, ConfigError
7273
from sqlmesh.utils.pydantic import validate_string
@@ -6469,23 +6470,27 @@ def plan_with_output(ctx: Context, environment: str):
64696470
assert context_diff.environment == environment
64706471

64716472

6472-
def test_plan_evaluator_job_id(tmp_path: Path, mocker: MockerFixture):
6473-
def _to_sqls(mock_logger):
6474-
return [call[0][0] for call in mock_logger.call_args_list]
6473+
def test_plan_evaluator_correlation_id(tmp_path: Path):
6474+
def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger):
6475+
sqls = [call[0][0] for call in mock_logger.call_args_list]
6476+
return any(f"/* {correlation_id} */" in sql for sql in sqls)
64756477

64766478
create_temp_file(
64776479
tmp_path, Path("models") / "test.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col"
64786480
)
64796481

6480-
# Case 1: Ensure that the job id (plan_id) is included in the SQL
6482+
# Case 1: Ensure that the correlation id (plan_id) is included in the SQL
64816483
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
64826484
ctx = Context(paths=[tmp_path], config=Config())
64836485
plan = ctx.plan(auto_apply=True, no_prompts=True)
64846486

6485-
assert any(f"/* sqlmesh_ref: {plan.plan_id} */" in sql for sql in _to_sqls(mock_logger))
6487+
correlation_id = CorrelationId.from_plan(plan)
6488+
assert str(correlation_id) == f"PLAN: {plan.plan_id}"
64866489

6487-
# Case 2: Ensure that the previous job id is not included in the SQL for other operations
6490+
assert _correlation_id_in_sqls(correlation_id, mock_logger)
6491+
6492+
# Case 2: Ensure that the previous correlation id is not included in the SQL for other operations
64886493
with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger:
64896494
ctx.snapshot_evaluator().adapter.execute("SELECT 1")
64906495

6491-
assert not any(f"/* sqlmesh_ref: {plan.plan_id} */" in sql for sql in _to_sqls(mock_logger))
6496+
assert not _correlation_id_in_sqls(correlation_id, mock_logger)

tests/core/test_plan_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
stages as plan_stages,
1212
)
1313
from sqlmesh.core.snapshot import SnapshotChangeCategory
14+
from sqlmesh.utils import CorrelationId
1415

1516

1617
@pytest.fixture
@@ -63,7 +64,7 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):
6364
sushi_context.default_catalog,
6465
console=sushi_context.console,
6566
)
66-
evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator(job_id=plan.plan_id)
67+
evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator(CorrelationId.from_plan(plan))
6768

6869
evaluatable_plan = plan.to_evaluatable()
6970
stages = plan_stages.build_plan_stages(

0 commit comments

Comments
 (0)