Skip to content

Commit 1a4151a

Browse files
authored
feat: Design of EstimatorReport (#997)
closes #834 Investigate an API for a `EstimatorReport`. #### TODO - [x] Metrics - [x] handle string metrics has specified in the accessor - [x] handle callable metrics - [x] handle scikit-learn scorers - [x] use efficiently the cache as much as possible - [x] add testing for all of those features - [x] allow to pass new validation set to functions instead of using the internal validation set - [x] add a proper help and rich `__repr__` - [x] Plots - [x] add the roc curve display - [x] add the precision recall curve display - [x] add prediction error display for regressor - [x] make proper testing for those displays - [x] add a proper `__repr__` for those displays - [x] Documentation - [x] (done for the checked part) add an example to showcase all the different features - [x] find a way to show the accessors documentation in the page of `EstimatorReport`. It could be a bit tricky because they are only defined once the instance created. - We need to have a look at the `series.rst` page from pandas to see how they document this sort of pattern. - [x] check the autocompletion: when typing `report.metrics.->tab` it should provide the autocompetion. **edit**: having a stub file is actually working. I prefer this than type hints directly in the file. - Open questions - [x] we use hashing to retrieve external set. - use the caching for the external validation set? To make it work we need to compute the hash of potentially big arrays. This might more costly than making the model predict. #### Notes This PR build upon: - #962 to reuse the `skore.console` - #998 to be able to detect clusterer in a consistent manner.
1 parent 1e0b605 commit 1a4151a

34 files changed

+5713
-7
lines changed
Lines changed: 385 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,385 @@
1+
"""
2+
============================================
3+
Get insights from any scikit-learn estimator
4+
============================================
5+
6+
This example shows how the :class:`skore.EstimatorReport` class can be used to
7+
quickly get insights from any scikit-learn estimator.
8+
"""
9+
10+
# %%
11+
#
12+
# TODO: we need to describe the aim of this classification problem.
13+
from skrub.datasets import fetch_open_payments
14+
15+
dataset = fetch_open_payments()
16+
df = dataset.X
17+
y = dataset.y
18+
19+
# %%
20+
from skrub import TableReport
21+
22+
TableReport(df)
23+
24+
# %%
25+
TableReport(y.to_frame())
26+
27+
# %%
28+
# Looking at the distributions of the target, we observe that this classification
29+
# task is quite imbalanced. It means that we have to be careful when selecting a set
30+
# of statistical metrics to evaluate the classification performance of our predictive
31+
# model. In addition, we see that the class labels are not specified by an integer
32+
# 0 or 1 but instead by a string "allowed" or "disallowed".
33+
#
34+
# For our application, the label of interest is "allowed".
35+
pos_label, neg_label = "allowed", "disallowed"
36+
37+
# %%
38+
# Before training a predictive model, we need to split our dataset into a training
39+
# and a validation set.
40+
from skore import train_test_split
41+
42+
X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42)
43+
44+
# %%
45+
# TODO: we have a perfect case to show useful feature of the `train_test_split`
46+
# function from `skore`.
47+
#
48+
# Now, we need to define a predictive model. Hopefully, `skrub` provides a convenient
49+
# function (:func:`skrub.tabular_learner`) when it comes to getting strong baseline
50+
# predictive models with a single line of code. As its feature engineering is generic,
51+
# it does not provide some handcrafted and tailored feature engineering but still
52+
# provides a good starting point.
53+
#
54+
# So let's create a classifier for our task and fit it on the training set.
55+
from skrub import tabular_learner
56+
57+
estimator = tabular_learner("classifier").fit(X_train, y_train)
58+
estimator
59+
60+
# %%
61+
#
62+
# Introducing the :class:`skore.EstimatorReport` class
63+
# ----------------------------------------------------
64+
#
65+
# Now, we would be interested in getting some insights from our predictive model.
66+
# One way is to use the :class:`skore.EstimatorReport` class. This constructor will
67+
# detect that our estimator is already fitted and will not fit it again.
68+
from skore import EstimatorReport
69+
70+
reporter = EstimatorReport(
71+
estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test
72+
)
73+
reporter
74+
75+
# %%
76+
#
77+
# Once the reporter is created, we get some information regarding the available tools
78+
# allowing us to get some insights from our specific model on the specific task.
79+
#
80+
# You can get a similar information if you call the :meth:`~skore.EstimatorReport.help`
81+
# method.
82+
reporter.help()
83+
84+
# %%
85+
#
86+
# Be aware that you can access the help for each individual sub-accessor. For instance:
87+
reporter.metrics.help()
88+
89+
# %%
90+
reporter.metrics.plot.help()
91+
92+
# %%
93+
#
94+
# Metrics computation with aggressive caching
95+
# -------------------------------------------
96+
#
97+
# At this point, we might be interested to have a first look at the statistical
98+
# performance of our model on the validation set that we provided. We can access it
99+
# by calling any of the metrics displayed above. Since we are greedy, we want to get
100+
# several metrics at once and we will use the
101+
# :meth:`~skore.EstimatorReport.metrics.report_metrics` method.
102+
import time
103+
104+
start = time.time()
105+
metric_report = reporter.metrics.report_metrics(pos_label=pos_label)
106+
end = time.time()
107+
metric_report
108+
109+
# %%
110+
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")
111+
112+
# %%
113+
#
114+
# An interesting feature provided by the :class:`skore.EstimatorReport` is the
115+
# the caching mechanism. Indeed, when we have a large enough dataset, computing the
116+
# predictions for a model is not cheap anymore. For instance, on our smallish dataset,
117+
# it took a couple of seconds to compute the metrics. The reporter will cache the
118+
# predictions and if you are interested in computing a metric again or an alternative
119+
# metric that requires the same predictions, it will be faster. Let's check by
120+
# requesting the same metrics report again.
121+
122+
start = time.time()
123+
metric_report = reporter.metrics.report_metrics(pos_label=pos_label)
124+
end = time.time()
125+
metric_report
126+
127+
# %%
128+
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")
129+
130+
# %%
131+
#
132+
# Since we obtain a pandas dataframe, we can also use the plotting interface of
133+
# pandas.
134+
import matplotlib.pyplot as plt
135+
136+
ax = metric_report.T.plot.barh()
137+
ax.set_title("Metrics report")
138+
plt.tight_layout()
139+
140+
# %%
141+
#
142+
# Whenever computing a metric, we check if the predictions are available in the cache
143+
# and reload them if available. So for instance, let's compute the log loss.
144+
145+
start = time.time()
146+
log_loss = reporter.metrics.log_loss()
147+
end = time.time()
148+
log_loss
149+
150+
# %%
151+
print(f"Time taken to compute the log loss: {end - start:.2f} seconds")
152+
153+
# %%
154+
#
155+
# We can show that without initial cache, it would have taken more time to compute
156+
# the log loss.
157+
reporter.clean_cache()
158+
159+
start = time.time()
160+
log_loss = reporter.metrics.log_loss()
161+
end = time.time()
162+
log_loss
163+
164+
# %%
165+
print(f"Time taken to compute the log loss: {end - start:.2f} seconds")
166+
167+
# %%
168+
#
169+
# By default, the metrics are computed on the test set. However, if a training set
170+
# is provided, we can also compute the metrics by specifying the `data_source`
171+
# parameter.
172+
reporter.metrics.log_loss(data_source="train")
173+
174+
# %%
175+
#
176+
# In the case where we are interested in computing the metrics on a completely new set
177+
# of data, we can use the `data_source="X_y"` parameter. In addition, we need to provide
178+
# a `X` and `y` parameters.
179+
180+
start = time.time()
181+
metric_report = reporter.metrics.report_metrics(
182+
data_source="X_y", X=X_test, y=y_test, pos_label=pos_label
183+
)
184+
end = time.time()
185+
metric_report
186+
187+
# %%
188+
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")
189+
190+
# %%
191+
#
192+
# As in the other case, we rely on the cache to avoid recomputing the predictions.
193+
# Internally, we compute a hash of the input data to be sure that we can hit the cache
194+
# in a consistent way.
195+
196+
# %%
197+
start = time.time()
198+
metric_report = reporter.metrics.report_metrics(
199+
data_source="X_y", X=X_test, y=y_test, pos_label=pos_label
200+
)
201+
end = time.time()
202+
metric_report
203+
204+
# %%
205+
print(f"Time taken to compute the metrics: {end - start:.2f} seconds")
206+
207+
# %%
208+
#
209+
# .. warning::
210+
# In this last example, we rely on computing the hash of the input data. Therefore,
211+
# there is a trade-off: the computation of the hash is not free and it might be
212+
# faster to compute the predictions instead.
213+
#
214+
# Be aware that you can also benefit from the caching mechanism with your own custom
215+
# metrics. We only expect that you define your own metric function to take `y_true`
216+
# and `y_pred` as the first two positional arguments. It can take any other arguments.
217+
# Let's see an example.
218+
219+
220+
def operational_decision_cost(y_true, y_pred, amount):
221+
mask_true_positive = (y_true == pos_label) & (y_pred == pos_label)
222+
mask_true_negative = (y_true == neg_label) & (y_pred == neg_label)
223+
mask_false_positive = (y_true == neg_label) & (y_pred == pos_label)
224+
mask_false_negative = (y_true == pos_label) & (y_pred == neg_label)
225+
# FIXME: we need to make sense of the cost sensitive part with the right naming
226+
fraudulent_refuse = mask_true_positive.sum() * 50
227+
fraudulent_accept = -amount[mask_false_negative].sum()
228+
legitimate_refuse = mask_false_positive.sum() * -5
229+
legitimate_accept = (amount[mask_true_negative] * 0.02).sum()
230+
return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept
231+
232+
233+
# %%
234+
#
235+
# In our use case, we have a operational decision to make that translate the
236+
# classification outcome into a cost. It translate the confusion matrix into a cost
237+
# matrix based on some amount linked to each sample in the dataset that are provided to
238+
# us. Here, we randomly generate some amount as an illustration.
239+
import numpy as np
240+
241+
rng = np.random.default_rng(42)
242+
amount = rng.integers(low=100, high=1000, size=len(y_test))
243+
244+
# %%
245+
#
246+
# Let's make sure that a function called the `predict` method and cached the result.
247+
# We compute the accuracy metric to make sure that the `predict` method is called.
248+
reporter.metrics.accuracy()
249+
250+
# %%
251+
#
252+
# We can now compute the cost of our operational decision.
253+
start = time.time()
254+
cost = reporter.metrics.custom_metric(
255+
metric_function=operational_decision_cost,
256+
metric_name="Operational Decision Cost",
257+
response_method="predict",
258+
amount=amount,
259+
)
260+
end = time.time()
261+
cost
262+
263+
# %%
264+
print(f"Time taken to compute the cost: {end - start:.2f} seconds")
265+
266+
# %%
267+
#
268+
# Let's now clean the cache and see if it is faster.
269+
reporter.clean_cache()
270+
271+
# %%
272+
start = time.time()
273+
cost = reporter.metrics.custom_metric(
274+
metric_function=operational_decision_cost,
275+
metric_name="Operational Decision Cost",
276+
response_method="predict",
277+
amount=amount,
278+
)
279+
end = time.time()
280+
cost
281+
282+
# %%
283+
print(f"Time taken to compute the cost: {end - start:.2f} seconds")
284+
285+
# %%
286+
#
287+
# We observe that caching is working as expected. It is really handy because it means
288+
# that you can compute some additional metrics without having to recompute the
289+
# the predictions.
290+
reporter.metrics.report_metrics(
291+
scoring=["precision", "recall", operational_decision_cost],
292+
pos_label=pos_label,
293+
scoring_kwargs={
294+
"amount": amount,
295+
"response_method": "predict",
296+
"metric_name": "Operational Decision Cost",
297+
},
298+
)
299+
300+
# %%
301+
#
302+
# It could happen that you are interested in providing several custom metrics which
303+
# does not necessarily share the same parameters. In this more complex case, we will
304+
# require you to provide a scorer using the :func:`sklearn.metrics.make_scorer`
305+
# function.
306+
from sklearn.metrics import make_scorer, f1_score
307+
308+
f1_scorer = make_scorer(
309+
f1_score,
310+
response_method="predict",
311+
metric_name="F1 Score",
312+
pos_label=pos_label,
313+
)
314+
operational_decision_cost_scorer = make_scorer(
315+
operational_decision_cost,
316+
response_method="predict",
317+
metric_name="Operational Decision Cost",
318+
amount=amount,
319+
)
320+
reporter.metrics.report_metrics(scoring=[f1_scorer, operational_decision_cost_scorer])
321+
322+
# %%
323+
#
324+
# Effortless one-liner plotting
325+
# -----------------------------
326+
#
327+
# The :class:`skore.EstimatorReport` class also provides a plotting interface that
328+
# allows to plot *defacto* the most common plots. As for the the metrics, we only
329+
# provide the meaningful set of plots for the provided estimator.
330+
reporter.metrics.plot.help()
331+
332+
# %%
333+
#
334+
# Let's start by plotting the ROC curve for our binary classification task.
335+
display = reporter.metrics.plot.roc(pos_label=pos_label)
336+
plt.tight_layout()
337+
338+
# %%
339+
#
340+
# The plot functionality is built upon the scikit-learn display objects. We return
341+
# those display (slightly modified to improve the UI) in case you want to tweak some
342+
# of the plot properties. You can have quick look at the available attributes and
343+
# methods by calling the `help` method or simply by printing the display.
344+
display
345+
346+
# %%
347+
display.help()
348+
349+
# %%
350+
display.plot()
351+
display.ax_.set_title("Example of a ROC curve")
352+
display.figure_
353+
plt.tight_layout()
354+
355+
# %%
356+
#
357+
# Similarly to the metrics, we aggressively use the caching to avoid recomputing the
358+
# predictions of the model. We also cache the plot display object by detection if the
359+
# input parameters are the same as the previous call. Let's demonstrate the kind of
360+
# performance gain we can get.
361+
start = time.time()
362+
# we already trigger the computation of the predictions in a previous call
363+
reporter.metrics.plot.roc(pos_label=pos_label)
364+
plt.tight_layout()
365+
end = time.time()
366+
367+
# %%
368+
print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds")
369+
370+
# %%
371+
#
372+
# Now, let's clean the cache and check if we get a slowdown.
373+
reporter.clean_cache()
374+
375+
# %%
376+
start = time.time()
377+
reporter.metrics.plot.roc(pos_label=pos_label)
378+
plt.tight_layout()
379+
end = time.time()
380+
381+
# %%
382+
print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds")
383+
384+
# %%
385+
# As expected, since we need to recompute the predictions, it takes more time.

0 commit comments

Comments
 (0)