Skip to content

Commit 6f6fafb

Browse files
authored
feat: Use rich Panel for showing warning in train_test_split (#1086)
closes #1060 It is the alternative to #1060 using `rich`. I added a test to check that we can filter the warning since we are not using the usual `warnings` module. In the future, we could factor out the code in a utils to be sure that we can also transform the warnings into error.
1 parent b177651 commit 6f6fafb

File tree

2 files changed

+58
-12
lines changed

2 files changed

+58
-12
lines changed

skore/src/skore/sklearn/train_test_split/train_test_split.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99
from numpy.random import RandomState
10+
from rich.panel import Panel
1011

1112
from skore.sklearn.find_ml_task import _find_ml_task
1213
from skore.sklearn.train_test_split.warning import TRAIN_TEST_SPLIT_WARNINGS
@@ -88,26 +89,27 @@ class labels.
8889
>>> X, y = np.arange(10).reshape((5, 2)), range(5)
8990
9091
>>> # Drop-in replacement for sklearn train_test_split
91-
>>> X_train, X_test, y_train, y_test = train_test_split(X, y,
92+
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, # doctest: +SKIP
9293
... test_size=0.33, random_state=42)
93-
>>> X_train
94+
>>> X_train # doctest: +SKIP
9495
array([[4, 5],
9596
[0, 1],
9697
[6, 7]])
9798
9899
>>> # Explicit X and y, makes detection of problems easier
99-
>>> X_train, X_test, y_train, y_test = train_test_split(X=X, y=y,
100+
>>> X_train, X_test, y_train, y_test = train_test_split(X=X, y=y, # doctest: +SKIP
100101
... test_size=0.33, random_state=42)
101-
>>> X_train
102+
>>> X_train # doctest: +SKIP
102103
array([[4, 5],
103104
[0, 1],
104105
[6, 7]])
105106
106107
>>> # When passing X and y explicitly, X is returned before y
107108
>>> arr = np.arange(10).reshape((5, 2))
108-
>>> arr_train, arr_test, X_train, X_test, y_train, y_test = train_test_split(
109+
>>> splits = train_test_split( # doctest: +SKIP
109110
... arr, y=y, X=X, test_size=0.33, random_state=42)
110-
>>> X_train
111+
>>> arr_train, arr_test, X_train, X_test, y_train, y_test = splits # doctest: +SKIP
112+
>>> X_train # doctest: +SKIP
111113
array([[4, 5],
112114
[0, 1],
113115
[6, 7]])
@@ -158,10 +160,24 @@ class labels.
158160
ml_task=ml_task,
159161
)
160162

163+
from skore import console # avoid circular import
164+
161165
for warning_class in TRAIN_TEST_SPLIT_WARNINGS:
162166
warning = warning_class.check(**kwargs)
163167

164-
if warning is not None:
165-
warnings.warn(message=warning, category=warning_class, stacklevel=1)
168+
if warning is not None and (
169+
not warnings.filters
170+
or not any(
171+
f[0] == "ignore" and f[2] == warning_class for f in warnings.filters
172+
)
173+
):
174+
console.print(
175+
Panel(
176+
title=warning_class.__name__,
177+
renderable=warning,
178+
style="orange1",
179+
border_style="cyan",
180+
)
181+
)
166182

167183
return output

skore/tests/unit/sklearn/train_test_split/test_train_test_split.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,45 @@ def case_time_based_column_polars_dates():
129129
case_time_based_column_polars_dates,
130130
],
131131
)
132-
def test_train_test_split_warns(params):
132+
def test_train_test_split_warns(params, capsys):
133133
"""When train_test_split is called with these args and kwargs, the corresponding
134-
warning should fire."""
135-
warnings.simplefilter("ignore")
134+
warning should be printed to the console."""
135+
args, kwargs, warning_cls = params()
136+
137+
train_test_split(*args, **kwargs)
138+
139+
captured = capsys.readouterr()
140+
assert warning_cls.__name__ in captured.out
141+
142+
143+
@pytest.mark.parametrize(
144+
"params",
145+
[
146+
case_high_class_imbalance,
147+
case_high_class_imbalance_too_few_examples,
148+
case_high_class_imbalance_too_few_examples_kwargs,
149+
case_high_class_imbalance_too_few_examples_kwargs_mixed,
150+
case_stratify,
151+
case_random_state_unset,
152+
case_shuffle_true,
153+
case_shuffle_none,
154+
case_time_based_column,
155+
case_time_based_columns_several,
156+
case_time_based_column_polars,
157+
case_time_based_column_polars_dates,
158+
],
159+
)
160+
def test_train_test_split_warns_suppressed(params, capsys):
161+
"""Verify that warnings can be suppressed and don't appear in the console output."""
136162
args, kwargs, warning_cls = params()
137163

138-
with pytest.warns(warning_cls):
164+
with warnings.catch_warnings():
165+
warnings.filterwarnings("ignore", category=warning_cls)
139166
train_test_split(*args, **kwargs)
140167

168+
captured = capsys.readouterr()
169+
assert warning_cls.__name__ not in captured.out
170+
141171

142172
def test_train_test_split_kwargs():
143173
"""Passing data by keyword arguments should produce the same results as passing

0 commit comments

Comments
 (0)