Skip to content

Commit e9272f6

Browse files
committed
Feat: improve signal CLI UX
1 parent 2acf1a2 commit e9272f6

File tree

5 files changed

+206
-12
lines changed

5 files changed

+206
-12
lines changed

sqlmesh/core/console.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,37 @@ def log_test_results(self, result: ModelTextTestResult, target_dialect: str) ->
330330
"""
331331

332332

333+
class SignalConsole(abc.ABC):
334+
@abc.abstractmethod
335+
def start_signal_progress(
336+
self,
337+
snapshot: Snapshot,
338+
total_signals: int,
339+
default_catalog: t.Optional[str],
340+
environment_naming_info: EnvironmentNamingInfo,
341+
) -> None:
342+
"""Indicates that signal checking has begun for a snapshot."""
343+
344+
@abc.abstractmethod
345+
def update_signal_progress(
346+
self,
347+
snapshot: Snapshot,
348+
signal_name: str,
349+
signal_idx: int,
350+
total_signals: int,
351+
ready_intervals: Intervals,
352+
check_intervals: Intervals,
353+
duration: float,
354+
) -> None:
355+
"""Updates the signal checking progress."""
356+
357+
@abc.abstractmethod
358+
def stop_signal_progress(self, snapshot: Snapshot) -> None:
359+
"""Indicates that signal checking has completed for a snapshot."""
360+
361+
333362
class Console(
363+
SignalConsole,
334364
PlanBuilderConsole,
335365
LinterConsole,
336366
StateExporterConsole,
@@ -536,6 +566,30 @@ def update_snapshot_evaluation_progress(
536566
def stop_evaluation_progress(self, success: bool = True) -> None:
537567
pass
538568

569+
def start_signal_progress(
570+
self,
571+
snapshot: Snapshot,
572+
total_signals: int,
573+
default_catalog: t.Optional[str],
574+
environment_naming_info: EnvironmentNamingInfo,
575+
) -> None:
576+
pass
577+
578+
def update_signal_progress(
579+
self,
580+
snapshot: Snapshot,
581+
signal_name: str,
582+
signal_idx: int,
583+
total_signals: int,
584+
ready_intervals: Intervals,
585+
check_intervals: Intervals,
586+
duration: float,
587+
) -> None:
588+
pass
589+
590+
def stop_signal_progress(self, snapshot: Snapshot) -> None:
591+
pass
592+
539593
def start_creation_progress(
540594
self,
541595
snapshots: t.List[Snapshot],
@@ -860,6 +914,8 @@ def __init__(
860914
self.table_diff_model_tasks: t.Dict[str, TaskID] = {}
861915
self.table_diff_progress_live: t.Optional[Live] = None
862916

917+
self.signal_status_tree: t.Optional[Tree] = None
918+
863919
self.verbosity = verbosity
864920
self.dialect = dialect
865921
self.ignore_warnings = ignore_warnings
@@ -901,6 +957,9 @@ def start_evaluation_progress(
901957
audit_only: bool = False,
902958
) -> None:
903959
"""Indicates that a new snapshot evaluation/auditing progress has begun."""
960+
# Add a newline to separate signal checking from evaluation
961+
self._print("")
962+
904963
if not self.evaluation_progress_live:
905964
self.evaluation_total_progress = make_progress_bar(
906965
"Executing model batches" if not audit_only else "Auditing models", self.console
@@ -1050,6 +1109,75 @@ def stop_evaluation_progress(self, success: bool = True) -> None:
10501109
self.environment_naming_info = EnvironmentNamingInfo()
10511110
self.default_catalog = None
10521111

1112+
def start_signal_progress(
1113+
self,
1114+
snapshot: Snapshot,
1115+
total_signals: int,
1116+
default_catalog: t.Optional[str],
1117+
environment_naming_info: EnvironmentNamingInfo,
1118+
) -> None:
1119+
"""Indicates that signal checking has begun for a snapshot."""
1120+
display_name = snapshot.display_name(
1121+
environment_naming_info,
1122+
default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None,
1123+
dialect=self.dialect,
1124+
)
1125+
self.signal_status_tree = Tree(f"Checking signals for {display_name}")
1126+
1127+
def update_signal_progress(
1128+
self,
1129+
snapshot: Snapshot,
1130+
signal_name: str,
1131+
signal_idx: int,
1132+
total_signals: int,
1133+
ready_intervals: Intervals,
1134+
check_intervals: Intervals,
1135+
duration: float,
1136+
) -> None:
1137+
"""Updates the signal checking progress."""
1138+
# Format checked intervals
1139+
check_display = []
1140+
for interval in check_intervals[:3]: # Show first 3 intervals
1141+
interval_str = _format_signal_interval(snapshot, interval)
1142+
if interval_str:
1143+
check_display.append(interval_str)
1144+
if len(check_intervals) > 3:
1145+
check_display.append(f"... and {len(check_intervals) - 3} more")
1146+
1147+
# Format ready intervals
1148+
ready_display = []
1149+
for interval in ready_intervals[:3]: # Show first 3 intervals
1150+
interval_str = _format_signal_interval(snapshot, interval)
1151+
if interval_str:
1152+
ready_display.append(interval_str)
1153+
if len(ready_intervals) > 3:
1154+
ready_display.append(f"... and {len(ready_intervals) - 3} more")
1155+
1156+
# Display signal name
1157+
tree = Tree(f"[{signal_idx + 1}/{total_signals}] {signal_name} {duration:.2f}s")
1158+
1159+
# TODO: what about full models and other cases where these lists can be empty?
1160+
check_str = ", ".join(check_display) if check_display else "no intervals"
1161+
tree.add(f"check: {check_str}")
1162+
1163+
ready_str = ", ".join(ready_display) if ready_display else "no intervals"
1164+
if ready_intervals == check_intervals:
1165+
ready_str = f"[green]ready: {ready_str}[/green]"
1166+
elif ready_intervals:
1167+
ready_str = f"[yellow]ready: {ready_str}[/yellow]"
1168+
else:
1169+
ready_str = f"[red]ready: {ready_str}[/red]"
1170+
1171+
tree.add(ready_str)
1172+
if self.signal_status_tree is not None:
1173+
self.signal_status_tree.add(tree)
1174+
1175+
def stop_signal_progress(self, snapshot: Snapshot) -> None:
1176+
"""Indicates that signal checking has completed for a snapshot."""
1177+
if self.signal_status_tree is not None:
1178+
self._print(self.signal_status_tree)
1179+
self.signal_status_tree = None
1180+
10531181
def start_creation_progress(
10541182
self,
10551183
snapshots: t.List[Snapshot],
@@ -3810,22 +3938,35 @@ def _format_audits_errors(error: NodeAuditsErrors) -> str:
38103938
return " " + "\n".join(error_messages)
38113939

38123940

3813-
def _format_evaluation_model_interval(snapshot: Snapshot, interval: Interval) -> str:
3941+
def _format_interval(snapshot: Snapshot, interval: Interval, prefix: str = "") -> str:
3942+
"""Format an interval with an optional prefix."""
38143943
if snapshot.is_model and (
38153944
snapshot.model.kind.is_incremental
38163945
or snapshot.model.kind.is_managed
38173946
or snapshot.model.kind.is_custom
38183947
):
38193948
inclusive_interval = make_inclusive(interval[0], interval[1])
38203949
if snapshot.model.interval_unit.is_date_granularity:
3821-
return f"insert {to_ds(inclusive_interval[0])} - {to_ds(inclusive_interval[1])}"
3822-
# omit end date if interval start/end on same day
3823-
if inclusive_interval[0].date() == inclusive_interval[1].date():
3824-
return f"insert {to_ds(inclusive_interval[0])} {inclusive_interval[0].strftime('%H:%M:%S')}-{inclusive_interval[1].strftime('%H:%M:%S')}"
3825-
return f"insert {inclusive_interval[0].strftime('%Y-%m-%d %H:%M:%S')} - {inclusive_interval[1].strftime('%Y-%m-%d %H:%M:%S')}"
3950+
base = f"{to_ds(inclusive_interval[0])} - {to_ds(inclusive_interval[1])}"
3951+
elif inclusive_interval[0].date() == inclusive_interval[1].date():
3952+
# omit end date if interval start/end on same day
3953+
base = f"{to_ds(inclusive_interval[0])} {inclusive_interval[0].strftime('%H:%M:%S')}-{inclusive_interval[1].strftime('%H:%M:%S')}"
3954+
else:
3955+
base = f"{inclusive_interval[0].strftime('%Y-%m-%d %H:%M:%S')} - {inclusive_interval[1].strftime('%Y-%m-%d %H:%M:%S')}"
3956+
return f"{prefix} {base}" if prefix else base
38263957
return ""
38273958

38283959

3960+
def _format_signal_interval(snapshot: Snapshot, interval: Interval) -> str:
3961+
"""Format an interval for signal output (without 'insert' prefix)."""
3962+
return _format_interval(snapshot, interval)
3963+
3964+
3965+
def _format_evaluation_model_interval(snapshot: Snapshot, interval: Interval) -> str:
3966+
"""Format an interval for evaluation output (with 'insert' prefix)."""
3967+
return _format_interval(snapshot, interval, prefix="insert")
3968+
3969+
38293970
def _create_evaluation_model_annotation(snapshot: Snapshot, interval_info: t.Optional[str]) -> str:
38303971
if snapshot.is_audit:
38313972
return "run standalone audit"

sqlmesh/core/scheduler.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@ def batch_intervals(
269269
self,
270270
merged_intervals: SnapshotToIntervals,
271271
deployability_index: t.Optional[DeployabilityIndex],
272+
environment_naming_info: EnvironmentNamingInfo,
272273
) -> t.Dict[Snapshot, Intervals]:
273274
dag = snapshots_to_dag(merged_intervals)
274275

@@ -303,7 +304,13 @@ def batch_intervals(
303304
default_catalog=self.default_catalog,
304305
)
305306

306-
intervals = snapshot.check_ready_intervals(intervals, context)
307+
intervals = snapshot.check_ready_intervals(
308+
intervals,
309+
context,
310+
console=self.console,
311+
default_catalog=self.default_catalog,
312+
environment_naming_info=environment_naming_info,
313+
)
307314
unready -= set(intervals)
308315

309316
for parent in snapshot.parents:
@@ -324,10 +331,14 @@ def batch_intervals(
324331
):
325332
batches.append((next_batch[0][0], next_batch[-1][-1]))
326333
next_batch = []
334+
327335
next_batch.append(interval)
336+
328337
if next_batch:
329338
batches.append((next_batch[0][0], next_batch[-1][-1]))
339+
330340
snapshot_batches[snapshot] = batches
341+
331342
return snapshot_batches
332343

333344
def run_merged_intervals(
@@ -359,7 +370,9 @@ def run_merged_intervals(
359370
"""
360371
execution_time = execution_time or now_timestamp()
361372

