diff --git a/examples/model_evaluation/plot_estimator_report.py b/examples/model_evaluation/plot_estimator_report.py index 44b0306134..970faa3457 100644 --- a/examples/model_evaluation/plot_estimator_report.py +++ b/examples/model_evaluation/plot_estimator_report.py @@ -291,8 +291,11 @@ def operational_decision_cost(y_true, y_pred, amount): # that we can compute some additional metrics without having to recompute the # the predictions. report.metrics.summarize( - scoring=["precision", "recall", operational_decision_cost], - scoring_names=["Precision", "Recall", "Operational Decision Cost"], + scoring={ + "Precision": "precision", + "Recall": "recall", + "Operational Decision Cost": operational_decision_cost, + }, scoring_kwargs={"amount": amount, "response_method": "predict"}, ).frame() @@ -309,8 +312,10 @@ def operational_decision_cost(y_true, y_pred, amount): operational_decision_cost, response_method="predict", amount=amount ) report.metrics.summarize( - scoring=[f1_scorer, operational_decision_cost_scorer], - scoring_names=["F1 Score", "Operational Decision Cost"], + scoring={ + "F1 Score": f1_scorer, + "Operational Decision Cost": operational_decision_cost_scorer, + }, ).frame() # %% diff --git a/examples/use_cases/plot_employee_salaries.py b/examples/use_cases/plot_employee_salaries.py index 83fc2e4a16..758440f485 100644 --- a/examples/use_cases/plot_employee_salaries.py +++ b/examples/use_cases/plot_employee_salaries.py @@ -308,15 +308,10 @@ def periodic_spline_transformer(period, n_splines=None, degree=3): # %% from sklearn.metrics import get_scorer -scoring = ["r2", "rmse", get_scorer("neg_mean_absolute_error")] +scoring = {"R²": "r2", "RMSE": "rmse", "MAE": get_scorer("neg_mean_absolute_error")} scoring_kwargs = {"response_method": "predict"} -scoring_names = ["R²", "RMSE", "MAE"] -comparator.metrics.summarize( - scoring=scoring, - scoring_kwargs=scoring_kwargs, - scoring_names=scoring_names, -).frame() +comparator.metrics.summarize(scoring=scoring, scoring_kwargs=scoring_kwargs).frame() # %% # Finally, we can even get a deeper understanding by analyzing each split in the diff --git a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py index eaa8b4de56..b432474f7c 100644 --- a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py @@ -26,7 +26,6 @@ Aggregate, PositiveLabel, Scoring, - ScoringName, YPlotData, ) from skore._utils._accessor import ( @@ -57,8 +56,7 @@ def summarize( data_source: DataSource = "test", X: ArrayLike | None = None, y: ArrayLike | None = None, - scoring: Scoring | list[Scoring] | None = None, - scoring_names: ScoringName | list[ScoringName] | None = None, + scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None, scoring_kwargs: dict[str, Any] | None = None, pos_label: PositiveLabel | None = _DEFAULT, indicator_favorability: bool = False, @@ -84,7 +82,8 @@ def summarize( New target on which to compute the metric. By default, we use the target provided when creating the report. - scoring : str, callable, scorer or list of such instances, default=None + scoring : str, callable, scorer, or list of such instances or dict of such \ + instances, default=None The metrics to report. The possible values (whether or not in a list) are: - if a string, either one of the built-in metrics or a scikit-learn scorer @@ -101,10 +100,6 @@ def summarize( the metric favorability will only be displayed if it is given explicitly via `make_scorer`'s `greater_is_better` parameter. - scoring_names : str, None or list of such instances, default=None - Used to overwrite the default scoring names in the report. It should be of - the same length as the ``scoring`` parameter. - scoring_kwargs : dict, default=None The keyword arguments to pass to the scoring functions. @@ -164,7 +159,6 @@ class is set to the one provided when creating the report. If `None`, scoring=scoring, pos_label=pos_label, scoring_kwargs=scoring_kwargs, - scoring_names=scoring_names, indicator_favorability=indicator_favorability, aggregate=aggregate, ) @@ -1125,12 +1119,12 @@ def custom_metric( response_method=response_method, **kwargs, ) + scoring = {metric_name: scorer} if metric_name is not None else [scorer] return self.summarize( - scoring=[scorer], + scoring=scoring, data_source=data_source, X=X, y=y, - scoring_names=[metric_name] if metric_name is not None else None, aggregate=aggregate, ).frame() diff --git a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py index 077fc477c8..0677e5172b 100644 --- a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py @@ -26,7 +26,6 @@ Aggregate, PositiveLabel, Scoring, - ScoringName, YPlotData, ) from skore._utils._accessor import _check_estimator_report_has_method @@ -55,8 +54,7 @@ def summarize( data_source: DataSource = "test", X: ArrayLike | None = None, y: ArrayLike | None = None, - scoring: Scoring | list[Scoring] | None = None, - scoring_names: ScoringName | list[ScoringName] | None = None, + scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None, scoring_kwargs: dict[str, Any] | None = None, pos_label: PositiveLabel | None = _DEFAULT, indicator_favorability: bool = False, @@ -83,8 +81,9 @@ def summarize( New target on which to compute the metric. By default, we use the target provided when creating the report. - scoring : str, callable, scorer or list of such instances, default=None - The metrics to report. The possible values (whether or not in a list) are: + scoring : str, callable, scorer, or list of such instances or dict of such \ + instances, default=None + The metrics to report. The possible values are: - if a string, either one of the built-in metrics or a scikit-learn scorer name. You can get the possible list of string using @@ -99,10 +98,10 @@ def summarize( scorers as provided by :func:`sklearn.metrics.make_scorer`. In this case, the metric favorability will only be displayed if it is given explicitly via `make_scorer`'s `greater_is_better` parameter. - - scoring_names : str, None or list of such instances, default=None - Used to overwrite the default scoring names in the report. If a list, - it should be of the same length as the `scoring` parameter. + - if a dict, the keys are used as metric names and the values are the + scoring functions (strings, callables, or scorers as described above). + - if a list, each element can be any of the above types (strings, callables, + scorers). scoring_kwargs : dict, default=None The keyword arguments to pass to the scoring functions. @@ -161,7 +160,6 @@ class is set to the one provided when creating the report. If `None`, scoring=scoring, pos_label=pos_label, scoring_kwargs=scoring_kwargs, - scoring_names=scoring_names, indicator_favorability=indicator_favorability, ) if flat_index: @@ -1035,13 +1033,13 @@ def custom_metric( response_method=response_method, **kwargs, ) + scoring = {metric_name: scorer} if metric_name is not None else [scorer] return self.summarize( - scoring=[scorer], + scoring=scoring, data_source=data_source, X=X, y=y, aggregate=aggregate, - scoring_names=[metric_name] if metric_name is not None else None, pos_label=pos_label, ).frame() diff --git a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py index 0f7eb03c6a..0c01a4c618 100644 --- a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py @@ -30,7 +30,6 @@ _DEFAULT, PositiveLabel, Scoring, - ScoringName, YPlotData, ) from skore._utils._accessor import ( @@ -66,8 +65,7 @@ def summarize( data_source: DataSource = "test", X: ArrayLike | None = None, y: ArrayLike | None = None, - scoring: Scoring | list[Scoring] | None = None, - scoring_names: ScoringName | list[ScoringName] | None = None, + scoring: Scoring | list[Scoring] | dict[str, Scoring] | None = None, scoring_kwargs: dict[str, Any] | None = None, pos_label: PositiveLabel | None = _DEFAULT, indicator_favorability: bool = False, @@ -92,8 +90,9 @@ def summarize( New target on which to compute the metric. By default, we use the target provided when creating the report. - scoring : str, callable, scorer or list of such instances, default=None - The metrics to report. The possible values (whether or not in a list) are: + scoring : str, callable, scorer, or list of such instances or dict of such \ + instances, default=None + The metrics to report. The possible values are: - if a string, either one of the built-in metrics or a scikit-learn scorer name. You can get the possible list of string using @@ -108,10 +107,10 @@ def summarize( scorers as provided by :func:`sklearn.metrics.make_scorer`. In this case, the metric favorability will only be displayed if it is given explicitly via `make_scorer`'s `greater_is_better` parameter. - - scoring_names : str, None or list of such instances, default=None - Used to overwrite the default scoring names in the report. It should be of - the same length as the `scoring` parameter. + - if a dict, the keys are used as metric names and the values are the + scoring functions (strings, callables, or scorers as described above). + - if a list, each element can be any of the above types (strings, callables, + scorers). scoring_kwargs : dict, default=None The keyword arguments to pass to the scoring functions. @@ -164,12 +163,14 @@ class is set to the one provided when creating the report. If `None`, if pos_label is _DEFAULT: pos_label = self._parent.pos_label - if scoring is not None and not isinstance(scoring, list): + # Handle dictionary scoring + scoring_names = None + if isinstance(scoring, dict): + scoring_names = list(scoring.keys()) + scoring = list(scoring.values()) + elif scoring is not None and not isinstance(scoring, list): scoring = [scoring] - if scoring_names is not None and not isinstance(scoring_names, list): - scoring_names = [scoring_names] - if data_source == "X_y": # optimization of the hash computation to avoid recomputing it # FIXME: we are still recomputing the hash for all the metrics that we @@ -181,7 +182,6 @@ class is set to the one provided when creating the report. If `None`, else: data_source_hash = None - scoring_was_none = scoring is None if scoring is None: # Equivalent to _get_scorers_to_add if self._parent._ml_task == "binary-classification": @@ -196,23 +196,8 @@ class is set to the one provided when creating the report. If `None`, scoring = ["_r2", "_rmse"] scoring += ["_fit_time", "_predict_time"] - if scoring_names is not None and len(scoring_names) != len(scoring): - if scoring_was_none: - # we raise a better error message since we decide the default scores - raise ValueError( - "The `scoring_names` parameter should be of the same length as " - "the `scoring` parameter. In your case, `scoring` was set to None " - f"and you are using our default scores that are {len(scoring)}. " - f"The list is the following: {scoring}." - ) - else: - raise ValueError( - "The `scoring_names` parameter should be of the same length as " - f"the `scoring` parameter. Got {len(scoring_names)} names for " - f"{len(scoring)} scoring functions." - ) - elif scoring_names is None: - scoring_names = [None] * len(scoring) + if scoring_names is None: + scoring_names = [None] * len(scoring) # type: ignore scores = [] favorability_indicator = [] diff --git a/skore/tests/unit/displays/metrics_summary/test_cross_validation.py b/skore/tests/unit/displays/metrics_summary/test_cross_validation.py index 999d0feaf1..9f60c0bd82 100644 --- a/skore/tests/unit/displays/metrics_summary/test_cross_validation.py +++ b/skore/tests/unit/displays/metrics_summary/test_cross_validation.py @@ -313,16 +313,28 @@ def test_scoring_kwargs_multi_class( @pytest.mark.parametrize( - "fixture_name, scoring_names, expected_index", + "fixture_name, scoring, expected_index", [ ( "linear_regression_data", - ["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"], + { + "R2": "r2", + "RMSE": "rmse", + "FIT_TIME": "_fit_time", + "PREDICT_TIME": "_predict_time", + }, ["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"], ), ( "forest_multiclass_classification_data", - ["Precision", "Recall", "ROC AUC", "Log Loss", "Fit Time", "Predict Time"], + { + "Precision": "_precision", + "Recall": "_recall", + "ROC AUC": "_roc_auc", + "Log Loss": "_log_loss", + "Fit Time": "_fit_time", + "Predict Time": "_predict_time", + }, [ "Precision", "Precision", @@ -340,11 +352,11 @@ def test_scoring_kwargs_multi_class( ), ], ) -def test_overwrite_scoring_names(request, fixture_name, scoring_names, expected_index): +def test_overwrite_scoring_names(request, fixture_name, scoring, expected_index): """Test that we can overwrite the scoring names in `MetricsSummaryDisplay`.""" estimator, X, y = request.getfixturevalue(fixture_name) report = CrossValidationReport(estimator, X, y, splitter=2) - result = report.metrics.summarize(scoring_names=scoring_names).frame() + result = report.metrics.summarize(scoring=scoring).frame() assert result.shape == (len(expected_index), 2) # Get level 0 names if MultiIndex, otherwise get column names @@ -480,3 +492,31 @@ def test_indicator_favorability(forest_binary_classification_data, aggregate): assert indicator["Brier score"].tolist() == ["(↘︎)"] assert indicator["Fit time (s)"].tolist() == ["(↘︎)"] assert indicator["Predict time (s)"].tolist() == ["(↘︎)"] + + +def test_overwrite_scoring_names_with_dict_cross_validation( + forest_multiclass_classification_data, +): + """Test that we can overwrite the scoring names using dict scoring in + CrossValidationReport.""" + estimator, X, y = forest_multiclass_classification_data + report = CrossValidationReport(estimator, X, y, splitter=2) + + scoring_dict = { + "Custom Precision": "_precision", + "Custom Recall": "_recall", + "Custom ROC AUC": "_roc_auc", + } + + result = report.metrics.summarize(scoring=scoring_dict).frame() + + # Check that custom names are used + result_index = ( + result.index.get_level_values(0).tolist() + if isinstance(result.index, pd.MultiIndex) + else result.index.tolist() + ) + + assert "Custom Precision" in result_index + assert "Custom Recall" in result_index + assert "Custom ROC AUC" in result_index diff --git a/skore/tests/unit/displays/metrics_summary/test_estimator.py b/skore/tests/unit/displays/metrics_summary/test_estimator.py index ae624d1972..4d4de1f03a 100644 --- a/skore/tests/unit/displays/metrics_summary/test_estimator.py +++ b/skore/tests/unit/displays/metrics_summary/test_estimator.py @@ -192,16 +192,28 @@ def test_scoring_kwargs( @pytest.mark.parametrize( - "fixture_name, scoring_names, expected_columns", + "fixture_name, scoring_dict, expected_columns", [ ( "linear_regression_with_test", - ["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"], + { + "R2": "r2", + "RMSE": "rmse", + "FIT_TIME": "_fit_time", + "PREDICT_TIME": "_predict_time", + }, ["R2", "RMSE", "FIT_TIME", "PREDICT_TIME"], ), ( "forest_multiclass_classification_with_test", - ["Precision", "Recall", "ROC AUC", "Log Loss", "Fit Time", "Predict Time"], + { + "Precision": "_precision", + "Recall": "_recall", + "ROC AUC": "_roc_auc", + "Log Loss": "_log_loss", + "Fit Time": "_fit_time", + "Predict Time": "_predict_time", + }, [ "Precision", "Precision", @@ -219,13 +231,13 @@ def test_scoring_kwargs( ), ], ) -def test_overwrite_scoring_names( - request, fixture_name, scoring_names, expected_columns +def test_overwrite_scoring_names_with_dict( + request, fixture_name, scoring_dict, expected_columns ): - """Test that we can overwrite the scoring names.""" + """Test that we can overwrite the scoring names using dict scoring.""" estimator, X_test, y_test = request.getfixturevalue(fixture_name) report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) - result = report.metrics.summarize(scoring_names=scoring_names).frame() + result = report.metrics.summarize(scoring=scoring_dict).frame() assert result.shape == (len(expected_columns), 1) # Get level 0 names if MultiIndex, otherwise get column names @@ -253,28 +265,27 @@ def test_indicator_favorability( @pytest.mark.parametrize( - "scoring, scoring_names, scoring_kwargs", + "scoring, scoring_kwargs", [ - ("accuracy", "this_is_a_test", None), - ("neg_log_loss", "this_is_a_test", None), - (accuracy_score, "this_is_a_test", {"response_method": "predict"}), - (get_scorer("accuracy"), "this_is_a_test", None), + ("accuracy", None), + ("neg_log_loss", None), + (accuracy_score, {"response_method": "predict"}), + (get_scorer("accuracy"), None), ], ) def test_scoring_single_list_equivalence( - forest_binary_classification_with_test, scoring, scoring_names, scoring_kwargs + forest_binary_classification_with_test, scoring, scoring_kwargs ): """Check that passing a single string, callable, scorer is equivalent to passing a - list with a single element, and it's possible to overwrite col name.""" + list with a single element.""" estimator, X_test, y_test = forest_binary_classification_with_test report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) result_single = report.metrics.summarize( - scoring=scoring, scoring_names=scoring_names, scoring_kwargs=scoring_kwargs + scoring=scoring, scoring_kwargs=scoring_kwargs ).frame() result_list = report.metrics.summarize( - scoring=[scoring], scoring_names=scoring_names, scoring_kwargs=scoring_kwargs + scoring=[scoring], scoring_kwargs=scoring_kwargs ).frame() - assert result_single.index[0] == "this_is_a_test" assert result_single.equals(result_list) diff --git a/skore/tests/unit/reports/estimator/metrics/test_numeric.py b/skore/tests/unit/reports/estimator/metrics/test_numeric.py index 1217181a44..3509483d0a 100644 --- a/skore/tests/unit/reports/estimator/metrics/test_numeric.py +++ b/skore/tests/unit/reports/estimator/metrics/test_numeric.py @@ -3,12 +3,14 @@ import joblib import numpy as np +import pandas as pd import pytest from sklearn.base import BaseEstimator from sklearn.cluster import KMeans from sklearn.datasets import make_classification from sklearn.linear_model import LogisticRegression from sklearn.metrics import ( + get_scorer, precision_score, rand_score, recall_score, @@ -515,3 +517,50 @@ def test_roc_multiclass_requires_predict_proba( report = EstimatorReport(classifier, X_test=X_test, y_test=y_test) assert hasattr(report.metrics, "roc_auc") report.metrics.roc_auc() + + +def test_summarize_scoring_dict(forest_binary_classification_with_test): + """Test that scoring can be passed as a dictionary with custom names.""" + estimator, X_test, y_test = forest_binary_classification_with_test + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + # Test with dictionary scoring + scoring_dict = { + "Custom Accuracy": "accuracy", + "Custom Precision": "precision", + "Custom R2": get_scorer("neg_mean_absolute_error"), + } + + result = report.metrics.summarize(scoring=scoring_dict).frame() + + # Check that custom names are used + assert "Custom Accuracy" in result.index + assert "Custom Precision" in result.index + assert "Custom R2" in result.index + + # Verify the result structure + assert isinstance(result, pd.DataFrame) + assert len(result.index) >= 3 # At least our 3 custom metrics + + +def test_summarize_scoring_dict_with_callables(linear_regression_with_test): + """Test that scoring dict works with callable functions.""" + estimator, X_test, y_test = linear_regression_with_test + report = EstimatorReport(estimator, X_test=X_test, y_test=y_test) + + def custom_metric(y_true, y_pred): + return np.mean(np.abs(y_true - y_pred)) + + scoring_dict = {"R Squared": "r2", "Custom MAE": custom_metric} + + result = report.metrics.summarize( + scoring=scoring_dict, scoring_kwargs={"response_method": "predict"} + ).frame() + + # Check that custom names are used + assert "R Squared" in result.index + assert "Custom MAE" in result.index + + # Verify the result structure + assert isinstance(result, pd.DataFrame) + assert len(result.index) == 2