Skip to content

Commit 36776dd

Browse files
committed
Refactor test_test to validate rich output [skip ci]
1 parent 10f9313 commit 36776dd

File tree

5 files changed

+165
-108
lines changed

5 files changed

+165
-108
lines changed

sqlmesh/core/console.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,9 @@ def loading_stop(self, id: uuid.UUID) -> None:
498498
"""Stop loading for the given id."""
499499

500500
@abc.abstractmethod
501-
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
501+
def log_unit_test_results(
502+
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
503+
) -> None:
502504
"""Print the unit test results."""
503505

504506

@@ -782,7 +784,9 @@ def start_destroy(self) -> bool:
782784
def stop_destroy(self, success: bool = True) -> None:
783785
pass
784786

785-
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
787+
def log_unit_test_results(
788+
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
789+
) -> None:
786790
pass
787791

788792

@@ -1964,6 +1968,7 @@ def log_test_results(
19641968
self, result: unittest.result.TestResult, output: t.Optional[str], target_dialect: str
19651969
) -> None:
19661970
divider_length = 70
1971+
result.log_test_report()
19671972
if result.wasSuccessful():
19681973
self._print("=" * divider_length)
19691974
self._print(
@@ -2500,7 +2505,9 @@ def show_linter_violations(
25002505
else:
25012506
self.log_warning(msg)
25022507

2503-
def log_unit_test_results(self, result: ModelTextTestResult, test_duration: float) -> None:
2508+
def log_unit_test_results(
2509+
self, result: ModelTextTestResult, test_duration: t.Optional[float] = None
2510+
) -> None:
25042511
tests_run = result.testsRun
25052512
errors = result.errors
25062513
failures = result.original_failures
@@ -2524,22 +2531,27 @@ def log_unit_test_results(self, result: ModelTextTestResult, test_duration: floa
25242531

25252532
if test_description := test_case.shortDescription():
25262533
self._print(test_description)
2527-
self._print(f"{unittest.TextTestResult.separator2}\n")
2534+
self._print(f"{unittest.TextTestResult.separator2}")
25282535

25292536
if exception := failure[1]:
2530-
for arg in exception.args:
2537+
for i, arg in enumerate(exception.args):
2538+
arg = f"Exception: {arg}" if isinstance(arg, str) else arg
25312539
self._print(arg)
2532-
self._print("\n")
2540+
2541+
if i < len(exception.args) - 1:
2542+
self._print("\n")
25332543

25342544
for test_case, error in errors:
25352545
self._print(unittest.TextTestResult.separator1)
25362546
self._print(f"ERROR: {test_case}")
2547+
self._print(f"{unittest.TextTestResult.separator2}")
25372548
self._print(error)
25382549

25392550
# Output final report
25402551
self._print(unittest.TextTestResult.separator2)
2552+
test_duration_msg = f" in {test_duration:.3f}s" if test_duration else ""
25412553
self._print(
2542-
f"Ran {tests_run} {'tests' if tests_run > 1 else 'test'} in {test_duration:.3f}s \n"
2554+
f"\nRan {tests_run} {'tests' if tests_run > 1 else 'test'}{test_duration_msg} \n"
25432555
)
25442556
self._print(
25452557
f"{'OK' if is_success else 'FAILED'}{' (' + ', '.join(infos) + ')' if infos else ''}"

sqlmesh/core/test/definition.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def __init__(
6565
default_catalog: str | None = None,
6666
concurrency: bool = False,
6767
verbosity: Verbosity = Verbosity.DEFAULT,
68-
rich_output: bool = True,
6968
) -> None:
7069
"""ModelTest encapsulates a unit test for a model.
7170
@@ -90,7 +89,6 @@ def __init__(
9089
self.dialect = dialect
9190
self.concurrency = concurrency
9291
self.verbosity = verbosity
93-
self.rich_output = rich_output
9492

9593
self._fixture_table_cache: t.Dict[str, exp.Table] = {}
9694
self._normalized_column_name_cache: t.Dict[str, str] = {}
@@ -142,6 +140,12 @@ def __init__(
142140

143141
super().__init__()
144142

143+
def defaultTestResult(self) -> unittest.TestResult:
144+
from sqlmesh.core.test.result import ModelTextTestResult
145+
import sys
146+
147+
return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity)
148+
145149
def shortDescription(self) -> t.Optional[str]:
146150
return self.body.get("description")
147151

@@ -293,25 +297,23 @@ def _to_hashable(x: t.Any) -> t.Any:
293297
if expected.shape != actual.shape:
294298
_raise_if_unexpected_columns(expected.columns, actual.columns)
295299

296-
error_msg = "Data mismatch (rows are different)"
300+
args.append("Data mismatch (rows are different)")
297301

298302
missing_rows = _row_difference(expected, actual)
299303
if not missing_rows.empty:
300-
error_msg += f"\n\nMissing rows:\n\n{missing_rows}"
304+
args.append(df_to_table("Missing rows", missing_rows))
301305

302306
unexpected_rows = _row_difference(actual, expected)
307+
303308
if not unexpected_rows.empty:
304-
error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}"
309+
args.append(df_to_table("Unexpected rows", unexpected_rows))
305310

306-
args.append(error_msg)
307311
else:
308312
diff = expected.compare(actual).rename(
309313
columns={"self": "Expected", "other": "Actual"}
310314
)
311315

312-
if not self.rich_output:
313-
args.append(f"Data mismatch\n\n{diff}")
314-
elif self.verbosity == Verbosity.DEFAULT:
316+
if self.verbosity == Verbosity.DEFAULT:
315317
args.append(df_to_table("Data mismatch", diff))
316318
else:
317319
from pandas import MultiIndex
@@ -714,7 +716,6 @@ def __init__(
714716
default_catalog: str | None = None,
715717
concurrency: bool = False,
716718
verbosity: Verbosity = Verbosity.DEFAULT,
717-
rich_output: bool = True,
718719
) -> None:
719720
"""PythonModelTest encapsulates a unit test for a Python model.
720721
@@ -742,7 +743,6 @@ def __init__(
742743
default_catalog,
743744
concurrency,
744745
verbosity,
745-
rich_output,
746746
)
747747

748748
self.context = TestExecutionContext(
@@ -996,7 +996,7 @@ def df_to_table(
996996
rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60)
997997
if show_index:
998998
index_name = str(index_name) if index_name else ""
999-
rich_table.add_column(index_name)
999+
rich_table.add_column(Align.center(index_name))
10001000

10011001
for column in df.columns:
10021002
column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column)

sqlmesh/core/test/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def addSuccess(self, test: unittest.TestCase) -> None:
7676
super().addSuccess(test)
7777
self.successes.append(test)
7878

79-
def log_test_report(self, test_duration: float) -> None:
79+
def log_test_report(self, test_duration: t.Optional[float] = None) -> None:
8080
"""
8181
Log the test report following unittest's conventions.
8282

sqlmesh/core/test/runner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ def _run_single_test(
184184

185185
end_time = time.perf_counter()
186186

187+
from sqlmesh.core.console import get_console
188+
console = get_console()
189+
print(f"console {console}")
190+
187191
combined_results.log_test_report(test_duration=end_time - start_time)
188192

189193
return combined_results

0 commit comments

Comments
 (0)