Skip to content

Commit a1ef819

Browse files
thomass-devauguste-probablsylvaincom
authored
feat: Add ComparisonReport to compare instances of EstimatorReport (#1286)
- [x] Rename to `ComparisonReport` - [x] Rebase on top of #1239 and adapt - [x] Raise if `report.metrics.accuracy(data_source="train")` is called with at least one EstimatorReport that does not have training data - [x] Test - [x] Docstrings - [x] MetricsAccessor - [x] Move index column "#0" in front of each metric - [x] Pass report names in comparator - [ ] ~Update plots legend~ see #1309 - The actual `RocCurveDisplay` needs a full refactor to be splitted by use-case: estimator report, cross-validation report and finally comparison report. In each of these use-cases, there is two scenarios with binary classification and multi-class classification. Otherwise, it will be unmaintainable. - [ ] ~Investigate missing metrics in `report_metrics`~ **(deferred to future PR)** - The logic is split between `report_metrics` and `available_if`; it should be merged (ideally everything in `available_if`?) - [ ] ~Refactor to make `CrossValidationReport` depend on it~ **(deferred to future PR)** - [x] ~Change EstimatorReport `repr`?~ Issue #1293 Closes #1245 Co-authored-by: Auguste <[email protected]> Co-authored-by: Sylvain Combettes <[email protected]>
1 parent 6ea43f6 commit a1ef819

File tree

20 files changed

+1986
-41
lines changed

20 files changed

+1986
-41
lines changed

README.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ skore is a Python open-source library designed to help data scientists apply rec
3333
- `train_test_split` supercharged with methodological guidance: the API is the same as scikit-learn's, but skore displays warnings when applicable. For example, it warns you against shuffling time series data or when you have class imbalance.
3434
- **Evaluate**: automated insightful reports.
3535
- `EstimatorReport`: feed your scikit-learn compatible estimator and dataset, and it generates recommended metrics and plots to help you analyze your estimator. All these are computed and generated for you in 1 line of code. Under the hood, we use efficient caching to make the computations blazing fast.
36-
- `CrossValidationReport`: Get a skore estimator report for each fold of your cross-validation.
36+
- `CrossValidationReport`: get a skore estimator report for each fold of your cross-validation.
37+
- `ComparisonReport`: benchmark your skore estimator reports.
3738

3839
## What's next?
3940

@@ -91,7 +92,7 @@ You can find information on the latest version [here](https://anaconda.org/conda
9192
```python
9293
# Display the ROC curve that was generated for you:
9394
roc_plot = cv_report.metrics.roc()
94-
roc_plot
95+
roc_plot.plot()
9596
```
9697

9798
1. Store your results for safe-keeping.
@@ -109,7 +110,8 @@ You can find information on the latest version [here](https://anaconda.org/conda
109110

110111
```python
111112
# Get your results
112-
df_get = my_project.put("df_cv_report_metrics")
113+
df_get = my_project.get("df_cv_report_metrics")
114+
df_get
113115
```
114116

115117
Learn more in our [documentation](https://skore.probabl.ai).

examples/getting_started/plot_skore_getting_started.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# * :class:`skore.CrossValidationReport`: get an insightful report on your
1818
# cross-validation results
1919
#
20+
# * :class:`skore.ComparisonReport`: benchmark your skore estimator reports
21+
#
2022
# * :func:`skore.train_test_split`: get diagnostics when splitting your data
2123
#
2224
# #. Track your ML/DS results using skore's :class:`~skore.Project`
@@ -50,33 +52,33 @@
5052
X, y = make_classification(n_classes=2, n_samples=100_000, n_informative=4)
5153
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
5254

53-
clf = LogisticRegression(random_state=0)
55+
log_reg = LogisticRegression(random_state=0)
5456

55-
est_report = EstimatorReport(
56-
clf, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test
57+
log_reg_report = EstimatorReport(
58+
log_reg, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test
5759
)
5860

5961
# %%
6062
# Now, we can display the help tree to see all the insights that are available to us
6163
# (skore detected that we are doing binary classification):
6264

6365
# %%
64-
est_report.help()
66+
log_reg_report.help()
6567

6668
# %%
6769
# We can get the report metrics that was computed for us:
6870

6971
# %%
70-
df_est_report_metrics = est_report.metrics.report_metrics()
71-
df_est_report_metrics
72+
df_log_reg_report_metrics = log_reg_report.metrics.report_metrics()
73+
df_log_reg_report_metrics
7274

7375
# %%
7476
# We can also plot the ROC curve that was generated for us:
7577

7678
# %%
7779
import matplotlib.pyplot as plt
7880

79-
roc_plot = est_report.metrics.roc()
81+
roc_plot = log_reg_report.metrics.roc()
8082
roc_plot.plot()
8183
plt.tight_layout()
8284

@@ -97,7 +99,7 @@
9799
# %%
98100
from skore import CrossValidationReport
99101

100-
cv_report = CrossValidationReport(clf, X, y, cv_splitter=5)
102+
cv_report = CrossValidationReport(log_reg, X, y, cv_splitter=5)
101103

102104
# %%
103105
# We display the cross-validation report helper:
@@ -125,16 +127,60 @@
125127
# for example the first fold:
126128

127129
# %%
128-
est_report_fold = cv_report.estimator_reports_[0]
129-
df_report_metrics_fold = est_report_fold.metrics.report_metrics()
130-
df_report_metrics_fold
130+
log_reg_report_fold = cv_report.estimator_reports_[0]
131+
df_log_reg_report_fold_metrics = log_reg_report_fold.metrics.report_metrics()
132+
df_log_reg_report_fold_metrics
131133

132134
# %%
133135
# .. seealso::
134136
#
135137
# For more information about the motivation and usage of
136138
# :class:`skore.CrossValidationReport`, see :ref:`example_use_case_employee_salaries`.
137139

140+
# %%
141+
# Comparing estimators reports
142+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
143+
#
144+
# :class:`skore.ComparisonReport` enables users to compare several estimator reports
145+
# (corresponding to several estimators) on a same test set, as in a benchmark of
146+
# estimators.
147+
#
148+
# Apart from the previous ``log_reg_report``, let use define another estimator report:
149+
150+
# %%
151+
from sklearn.ensemble import RandomForestClassifier
152+
153+
rf = RandomForestClassifier(max_depth=2, random_state=0)
154+
rf_report = EstimatorReport(
155+
rf, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test
156+
)
157+
158+
# %%
159+
# Now, let us compare these two estimator reports, that were applied to the exact
160+
# same test set:
161+
162+
# %%
163+
from skore import ComparisonReport
164+
165+
comparator = ComparisonReport(reports=[log_reg_report, rf_report])
166+
167+
# %%
168+
# As for the :class:`~skore.EstimatorReport` and the
169+
# :class:`~skore.CrossValidationReport`, we have a helper:
170+
171+
# %%
172+
comparator.help()
173+
174+
# %%
175+
# Let us display the result of our benchmark:
176+
177+
# %%
178+
benchmark_metrics = comparator.metrics.report_metrics()
179+
benchmark_metrics
180+
181+
# %%
182+
# We have the result of our benchmark.
183+
138184
# %%
139185
# Train-test split with skore
140186
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^

examples/use_cases/plot_employee_salaries.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,11 @@ def periodic_spline_transformer(period, n_splines=None, degree=3):
298298
)
299299
results
300300

301+
# %%
302+
# .. note::
303+
# We could have also used the :class:`skore.ComparisonReport` to compare estimator
304+
# reports.
305+
301306
# %%
302307
#
303308
# Finally, we can even get the individual :class:`~skore.EstimatorReport` for each fold

skore/src/skore/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from skore._config import config_context, get_config, set_config
99
from skore.project import Project, open
1010
from skore.sklearn import (
11+
ComparisonReport,
1112
CrossValidationReport,
1213
EstimatorReport,
1314
PrecisionRecallCurveDisplay,
@@ -20,6 +21,7 @@
2021

2122
__all__ = [
2223
"CrossValidationReport",
24+
"ComparisonReport",
2325
"EstimatorReport",
2426
"PrecisionRecallCurveDisplay",
2527
"PredictionErrorDisplay",

skore/src/skore/sklearn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Enhance `sklearn` functions."""
22

3+
from skore.sklearn._comparison import ComparisonReport
34
from skore.sklearn._cross_validation import CrossValidationReport
45
from skore.sklearn._estimator import EstimatorReport
56
from skore.sklearn._plot import (
@@ -13,6 +14,7 @@
1314
"train_test_split",
1415
"CrossValidationReport",
1516
"EstimatorReport",
17+
"ComparisonReport",
1618
"RocCurveDisplay",
1719
"PrecisionRecallCurveDisplay",
1820
"PredictionErrorDisplay",

skore/src/skore/sklearn/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def _get_attributes_for_help(self):
124124

125125
def _create_help_tree(self):
126126
"""Create a rich Tree with the available tools and accessor methods."""
127-
tree = Tree("report")
127+
tree = Tree(self.__class__.__name__)
128128

129129
# Add accessor methods first
130130
for accessor_attr, config in self._ACCESSOR_CONFIG.items():
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from skore.externals._pandas_accessors import _register_accessor
2+
from skore.sklearn._comparison.metrics_accessor import _MetricsAccessor
3+
from skore.sklearn._comparison.report import ComparisonReport
4+
5+
_register_accessor("metrics", ComparisonReport)(_MetricsAccessor)
6+
7+
__all__ = ["ComparisonReport"]

0 commit comments

Comments
 (0)