Skip to content

Commit 4b877d3

Browse files
committed
Refactor test_test to validate rich output [skip ci]
1 parent 10f9313 commit 4b877d3

File tree

6 files changed

+182
-117
lines changed

6 files changed

+182
-117
lines changed

sqlmesh/core/console.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def plan(
463463

464464
@abc.abstractmethod
465465
def log_test_results(
466-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
466+
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
467467
) -> None:
468468
"""Display the test result and output.
469469
@@ -498,7 +498,9 @@ def loading_stop(self, id: uuid.UUID) -> None:
498498
"""Stop loading for the given id."""
499499

500500
@abc.abstractmethod
501-
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
501+
def log_unit_test_results(
502+
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
503+
) -> None:
502504
"""Print the unit test results."""
503505

504506

@@ -674,7 +676,7 @@ def plan(
674676
plan_builder.apply()
675677

676678
def log_test_results(
677-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
679+
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
678680
) -> None:
679681
pass
680682

@@ -782,7 +784,9 @@ def start_destroy(self) -> bool:
782784
def stop_destroy(self, success: bool = True) -> None:
783785
pass
784786

785-
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
787+
def log_unit_test_results(
788+
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
789+
) -> None:
786790
pass
787791

788792

@@ -1961,9 +1965,13 @@ def _prompt_promote(self, plan_builder: PlanBuilder) -> None:
19611965
plan_builder.apply()
19621966

19631967
def log_test_results(
1964-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
1968+
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
19651969
) -> None:
19661970
divider_length = 70
1971+
1972+
self.log_unit_test_results(result)
1973+
1974+
self._print("\n")
19671975
if result.wasSuccessful():
19681976
self._print("=" * divider_length)
19691977
self._print(
@@ -1980,7 +1988,7 @@ def log_test_results(
19801988
)
19811989
for test, _ in result.failures + result.errors:
19821990
if isinstance(test, ModelTest):
1983-
self._print(f"Failure Test: {test.model.name} {test.test_name}")
1991+
self._print(f"Failure Test: {test.path}::{test.test_name}")
19841992
self._print("=" * divider_length)
19851993
self._print(output)
19861994

@@ -2500,7 +2508,9 @@ def show_linter_violations(
25002508
else:
25012509
self.log_warning(msg)
25022510

2503-
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
2511+
def log_unit_test_results(
2512+
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
2513+
) -> None:
25042514
tests_run = result.testsRun
25052515
errors = result.errors
25062516
failures = result.original_failures
@@ -2524,22 +2534,27 @@ def log_unit_test_results(self, result: ModelTextTestResult, test_duration: floa
25242534

25252535
if test_description := test_case.shortDescription():
25262536
self._print(test_description)
2527-
self._print(f"{unittest.TextTestResult.separator2}\n")
2537+
self._print(f"{unittest.TextTestResult.separator2}")
25282538

25292539
if exception := failure[1]:
2530-
for arg in exception.args:
2540+
for i, arg in enumerate(exception.args):
2541+
arg = f"Exception: {arg}" if isinstance(arg, str) else arg
25312542
self._print(arg)
2532-
self._print("\n")
2543+
2544+
if i < len(exception.args) - 1:
2545+
self._print("\n")
25332546

25342547
for test_case, error in errors:
25352548
self._print(unittest.TextTestResult.separator1)
25362549
self._print(f"ERROR: {test_case}")
2550+
self._print(f"{unittest.TextTestResult.separator2}")
25372551
self._print(error)
25382552

25392553
# Output final report
25402554
self._print(unittest.TextTestResult.separator2)
2555+
test_duration_msg = f" in {test_duration:.3f}s" if test_duration else ""
25412556
self._print(
2542-
f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n"
2557+
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
25432558
)
25442559
self._print(
25452560
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"
@@ -2817,7 +2832,7 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None:
28172832
self.display(radio)
28182833

28192834
def log_test_results(
2820-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
2835+
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
28212836
) -> None:
28222837
import ipywidgets as widgets
28232838

@@ -3191,8 +3206,12 @@ def log_success(self, message: str) -> None:
31913206
self._print(message)
31923207

31933208
def log_test_results(
3194-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
3209+
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
31953210
) -> None:
3211+
# self._print("```")
3212+
self.log_unit_test_results(result)
3213+
# self._print("```\n\n")
3214+
31963215
if result.wasSuccessful():
31973216
self._print(
31983217
f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n"
@@ -3204,6 +3223,7 @@ def log_test_results(
32043223
for test, _ in result.failures + result.errors:
32053224
if isinstance(test, ModelTest):
32063225
self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n")
3226+
32073227
self._print(f"```{output}```\n\n")
32083228

32093229
def log_skipped_models(self, snapshot_names: t.Set[str]) -> None:
@@ -3584,7 +3604,7 @@ def show_model_difference_summary(
35843604
self._write(f" Modified: {modified}")
35853605

35863606
def log_test_results(
3587-
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
3607+
self, result: ModelTextTestResult, output: t.Optional[str], target_dialect: str
35883608
) -> None:
35893609
self._write("Test Results:", result)
35903610

sqlmesh/core/context.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2053,7 +2053,7 @@ def test(
20532053

20542054
test_meta = self.load_model_tests(tests=tests, patterns=match_patterns)
20552055

2056-
return run_tests(
2056+
result = run_tests(
20572057
model_test_metadata=test_meta,
20582058
models=self._models,
20592059
config=self.config,
@@ -2066,6 +2066,10 @@ def test(
20662066
default_catalog_dialect=self.config.dialect or "",
20672067
)
20682068

2069+
self.console.log_test_results(result, output="", target_dialect=self.default_dialect)
2070+
2071+
return result
2072+
20692073
@python_api_analytics
20702074
def audit(
20712075
self,

sqlmesh/core/test/definition.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def __init__(
6565
default_catalog: str | None = None,
6666
concurrency: bool = False,
6767
verbosity: Verbosity = Verbosity.DEFAULT,
68-
rich_output: bool = True,
6968
) -> None:
7069
"""ModelTest encapsulates a unit test for a model.
7170
@@ -90,7 +89,6 @@ def __init__(
9089
self.dialect = dialect
9190
self.concurrency = concurrency
9291
self.verbosity = verbosity
93-
self.rich_output = rich_output
9492

9593
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
9694
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -142,6 +140,12 @@ def __init__(
142140

143141
super().__init__()
144142

143+
def defaultTestResult(self) -> unittest.TestResult:
144+
from sqlmesh.core.test.result import ModelTextTestResult
145+
import sys
146+
147+
return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)
148+
145149
def shortDescription(self) -> t.Optional[str]:
146150
return self.body.get("description")
147151

@@ -293,25 +297,23 @@ def _to_hashable(x: t.Any) -> t.Any:
293297
if expected.shape != actual.shape:
294298
_raise_if_unexpected_columns(expected.columns, actual.columns)
295299

296-
error_msg = "Data mismatch (rows are different)"
300+
args.append("Data mismatch (rows are different)")
297301

298302
missing_rows = _row_difference(expected, actual)
299303
if not missing_rows.empty:
300-
error_msg += f"\n\nMissing rows:\n\n{missing_rows}"
304+
args.append(df_to_table("Missing rows", missing_rows))
301305

302306
unexpected_rows = _row_difference(actual, expected)
307+
303308
if not unexpected_rows.empty:
304-
error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
309+
args.append(df_to_table("Unexpected rows", unexpected_rows))
305310

306-
args.append(error_msg)
307311
else:
308312
diff = expected.compare(actual).rename(
309313
columns={"self": "Expected", "other": "Actual"}
310314
)
311315

312-
if not self.rich_output:
313-
args.append(f"Data mismatch\n\n{diff}")
314-
elif self.verbosity == Verbosity.DEFAULT:
316+
if self.verbosity == Verbosity.DEFAULT:
315317
args.append(df_to_table("Data mismatch", diff))
316318
else:
317319
from pandas import MultiIndex
@@ -714,7 +716,6 @@ def __init__(
714716
default_catalog: str | None = None,
715717
concurrency: bool = False,
716718
verbosity: Verbosity = Verbosity.DEFAULT,
717-
rich_output: bool = True,
718719
) -> None:
719720
"""PythonModelTest encapsulates a unit test for a Python model.
720721
@@ -742,7 +743,6 @@ def __init__(
742743
default_catalog,
743744
concurrency,
744745
verbosity,
745-
rich_output,
746746
)
747747

748748
self.context = TestExecutionContext(
@@ -996,7 +996,7 @@ def df_to_table(
996996
rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60)
997997
if show_index:
998998
index_name = str(index_name) if index_name else ""
999-
rich_table.add_column(index_name)
999+
rich_table.add_column(Align.center(index_name))
10001000

10011001
for column in df.columns:
10021002
column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column)

sqlmesh/core/test/result.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(self, *args: t.Any, **kwargs: t.Any):
1919
self.successes = []
2020
self.original_failures: t.List[t.Tuple[unittest.TestCase, ErrorType]] = []
2121
self.original_errors: t.List[t.Tuple[unittest.TestCase, ErrorType]] = []
22+
self.duration: t.Optional[float] = None
2223

2324
def addSubTest(
2425
self,
@@ -76,7 +77,7 @@ def addSuccess(self, test: unittest.TestCase) -> None:
7677
super().addSuccess(test)
7778
self.successes.append(test)
7879

79-
def log_test_report(self, test_duration: float) -> None:
80+
def log_test_report(self, test_duration: t.Optional[float] = None) -> None:
8081
"""
8182
Log the test report following unittest's conventions.
8283

sqlmesh/core/test/runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,6 @@ def _run_single_test(
184184

185185
end_time = time.perf_counter()
186186

187-
combined_results.log_test_report(test_duration=end_time - start_time)
187+
combined_results.duration = end_time - start_time
188188

189189
return combined_results

0 commit comments

Comments
 (0)