-
Notifications
You must be signed in to change notification settings - Fork 101
feat: Design of EstimatorReport
#997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
146 commits
Select commit
Hold shift + click to select a range
fd94974
feat: Use friendly verbose and colorish
glemaitre 1a3d4a6
limit size
glemaitre fe075a4
tweak bold effect
glemaitre 505f2b4
iter
glemaitre c563d45
test: complete tests for new arg in cli
MarieSacksick 801795c
iter
glemaitre 376099c
use context manager as a more explicit way to configurate the logger
glemaitre d427f2d
TST add a couple of quick test for the logger context manager
glemaitre ce273b6
Merge remote-tracking branch 'glemaitre/is/959' into model_report
glemaitre 9cdeacd
feat: EstimatorReport
glemaitre 56b1821
iter
glemaitre acb51e0
fix: Use estimator whenever possible to detect the ML task
glemaitre 92b4e6c
iter
glemaitre 2982798
tests
glemaitre 66a3fd3
tests
glemaitre 4b7124b
DOC add some docstring
glemaitre 7b7c9c8
EXA add an example to present the feature
glemaitre 8de79ae
iter
glemaitre cf0c865
add metrics
glemaitre b1cd767
allow to pass a new set of data to the metrics
glemaitre 97577b4
iter
glemaitre a9d67b4
TST add test for individual metrics
glemaitre 597212b
TST add test for the default scoring in
glemaitre 19fd6d8
TST add test for passing scoring kwargs in report_metrics
glemaitre eeba764
TST add check that we properly hit the cache with arbitrary keywords
glemaitre f845db6
FEA add support for an arbitrary metric
glemaitre 9f219f2
improve example
glemaitre eaeb072
check name add add test with joblib hash
glemaitre 0533abf
allow to add a custom metric in the reporting
glemaitre 6d22b33
mainly refactor the help
glemaitre 7c42cf9
iter
glemaitre 9197928
fix bug according to test
glemaitre 7594551
fix docstring check
glemaitre c94635f
add the EstimatorReport to the API doc
glemaitre be4b265
use literal option by default
glemaitre b491d4e
iter
glemaitre 13297b5
add stubs for solving the problem of auto-completion
glemaitre 4ac4237
only check the cache
glemaitre bb69f8f
TST add test for help and repr of accessor
glemaitre a3d97a0
use pos_label instead of possitive_class and test for plotting
glemaitre 152fc21
add rich repr and help for display
glemaitre f56f19f
iter
glemaitre 003039f
TST for the plot repr and help function
glemaitre 36f7a2c
iter
glemaitre 41a67a0
add precision recall curve
glemaitre c87ea6e
TST add more test for the estimator report displays
glemaitre 70fb720
iter
glemaitre 7258f75
rename X_val and y_val to X_test and y_test to simplify
glemaitre 2cbba2f
use a single constructor
glemaitre 914fb1a
Merge remote-tracking branch 'glemaitre/_find_ml_task_estimator_base'…
glemaitre 259794b
accept external data
glemaitre 0f93945
fix docstring
glemaitre f463285
iter
glemaitre d811bc6
add data_source with test
glemaitre 1705252
Merge remote-tracking branch 'origin/main' into model_report
glemaitre aab9bf4
bring the cache to the external data by computing a hash
glemaitre ad28319
use agg backend
glemaitre b99d78c
Merge remote-tracking branch 'origin/main' into model_report
glemaitre 4b23afb
expose .plot under the metrics accessor
glemaitre 2df34d1
rename plot accessors
glemaitre ce4811c
iter
glemaitre 8609881
iter
glemaitre 6cb3b08
small refactoring for plotting
glemaitre 4f9b6fc
commit refactoring
glemaitre ea18335
add multiclass ovr roc curve
glemaitre 8b61fa0
update classification support for plots
glemaitre 9ea0c6e
check as well for regression
glemaitre ef3937d
add a module to test display
glemaitre 54d9fab
more test roc curve
glemaitre 3119809
check error message
glemaitre 6824bef
add test for the kwargs
glemaitre 2a55aec
check chance kwargs
glemaitre c3e86cb
more doc and remove sample_weight for the moment
glemaitre 6745a3d
modify pr curve and align roc curve
glemaitre efaa8b8
iter
glemaitre 3646154
Merge remote-tracking branch 'origin/main' into model_report
glemaitre aaffeb2
add test precision recall curve binary
glemaitre 54b8692
docstring for tests
glemaitre 84bfba1
add test for multiclass precision recall curve
glemaitre 7dad8b1
add test for args
glemaitre a48fb16
fix file
glemaitre 0613274
iter
glemaitre 0df833d
fix bug with default metric
glemaitre 4eee6ab
more coverage
glemaitre 4c5b10c
improve coverage in EstimatorReport
glemaitre b52cffa
do not cover the cross-validation for the moment but raise an error
glemaitre 1a73cf6
update outdated setter
glemaitre a23fbcb
check other data_source displya
glemaitre 94e4edd
check plot kwargs
glemaitre 2235230
more coverage for precision recall display
glemaitre 906058f
test providing axis in displays
glemaitre 5f7e303
add test for plotting utils
glemaitre 38182b0
modify example
glemaitre 58d2a33
improve menu accessor
glemaitre 84b5ae3
Merge branch 'main' into model_report
glemaitre 97caf68
iter
glemaitre 4fe9660
fix
glemaitre 6c5eb2c
docstring fix
glemaitre c549a7f
new function to precompute the cache
glemaitre 59e3ab8
test the cache_predictions function
glemaitre c64f259
add plotting error display
glemaitre 0f9d368
Update skore/src/skore/sklearn/_estimator.py
glemaitre 5746811
Apply suggestions from code review
glemaitre 4bc0358
brier do not support multiclass
glemaitre b728222
iter
glemaitre b677035
docstring model
glemaitre 3b1bea4
use unicode visual clue
glemaitre 1cbbb17
iter
glemaitre b91a672
iter
glemaitre 3b5c2eb
integration test
glemaitre 05ff4f3
more test
glemaitre bc92165
simplify API for naming with non-default value needed
glemaitre 096fd93
improve color help
glemaitre e2a3a74
iter
glemaitre aee7270
improve repr of the displays
glemaitre f627593
add estimator name in repr
glemaitre 22f10c3
refactor in plots
glemaitre 14da688
more refactoring
glemaitre 1859092
refactor detecting but
glemaitre 63e3d1c
make pos_label and average consistent
glemaitre 1a3bea6
improve consistency documentation
glemaitre 11d338e
small refactor test
glemaitre e46ad58
more refactor tests
glemaitre bb63a82
refactor
glemaitre 1e7f909
fix
glemaitre e6d2a69
split stubs file
glemaitre e8721d6
update doc
glemaitre 334d121
use accessor from pandas to ease doc building
glemaitre 7fba2ff
add noqa F401 to avoid removing import
glemaitre 9c1abd0
do not inject do and overwrite instead
glemaitre 50b7ce7
use the register_accessor for sub-accessor
glemaitre 249e062
first draft for accessor documentation
glemaitre 5a44fc7
add sphinx_autosummary_accessors as a dependence
glemaitre ad4f408
Update examples/model_evaluation/plot_estimator_report.py
glemaitre dad7211
Update skore/src/skore/sklearn/_estimator/base.py
glemaitre 8cc9f4e
Update skore/src/skore/sklearn/_estimator/report.py
glemaitre 4b8be36
rewrap
glemaitre ad20a1c
add legend in the help
glemaitre 682801c
make matplotlib and pandas a dependency
glemaitre d519bd7
remove unecessary __init__.py
glemaitre 65a0f5d
vendor the accessor
glemaitre ebb7ab9
iter
glemaitre cb5f210
add attributes
glemaitre b8d4610
check that we support X_y without passing original dataset
glemaitre 866be82
compute brier score for both labels
glemaitre 82f6332
simplify the brier score
glemaitre File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,385 @@ | ||
| """ | ||
| ============================================ | ||
| Get insights from any scikit-learn estimator | ||
| ============================================ | ||
| This example shows how the :class:`skore.EstimatorReport` class can be used to | ||
| quickly get insights from any scikit-learn estimator. | ||
| """ | ||
|
|
||
| # %% | ||
| # | ||
| # TODO: we need to describe the aim of this classification problem. | ||
| from skrub.datasets import fetch_open_payments | ||
|
|
||
| dataset = fetch_open_payments() | ||
| df = dataset.X | ||
| y = dataset.y | ||
|
|
||
| # %% | ||
| from skrub import TableReport | ||
|
|
||
| TableReport(df) | ||
|
|
||
| # %% | ||
| TableReport(y.to_frame()) | ||
|
|
||
| # %% | ||
| # Looking at the distributions of the target, we observe that this classification | ||
| # task is quite imbalanced. It means that we have to be careful when selecting a set | ||
| # of statistical metrics to evaluate the classification performance of our predictive | ||
| # model. In addition, we see that the class labels are not specified by an integer | ||
| # 0 or 1 but instead by a string "allowed" or "disallowed". | ||
| # | ||
| # For our application, the label of interest is "allowed". | ||
| pos_label, neg_label = "allowed", "disallowed" | ||
|
|
||
| # %% | ||
| # Before training a predictive model, we need to split our dataset into a training | ||
| # and a validation set. | ||
| from skore import train_test_split | ||
|
|
||
| X_train, X_test, y_train, y_test = train_test_split(df, y, random_state=42) | ||
|
|
||
| # %% | ||
| # TODO: we have a perfect case to show useful feature of the `train_test_split` | ||
| # function from `skore`. | ||
| # | ||
| # Now, we need to define a predictive model. Hopefully, `skrub` provides a convenient | ||
| # function (:func:`skrub.tabular_learner`) when it comes to getting strong baseline | ||
| # predictive models with a single line of code. As its feature engineering is generic, | ||
| # it does not provide some handcrafted and tailored feature engineering but still | ||
| # provides a good starting point. | ||
| # | ||
| # So let's create a classifier for our task and fit it on the training set. | ||
| from skrub import tabular_learner | ||
|
|
||
| estimator = tabular_learner("classifier").fit(X_train, y_train) | ||
| estimator | ||
|
|
||
| # %% | ||
| # | ||
| # Introducing the :class:`skore.EstimatorReport` class | ||
| # ---------------------------------------------------- | ||
| # | ||
| # Now, we would be interested in getting some insights from our predictive model. | ||
| # One way is to use the :class:`skore.EstimatorReport` class. This constructor will | ||
| # detect that our estimator is already fitted and will not fit it again. | ||
| from skore import EstimatorReport | ||
|
|
||
| reporter = EstimatorReport( | ||
| estimator, X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test | ||
| ) | ||
| reporter | ||
|
|
||
| # %% | ||
| # | ||
| # Once the reporter is created, we get some information regarding the available tools | ||
| # allowing us to get some insights from our specific model on the specific task. | ||
| # | ||
| # You can get a similar information if you call the :meth:`~skore.EstimatorReport.help` | ||
| # method. | ||
| reporter.help() | ||
|
|
||
| # %% | ||
| # | ||
| # Be aware that you can access the help for each individual sub-accessor. For instance: | ||
| reporter.metrics.help() | ||
|
|
||
| # %% | ||
| reporter.metrics.plot.help() | ||
|
|
||
| # %% | ||
| # | ||
| # Metrics computation with aggressive caching | ||
| # ------------------------------------------- | ||
| # | ||
| # At this point, we might be interested to have a first look at the statistical | ||
| # performance of our model on the validation set that we provided. We can access it | ||
| # by calling any of the metrics displayed above. Since we are greedy, we want to get | ||
| # several metrics at once and we will use the | ||
| # :meth:`~skore.EstimatorReport.metrics.report_metrics` method. | ||
| import time | ||
|
|
||
| start = time.time() | ||
| metric_report = reporter.metrics.report_metrics(pos_label=pos_label) | ||
| end = time.time() | ||
| metric_report | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the metrics: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # An interesting feature provided by the :class:`skore.EstimatorReport` is the | ||
| # the caching mechanism. Indeed, when we have a large enough dataset, computing the | ||
| # predictions for a model is not cheap anymore. For instance, on our smallish dataset, | ||
| # it took a couple of seconds to compute the metrics. The reporter will cache the | ||
| # predictions and if you are interested in computing a metric again or an alternative | ||
| # metric that requires the same predictions, it will be faster. Let's check by | ||
| # requesting the same metrics report again. | ||
|
|
||
| start = time.time() | ||
| metric_report = reporter.metrics.report_metrics(pos_label=pos_label) | ||
| end = time.time() | ||
| metric_report | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the metrics: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # Since we obtain a pandas dataframe, we can also use the plotting interface of | ||
| # pandas. | ||
| import matplotlib.pyplot as plt | ||
|
|
||
| ax = metric_report.T.plot.barh() | ||
| ax.set_title("Metrics report") | ||
| plt.tight_layout() | ||
|
|
||
| # %% | ||
| # | ||
| # Whenever computing a metric, we check if the predictions are available in the cache | ||
| # and reload them if available. So for instance, let's compute the log loss. | ||
|
|
||
| start = time.time() | ||
| log_loss = reporter.metrics.log_loss() | ||
| end = time.time() | ||
| log_loss | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the log loss: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # We can show that without initial cache, it would have taken more time to compute | ||
| # the log loss. | ||
| reporter.clean_cache() | ||
|
|
||
| start = time.time() | ||
| log_loss = reporter.metrics.log_loss() | ||
| end = time.time() | ||
| log_loss | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the log loss: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # By default, the metrics are computed on the test set. However, if a training set | ||
| # is provided, we can also compute the metrics by specifying the `data_source` | ||
| # parameter. | ||
| reporter.metrics.log_loss(data_source="train") | ||
|
|
||
| # %% | ||
| # | ||
| # In the case where we are interested in computing the metrics on a completely new set | ||
| # of data, we can use the `data_source="X_y"` parameter. In addition, we need to provide | ||
| # a `X` and `y` parameters. | ||
|
|
||
| start = time.time() | ||
| metric_report = reporter.metrics.report_metrics( | ||
| data_source="X_y", X=X_test, y=y_test, pos_label=pos_label | ||
| ) | ||
| end = time.time() | ||
| metric_report | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the metrics: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # As in the other case, we rely on the cache to avoid recomputing the predictions. | ||
| # Internally, we compute a hash of the input data to be sure that we can hit the cache | ||
| # in a consistent way. | ||
|
|
||
| # %% | ||
| start = time.time() | ||
| metric_report = reporter.metrics.report_metrics( | ||
| data_source="X_y", X=X_test, y=y_test, pos_label=pos_label | ||
| ) | ||
| end = time.time() | ||
| metric_report | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the metrics: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # .. warning:: | ||
| # In this last example, we rely on computing the hash of the input data. Therefore, | ||
| # there is a trade-off: the computation of the hash is not free and it might be | ||
| # faster to compute the predictions instead. | ||
| # | ||
| # Be aware that you can also benefit from the caching mechanism with your own custom | ||
| # metrics. We only expect that you define your own metric function to take `y_true` | ||
| # and `y_pred` as the first two positional arguments. It can take any other arguments. | ||
| # Let's see an example. | ||
|
|
||
|
|
||
| def operational_decision_cost(y_true, y_pred, amount): | ||
| mask_true_positive = (y_true == pos_label) & (y_pred == pos_label) | ||
| mask_true_negative = (y_true == neg_label) & (y_pred == neg_label) | ||
| mask_false_positive = (y_true == neg_label) & (y_pred == pos_label) | ||
| mask_false_negative = (y_true == pos_label) & (y_pred == neg_label) | ||
| # FIXME: we need to make sense of the cost sensitive part with the right naming | ||
| fraudulent_refuse = mask_true_positive.sum() * 50 | ||
| fraudulent_accept = -amount[mask_false_negative].sum() | ||
| legitimate_refuse = mask_false_positive.sum() * -5 | ||
| legitimate_accept = (amount[mask_true_negative] * 0.02).sum() | ||
| return fraudulent_refuse + fraudulent_accept + legitimate_refuse + legitimate_accept | ||
|
|
||
|
|
||
| # %% | ||
| # | ||
| # In our use case, we have a operational decision to make that translate the | ||
| # classification outcome into a cost. It translate the confusion matrix into a cost | ||
| # matrix based on some amount linked to each sample in the dataset that are provided to | ||
| # us. Here, we randomly generate some amount as an illustration. | ||
| import numpy as np | ||
|
|
||
| rng = np.random.default_rng(42) | ||
| amount = rng.integers(low=100, high=1000, size=len(y_test)) | ||
|
|
||
| # %% | ||
| # | ||
| # Let's make sure that a function called the `predict` method and cached the result. | ||
| # We compute the accuracy metric to make sure that the `predict` method is called. | ||
| reporter.metrics.accuracy() | ||
|
|
||
| # %% | ||
| # | ||
| # We can now compute the cost of our operational decision. | ||
| start = time.time() | ||
| cost = reporter.metrics.custom_metric( | ||
| metric_function=operational_decision_cost, | ||
| metric_name="Operational Decision Cost", | ||
| response_method="predict", | ||
| amount=amount, | ||
| ) | ||
| end = time.time() | ||
| cost | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the cost: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # Let's now clean the cache and see if it is faster. | ||
| reporter.clean_cache() | ||
|
|
||
| # %% | ||
| start = time.time() | ||
| cost = reporter.metrics.custom_metric( | ||
| metric_function=operational_decision_cost, | ||
| metric_name="Operational Decision Cost", | ||
| response_method="predict", | ||
| amount=amount, | ||
| ) | ||
| end = time.time() | ||
| cost | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the cost: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # We observe that caching is working as expected. It is really handy because it means | ||
| # that you can compute some additional metrics without having to recompute the | ||
| # the predictions. | ||
| reporter.metrics.report_metrics( | ||
| scoring=["precision", "recall", operational_decision_cost], | ||
| pos_label=pos_label, | ||
| scoring_kwargs={ | ||
| "amount": amount, | ||
| "response_method": "predict", | ||
| "metric_name": "Operational Decision Cost", | ||
| }, | ||
| ) | ||
|
|
||
| # %% | ||
| # | ||
| # It could happen that you are interested in providing several custom metrics which | ||
| # does not necessarily share the same parameters. In this more complex case, we will | ||
| # require you to provide a scorer using the :func:`sklearn.metrics.make_scorer` | ||
| # function. | ||
| from sklearn.metrics import make_scorer, f1_score | ||
|
|
||
| f1_scorer = make_scorer( | ||
| f1_score, | ||
| response_method="predict", | ||
| metric_name="F1 Score", | ||
| pos_label=pos_label, | ||
| ) | ||
| operational_decision_cost_scorer = make_scorer( | ||
| operational_decision_cost, | ||
| response_method="predict", | ||
| metric_name="Operational Decision Cost", | ||
| amount=amount, | ||
| ) | ||
| reporter.metrics.report_metrics(scoring=[f1_scorer, operational_decision_cost_scorer]) | ||
|
|
||
| # %% | ||
| # | ||
| # Effortless one-liner plotting | ||
| # ----------------------------- | ||
| # | ||
| # The :class:`skore.EstimatorReport` class also provides a plotting interface that | ||
| # allows to plot *defacto* the most common plots. As for the the metrics, we only | ||
| # provide the meaningful set of plots for the provided estimator. | ||
| reporter.metrics.plot.help() | ||
|
|
||
| # %% | ||
| # | ||
| # Let's start by plotting the ROC curve for our binary classification task. | ||
| display = reporter.metrics.plot.roc(pos_label=pos_label) | ||
| plt.tight_layout() | ||
|
|
||
| # %% | ||
| # | ||
| # The plot functionality is built upon the scikit-learn display objects. We return | ||
| # those display (slightly modified to improve the UI) in case you want to tweak some | ||
| # of the plot properties. You can have quick look at the available attributes and | ||
| # methods by calling the `help` method or simply by printing the display. | ||
| display | ||
|
|
||
| # %% | ||
| display.help() | ||
|
|
||
| # %% | ||
| display.plot() | ||
| display.ax_.set_title("Example of a ROC curve") | ||
| display.figure_ | ||
| plt.tight_layout() | ||
|
|
||
| # %% | ||
| # | ||
| # Similarly to the metrics, we aggressively use the caching to avoid recomputing the | ||
| # predictions of the model. We also cache the plot display object by detection if the | ||
| # input parameters are the same as the previous call. Let's demonstrate the kind of | ||
| # performance gain we can get. | ||
| start = time.time() | ||
| # we already trigger the computation of the predictions in a previous call | ||
| reporter.metrics.plot.roc(pos_label=pos_label) | ||
| plt.tight_layout() | ||
| end = time.time() | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # | ||
| # Now, let's clean the cache and check if we get a slowdown. | ||
| reporter.clean_cache() | ||
|
|
||
| # %% | ||
| start = time.time() | ||
| reporter.metrics.plot.roc(pos_label=pos_label) | ||
| plt.tight_layout() | ||
| end = time.time() | ||
|
|
||
| # %% | ||
| print(f"Time taken to compute the ROC curve: {end - start:.2f} seconds") | ||
|
|
||
| # %% | ||
| # As expected, since we need to recompute the predictions, it takes more time. | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.