362-
batched_intervals = self.batch_intervals(merged_intervals, deployability_index)
373+
batched_intervals = self.batch_intervals(
374+
merged_intervals, deployability_index, environment_naming_info
375+
)
363376

364377
self.console.start_evaluation_progress(
365378
batched_intervals,

sqlmesh/core/snapshot/definition.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import sys
4+
import time
45
import typing as t
56
from collections import defaultdict
67
from datetime import datetime, timedelta
@@ -49,6 +50,7 @@
4950
from sqlmesh.utils.pydantic import PydanticModel, field_validator
5051

5152
if t.TYPE_CHECKING:
53+
from sqlmesh.core.console import Console
5254
from sqlglot.dialects.dialect import DialectType
5355
from sqlmesh.core.environment import EnvironmentNamingInfo
5456
from sqlmesh.core.context import ExecutionContext
@@ -947,7 +949,14 @@ def missing_intervals(
947949
model_end_ts,
948950
)
949951

950-
def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext) -> Intervals:
952+
def check_ready_intervals(
953+
self,
954+
intervals: Intervals,
955+
context: ExecutionContext,
956+
console: t.Optional[Console] = None,
957+
default_catalog: t.Optional[str] = None,
958+
environment_naming_info: t.Optional[EnvironmentNamingInfo] = None,
959+
) -> Intervals:
951960
"""Returns a list of intervals that are considered ready by the provided signal.
952961
953962
Note that this will handle gaps in the provided intervals. The returned intervals
@@ -961,7 +970,20 @@ def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext)
961970
python_env = self.model.python_env
962971
env = prepare_env(python_env)
963972

964-
for signal_name, kwargs in signals.items():
973+
if console:
974+
console.start_signal_progress(
975+
self,
976+
len(signals),
977+
default_catalog,
978+
environment_naming_info or EnvironmentNamingInfo(),
979+
)
980+
981+
for signal_idx, (signal_name, kwargs) in enumerate(signals.items()):
982+
# Capture intervals before signal check for display
983+
intervals_to_check = merge_intervals(intervals)
984+
985+
signal_start_ts = time.perf_counter()
986+
965987
try:
966988
intervals = _check_ready_intervals(
967989
env[signal_name],
@@ -978,6 +1000,23 @@ def check_ready_intervals(self, intervals: Intervals, context: ExecutionContext)
9781000
f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}"
9791001
)
9801002

1003+
duration = time.perf_counter() - signal_start_ts
1004+
1005+
if console:
1006+
console.update_signal_progress(
1007+
snapshot=self,
1008+
signal_name=signal_name,
1009+
signal_idx=signal_idx,
1010+
total_signals=len(signals),
1011+
ready_intervals=merge_intervals(intervals),
1012+
check_intervals=intervals_to_check,
1013+
duration=duration,
1014+
)
1015+
1016+
# Stop signal progress tracking
1017+
if console:
1018+
console.stop_signal_progress(self)
1019+
9811020
return intervals
9821021

9831022
def categorize_as(self, category: SnapshotChangeCategory) -> None:

tests/core/test_scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def _get_batched_missing_intervals(
7777
execution_time: t.Optional[TimeLike] = None,
7878
) -> SnapshotToIntervals:
7979
merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time)
80-
return scheduler.batch_intervals(merged_intervals, mocker.Mock())
80+
return scheduler.batch_intervals(merged_intervals, mocker.Mock(), mocker.Mock())
8181

8282
return _get_batched_missing_intervals
8383

web/server/api/endpoints/plan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from starlette.status import HTTP_204_NO_CONTENT
88

99
from sqlmesh.core.context import Context
10+
from sqlmesh.core.environment import EnvironmentNamingInfo
1011
from sqlmesh.core.plan import Plan, PlanBuilder
1112
from sqlmesh.core.snapshot.definition import SnapshotChangeCategory
1213
from sqlmesh.utils.date import make_inclusive, to_ds
@@ -132,7 +133,7 @@ def _get_plan_changes(context: Context, plan: Plan) -> models.PlanChanges:
132133
def _get_plan_backfills(context: Context, plan: Plan) -> t.Dict[str, t.Any]:
133134
"""Get plan backfills"""
134135
merged_intervals = context.scheduler().merged_missing_intervals()
135-
batches = context.scheduler().batch_intervals(merged_intervals, None)
136+
batches = context.scheduler().batch_intervals(merged_intervals, None, EnvironmentNamingInfo())
136137
tasks = {snapshot.name: len(intervals) for snapshot, intervals in batches.items()}
137138
snapshots = plan.context_diff.snapshots
138139
default_catalog = context.default_catalog

0 commit comments

Comments
 (0)