Skip to content

Commit c8e95d4

Browse files
committed
Refactor APIs
1 parent 6082a00 commit c8e95d4

File tree

9 files changed

+35
-25
lines changed

9 files changed

+35
-25
lines changed

sqlmesh/core/config/scheduler.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,11 @@ class SchedulerConfig(abc.ABC):
2828
"""Abstract base class for Scheduler configurations."""
2929

3030
@abc.abstractmethod
31-
def create_plan_evaluator(
32-
self, context: GenericContext, job_id: t.Optional[str] = None
33-
) -> PlanEvaluator:
31+
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
3432
"""Creates a Plan Evaluator instance.
3533
3634
Args:
3735
context: The SQLMesh Context.
38-
job_id: The plan ID.
3936
"""
4037

4138
@abc.abstractmethod
@@ -130,12 +127,9 @@ class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig)
130127

131128
type_: t.Literal["builtin"] = Field(alias="type", default="builtin")
132129

133-
def create_plan_evaluator(
134-
self, context: GenericContext, job_id: t.Optional[str] = None
135-
) -> PlanEvaluator:
130+
def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator:
136131
return BuiltInPlanEvaluator(
137132
state_sync=context.state_sync,
138-
snapshot_evaluator=context.snapshot_evaluator(job_id),
139133
create_scheduler=context.create_scheduler,
140134
default_catalog=context.default_catalog,
141135
console=context.console,

sqlmesh/core/context.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def snapshot_evaluator(self, job_id: t.Optional[str] = None) -> SnapshotEvaluato
453453
):
454454
self._snapshot_evaluator = SnapshotEvaluator(
455455
{
456-
gateway: adapter.with_log_level(logging.INFO, job_id)
456+
gateway: adapter.with_settings(level=logging.INFO, job_id=job_id)
457457
for gateway, adapter in self.engine_adapters.items()
458458
},
459459
ddl_concurrent_tasks=self.concurrent_tasks,
@@ -1592,7 +1592,10 @@ def apply(
15921592
default_catalog=self.default_catalog,
15931593
console=self.console,
15941594
)
1595-
explainer.evaluate(plan.to_evaluatable())
1595+
explainer.evaluate(
1596+
plan.to_evaluatable(),
1597+
snapshot_evaluator=self.snapshot_evaluator(job_id=plan.plan_id),
1598+
)
15961599
return
15971600

15981601
self.notification_target_manager.notify(
@@ -1905,7 +1908,7 @@ def _table_diff(
19051908
)
19061909

19071910
return TableDiff(
1908-
adapter=adapter.with_log_level(logger.getEffectiveLevel()),
1911+
adapter=adapter.with_settings(logger.getEffectiveLevel()),
19091912
source=source,
19101913
target=target,
19111914
on=on,
@@ -2392,8 +2395,10 @@ def _run(
23922395
return completion_status
23932396

23942397
def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None:
2395-
self._scheduler.create_plan_evaluator(self, job_id=plan.plan_id).evaluate(
2396-
plan.to_evaluatable(), circuit_breaker=circuit_breaker
2398+
self._scheduler.create_plan_evaluator(self).evaluate(
2399+
plan.to_evaluatable(),
2400+
snapshot_evaluator=self.snapshot_evaluator(job_id=plan.plan_id),
2401+
circuit_breaker=circuit_breaker,
23972402
)
23982403

23992404
@python_api_analytics

sqlmesh/core/engine_adapter/athena.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any
4747
):
4848
# Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config
49-
# which means that EngineAdapter.with_log_level() keeps this property when it makes a clone
49+
# which means that EngineAdapter.with_settings() keeps this property when it makes a clone
5050
super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs)
5151
self.s3_warehouse_location = s3_warehouse_location
5252

sqlmesh/core/engine_adapter/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def __init__(
147147
self._multithreaded = multithreaded
148148
self._job_id = job_id
149149

150-
def with_log_level(self, level: int, job_id: t.Optional[str] = None) -> EngineAdapter:
150+
def with_settings(self, level: int, **kwargs: t.Any) -> EngineAdapter:
151151
adapter = self.__class__(
152152
self._connection_pool,
153153
dialect=self.dialect,
@@ -158,8 +158,8 @@ def with_log_level(self, level: int, job_id: t.Optional[str] = None) -> EngineAd
158158
null_connection=True,
159159
multithreaded=self._multithreaded,
160160
pretty_sql=self._pretty_sql,
161-
job_id=job_id,
162161
**self._extra_config,
162+
**kwargs,
163163
)
164164

165165
return adapter

sqlmesh/core/plan/evaluator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@
4949
class PlanEvaluator(abc.ABC):
5050
@abc.abstractmethod
5151
def evaluate(
52-
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
52+
self,
53+
plan: EvaluatablePlan,
54+
snapshot_evaluator: SnapshotEvaluator,
55+
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
5356
) -> None:
5457
"""Evaluates a plan by pushing snapshots and backfilling data.
5558
@@ -60,20 +63,20 @@ def evaluate(
6063
6164
Args:
6265
plan: The plan to evaluate.
66+
snapshot_evaluator: The snapshot evaluator to use.
67+
circuit_breaker: The circuit breaker to use.
6368
"""
6469

6570

6671
class BuiltInPlanEvaluator(PlanEvaluator):
6772
def __init__(
6873
self,
6974
state_sync: StateSync,
70-
snapshot_evaluator: SnapshotEvaluator,
7175
create_scheduler: t.Callable[[t.Iterable[Snapshot]], Scheduler],
7276
default_catalog: t.Optional[str],
7377
console: t.Optional[Console] = None,
7478
):
7579
self.state_sync = state_sync
76-
self.snapshot_evaluator = snapshot_evaluator
7780
self.create_scheduler = create_scheduler
7881
self.default_catalog = default_catalog
7982
self.console = console or get_console()
@@ -82,9 +85,12 @@ def __init__(
8285
def evaluate(
8386
self,
8487
plan: EvaluatablePlan,
88+
snapshot_evaluator: SnapshotEvaluator,
8589
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
8690
) -> None:
8791
self._circuit_breaker = circuit_breaker
92+
self.snapshot_evaluator = snapshot_evaluator
93+
8894
self.console.start_plan_evaluation(plan)
8995
analytics.collector.on_plan_apply_start(
9096
plan=plan,

sqlmesh/core/plan/explainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from sqlmesh.utils import Verbosity, rich as srich, to_snake_case
2121
from sqlmesh.utils.date import to_ts
2222
from sqlmesh.utils.errors import SQLMeshError
23+
from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator
2324

2425

2526
logger = logging.getLogger(__name__)
@@ -37,7 +38,10 @@ def __init__(
3738
self.console = console or get_console()
3839

3940
def evaluate(
40-
self, plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None
41+
self,
42+
plan: EvaluatablePlan,
43+
snapshot_evaluator: SnapshotEvaluator,
44+
circuit_breaker: t.Optional[t.Callable[[], bool]] = None,
4145
) -> None:
4246
plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog)
4347
explainer_console = _get_explainer_console(

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,10 @@ def duck_conn() -> duckdb.DuckDBPyConnection:
266266
def push_plan(context: Context, plan: Plan) -> None:
267267
plan_evaluator = BuiltInPlanEvaluator(
268268
context.state_sync,
269-
context.snapshot_evaluator(),
270269
context.create_scheduler,
271270
context.default_catalog,
272271
)
272+
plan_evaluator.snapshot_evaluator = context.snapshot_evaluator(job_id=plan.plan_id)
273273
deployability_index = DeployabilityIndex.create(context.snapshots.values())
274274
evaluatable_plan = plan.to_evaluatable()
275275
stages = plan_stages.build_plan_stages(

tests/core/test_plan_evaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot):
5959

6060
evaluator = BuiltInPlanEvaluator(
6161
sushi_context.state_sync,
62-
sushi_context.snapshot_evaluator(),
6362
sushi_context.create_scheduler,
6463
sushi_context.default_catalog,
6564
console=sushi_context.console,
6665
)
66+
evaluator.snapshot_evaluator = sushi_context.snapshot_evaluator(job_id=plan.plan_id)
67+
6768
evaluatable_plan = plan.to_evaluatable()
6869
stages = plan_stages.build_plan_stages(
6970
evaluatable_plan, sushi_context.state_sync, sushi_context.default_catalog

tests/core/test_table_diff.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,11 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture)
335335
sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"'
336336
drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"'
337337

338-
# make with_log_level() return the current instance of engine_adapter so we can still spy on _execute
338+
# make with_settings() return the current instance of engine_adapter so we can still spy on _execute
339339
mocker.patch.object(
340-
engine_adapter, "with_log_level", new_callable=lambda: lambda _: engine_adapter
340+
engine_adapter, "with_settings", new_callable=lambda: lambda _: engine_adapter
341341
)
342-
assert engine_adapter.with_log_level(1) == engine_adapter
342+
assert engine_adapter.with_settings(1) == engine_adapter
343343

344344
spy_execute = mocker.spy(engine_adapter, "_execute")
345345
mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh")

0 commit comments

Comments
 (0)