diff --git a/examples/getting_started/plot_skore_getting_started.py b/examples/getting_started/plot_skore_getting_started.py index 77ef219b0c..82283f9cc3 100644 --- a/examples/getting_started/plot_skore_getting_started.py +++ b/examples/getting_started/plot_skore_getting_started.py @@ -122,7 +122,7 @@ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^ # # skore has also (re-)implemented a :class:`skore.CrossValidationReport` class that -# contains several :class:`skore.EstimatorReport`, one for each fold. +# contains several :class:`skore.EstimatorReport`, one for each split. # %% from skore import CrossValidationReport @@ -142,21 +142,21 @@ cv_report.metrics.summarize().frame() # %% -# or by individual fold: +# or by individual split: # %% cv_report.metrics.summarize(aggregate=None).frame() # %% -# We display the ROC curves for each fold: +# We display the ROC curves for each split: # %% roc_plot_cv = cv_report.metrics.roc() roc_plot_cv.plot() # %% -# We can retrieve the estimator report of a specific fold to investigate further, -# for example getting the report metrics for the first fold only: +# We can retrieve the estimator report of a specific split to investigate further, +# for example getting the report metrics for the first split only: # %% cv_report.estimator_reports_[0].metrics.summarize().frame() diff --git a/examples/technical_details/plot_cache_mechanism.py b/examples/technical_details/plot_cache_mechanism.py index d3eeb9f6ec..9b3865cd29 100644 --- a/examples/technical_details/plot_cache_mechanism.py +++ b/examples/technical_details/plot_cache_mechanism.py @@ -250,7 +250,7 @@ # Caching with :class:`~skore.CrossValidationReport` # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # -# :class:`~skore.CrossValidationReport` uses the same caching system for each fold +# :class:`~skore.CrossValidationReport` uses the same caching system for each split # in cross-validation by leveraging the previous :class:`~skore.EstimatorReport`: from skore import CrossValidationReport @@ -262,7 +262,7 @@ # Since a :class:`~skore.CrossValidationReport` uses many # :class:`~skore.EstimatorReport`, we will observe the same behaviour as we previously # exposed. -# The first call will be slow because it computes the predictions for each fold. +# The first call will be slow because it computes the predictions for each split. start = time.time() result = report.metrics.summarize().frame() end = time.time() diff --git a/examples/use_cases/plot_employee_salaries.py b/examples/use_cases/plot_employee_salaries.py index c9e19854fe..260cbbbeae 100644 --- a/examples/use_cases/plot_employee_salaries.py +++ b/examples/use_cases/plot_employee_salaries.py @@ -319,9 +319,9 @@ def periodic_spline_transformer(period, n_splines=None, degree=3): ).frame() # %% -# Finally, we can even get a deeper understanding by analyzing each fold in the +# Finally, we can even get a deeper understanding by analyzing each split in the # :class:`~skore.CrossValidationReport`. -# Here, we plot the actual-vs-predicted values for each fold. +# Here, we plot the actual-vs-predicted values for each split. # %% linear_model_report.metrics.prediction_error().plot(kind="actual_vs_predicted") diff --git a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py index 4ece9ee710..94a4c8f731 100644 --- a/skore/src/skore/_sklearn/_comparison/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_comparison/metrics_accessor.py @@ -1259,7 +1259,7 @@ def _get_display( y_true.append( YPlotData( estimator_name=report_name, - split_index=None, + split=None, y=report_y, ) ) @@ -1280,7 +1280,7 @@ def _get_display( y_pred.append( YPlotData( estimator_name=report_name, - split_index=None, + split=None, y=value, ) ) @@ -1301,9 +1301,7 @@ def _get_display( for report, report_name in zip( self._parent.reports_, self._parent.report_names_, strict=False ): - for split_index, estimator_report in enumerate( - report.estimator_reports_ - ): + for split, estimator_report in enumerate(report.estimator_reports_): report_X, report_y, _ = ( estimator_report.metrics._get_X_y_and_data_source_hash( data_source=data_source, @@ -1315,7 +1313,7 @@ def _get_display( y_true.append( YPlotData( estimator_name=report_name, - split_index=split_index, + split=split, y=report_y, ) ) @@ -1337,7 +1335,7 @@ def _get_display( y_pred.append( YPlotData( estimator_name=report_name, - split_index=split_index, + split=split, y=value, ) ) diff --git a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py index 9e659d354f..34f1c7d30e 100644 --- a/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_cross_validation/metrics_accessor.py @@ -1146,7 +1146,7 @@ def _get_display( y_true.append( YPlotData( estimator_name=self._parent.estimator_name_, - split_index=report_idx, + split=report_idx, y=y, ) ) @@ -1167,7 +1167,7 @@ def _get_display( y_pred.append( YPlotData( estimator_name=self._parent.estimator_name_, - split_index=report_idx, + split=report_idx, y=value, ) ) diff --git a/skore/src/skore/_sklearn/_cross_validation/report.py b/skore/src/skore/_sklearn/_cross_validation/report.py index 8c33b9568e..ecd730d9ba 100644 --- a/skore/src/skore/_sklearn/_cross_validation/report.py +++ b/skore/src/skore/_sklearn/_cross_validation/report.py @@ -92,7 +92,7 @@ class CrossValidationReport(_BaseReport, DirNamesMixin): Determines the cross-validation splitting strategy. Possible inputs for `splitter` are: - - int, to specify the number of folds in a `(Stratified)KFold`, + - int, to specify the number of splits in a `(Stratified)KFold`, - a scikit-learn :term:`CV splitter`, - An iterable yielding (train, test) splits as arrays of indices. @@ -339,19 +339,19 @@ def cache_predictions( total_estimators = len(self.estimator_reports_) progress.update(main_task, total=total_estimators) - for fold_idx, estimator_report in enumerate(self.estimator_reports_, 1): + for split_idx, estimator_report in enumerate(self.estimator_reports_, 1): # Share the parent's progress bar with child report estimator_report._progress_info = { "current_progress": progress, - "fold_info": {"current": fold_idx, "total": total_estimators}, + "split_info": {"current": split_idx, "total": total_estimators}, } - # Update the progress bar description to include the fold number + # Update the progress bar description to include the split number progress.update( main_task, description=( - "Cross-validation predictions for fold " - f"#{fold_idx}/{total_estimators}" + "Cross-validation predictions for split " + f"#{split_idx}/{total_estimators}" ), ) diff --git a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py index ec6b7322e6..0b61753096 100644 --- a/skore/src/skore/_sklearn/_estimator/metrics_accessor.py +++ b/skore/src/skore/_sklearn/_estimator/metrics_accessor.py @@ -1692,14 +1692,14 @@ def _get_display( y_true=[ YPlotData( estimator_name=self._parent.estimator_name_, - split_index=None, + split=None, y=y, ) ], y_pred=[ YPlotData( estimator_name=self._parent.estimator_name_, - split_index=None, + split=None, y=y_pred, ) ], diff --git a/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py b/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py index d9e1eba3c8..2115d6c3fc 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py +++ b/skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py @@ -60,7 +60,7 @@ class PrecisionRecallCurveDisplay( The precision-recall curve data to display. The columns are - `estimator_name` - - `split_index` (may be null) + - `split` (may be null) - `label` - `threshold` - `precision` @@ -70,7 +70,7 @@ class PrecisionRecallCurveDisplay( The average precision data to display. The columns are - `estimator_name` - - `split_index` (may be null) + - `split` (may be null) - `label` - `average_precision`. @@ -274,8 +274,8 @@ def _plot_cross_validated_estimator( line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"} if self.ml_task == "binary-classification": - for split_idx in self.precision_recall["split_index"].cat.categories: - query = f"label == {self.pos_label!r} & split_index == {split_idx}" + for split_idx in self.precision_recall["split"].cat.categories: + query = f"label == {self.pos_label!r} & split == {split_idx}" precision_recall = self.precision_recall.query(query) average_precision = self.average_precision.query(query)[ "average_precision" @@ -285,7 +285,7 @@ def _plot_cross_validated_estimator( line_kwargs, pr_curve_kwargs[split_idx] ) line_kwargs_validated["label"] = ( - f"Fold #{split_idx + 1} (AP = {average_precision:0.2f})" + f"Split #{split_idx + 1} (AP = {average_precision:0.2f})" ) (line,) = self.ax_.plot( @@ -314,8 +314,8 @@ def _plot_cross_validated_estimator( # average_precision_class = self.average_precision[class_] pr_curve_kwargs_class = pr_curve_kwargs[class_idx] - for split_idx in self.precision_recall["split_index"].cat.categories: - query = f"label == {class_label!r} & split_index == {split_idx}" + for split_idx in self.precision_recall["split"].cat.categories: + query = f"label == {class_label!r} & split == {split_idx}" precision_recall = self.precision_recall.query(query) average_precision = self.average_precision.query(query)[ "average_precision" @@ -519,7 +519,7 @@ def _plot_comparison_cross_validation( precision_recall = self.precision_recall.query(query) for split_idx, segment in precision_recall.groupby( - "split_index", observed=True + "split", observed=True ): if split_idx == 0: label_kwargs = { @@ -586,7 +586,7 @@ def _plot_comparison_cross_validation( precision_recall = self.precision_recall.query(query) for split_idx, segment in precision_recall.groupby( - "split_index", observed=True + "split", observed=True ): if split_idx == 0: label_kwargs = { @@ -836,7 +836,7 @@ def _compute_data_for_display( precision_recall_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": pos_label_validated, "threshold": threshold, "precision": precision, @@ -846,7 +846,7 @@ def _compute_data_for_display( average_precision_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": pos_label_validated, "average_precision": average_precision_i, } @@ -879,7 +879,7 @@ def _compute_data_for_display( precision_recall_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": class_, "threshold": threshold, "precision": precision, @@ -889,7 +889,7 @@ def _compute_data_for_display( average_precision_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": class_, "average_precision": average_precision_class_i, } @@ -897,7 +897,7 @@ def _compute_data_for_display( dtypes = { "estimator_name": "category", - "split_index": "category", + "split": "category", "label": "category", } @@ -929,7 +929,7 @@ def frame(self, with_average_precision: bool = False) -> DataFrame: depending on the report type: - `estimator_name`: Name of the estimator (when comparing estimators) - - `split_index`: Cross-validation fold ID (when doing cross-validation) + - `split`: Cross-validation split ID (when doing cross-validation) - `label`: Class label (for multiclass-classification) - `threshold`: Decision threshold - `precision`: Precision score at threshold @@ -966,11 +966,11 @@ def frame(self, with_average_precision: bool = False) -> DataFrame: if self.report_type == "estimator": indexing_columns = [] elif self.report_type == "cross-validation": - indexing_columns = ["split_index"] + indexing_columns = ["split"] elif self.report_type == "comparison-estimator": indexing_columns = ["estimator_name"] else: # self.report_type == "comparison-cross-validation" - indexing_columns = ["estimator_name", "split_index"] + indexing_columns = ["estimator_name", "split"] if self.ml_task == "binary-classification": columns = indexing_columns + statistical_columns diff --git a/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py b/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py index 0cc82b8e89..d489c16062 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py +++ b/skore/src/skore/_sklearn/_plot/metrics/prediction_error.py @@ -42,7 +42,7 @@ class PredictionErrorDisplay(StyleDisplayMixin, HelpDisplayMixin, PlotBackendMix The prediction error data to display. The columns are - `estimator_name` - - `split_index` (may be null) + - `split` (may be null) - `y_true` - `y_pred` - `residuals`. @@ -148,7 +148,7 @@ def _validate_data_points_kwargs( n_scatter_groups = 1 elif self.report_type == "cross-validation": allow_single_dict = True - n_scatter_groups = len(self._prediction_error["split_index"].cat.categories) + n_scatter_groups = len(self._prediction_error["split"].cat.categories) elif self.report_type in ( "comparison-estimator", "comparison-cross-validation", @@ -296,25 +296,25 @@ def _plot_cross_validated_estimator( """ scatter = [] data_points_kwargs: dict[str, Any] = {"alpha": 0.3, "s": 10} - n_splits = len(self._prediction_error["split_index"].cat.categories) + n_splits = len(self._prediction_error["split"].cat.categories) colors_markers = sample_mpl_colormap( colormaps.get_cmap("tab10"), n_splits if n_splits > 10 else 10, ) for split_idx, prediction_error_split in self._prediction_error.groupby( - "split_index", observed=True + "split", observed=True ): - data_points_kwargs_fold = { + data_points_kwargs_split = { "color": colors_markers[split_idx], **data_points_kwargs, } data_points_kwargs_validated = _validate_style_kwargs( - data_points_kwargs_fold, samples_kwargs[split_idx] + data_points_kwargs_split, samples_kwargs[split_idx] ) - label = f"Fold #{split_idx + 1}" + label = f"Split #{split_idx + 1}" if kind == "actual_vs_predicted": scatter.append( @@ -390,13 +390,13 @@ def _plot_comparison_estimator( for idx, (estimator_name, prediction_error_estimator) in enumerate( self._prediction_error.groupby("estimator_name", observed=True) ): - data_points_kwargs_fold = { + data_points_kwargs_split = { "color": colors_markers[idx], **data_points_kwargs, } data_points_kwargs_validated = _validate_style_kwargs( - data_points_kwargs_fold, samples_kwargs[idx] + data_points_kwargs_split, samples_kwargs[idx] ) if kind == "actual_vs_predicted": @@ -473,13 +473,13 @@ def _plot_comparison_cross_validation( for idx, (estimator_name, prediction_error_estimator) in enumerate( self._prediction_error.groupby("estimator_name", observed=True) ): - data_points_kwargs_fold = { + data_points_kwargs_split = { "color": colors_markers[idx], **data_points_kwargs, } data_points_kwargs_validated = _validate_style_kwargs( - data_points_kwargs_fold, samples_kwargs[idx] + data_points_kwargs_split, samples_kwargs[idx] ) if kind == "actual_vs_predicted": @@ -799,7 +799,7 @@ def _compute_data_for_display( prediction_error_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "y_true": y_true_sample_i, "y_pred": y_pred_sample_i, "residuals": residuals_sample_i, @@ -816,7 +816,7 @@ def _compute_data_for_display( prediction_error_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "y_true": y_true_sample_i, "y_pred": y_pred_sample_i, "residuals": residuals_sample_i, @@ -836,7 +836,7 @@ def _compute_data_for_display( return cls( prediction_error=DataFrame.from_records(prediction_error_records).astype( - {"estimator_name": "category", "split_index": "category"} + {"estimator_name": "category", "split": "category"} ), range_y_true=range_y_true, range_y_pred=range_y_pred, @@ -856,7 +856,7 @@ def frame(self) -> DataFrame: the report type: - `estimator_name`: Name of the estimator (when comparing estimators) - - `split_index`: Cross-validation fold ID (when doing cross-validation) + - `split`: Cross-validation split ID (when doing cross-validation) - `y_true`: True target values - `y_pred`: Predicted target values - `residuals`: Difference between true and predicted values @@ -879,10 +879,10 @@ def frame(self) -> DataFrame: if self.report_type == "estimator": columns = statistical_columns elif self.report_type == "cross-validation": - columns = ["split_index"] + statistical_columns + columns = ["split"] + statistical_columns elif self.report_type == "comparison-estimator": columns = ["estimator_name"] + statistical_columns else: # self.report_type == "comparison-cross-validation" - columns = ["estimator_name", "split_index"] + statistical_columns + columns = ["estimator_name", "split"] + statistical_columns return self._prediction_error[columns] diff --git a/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py b/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py index 71415b8bfb..33a8288a3a 100644 --- a/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py +++ b/skore/src/skore/_sklearn/_plot/metrics/roc_curve.py @@ -77,7 +77,7 @@ class RocCurveDisplay( The ROC curve data to display. The columns are - `estimator_name` - - `split_index` (may be null) + - `split` (may be null) - `label` - `threshold` - `fpr` @@ -87,7 +87,7 @@ class RocCurveDisplay( The ROC AUC data to display. The columns are - `estimator_name` - - `split_index` (may be null) + - `split` (may be null) - `label` - `roc_auc`. @@ -316,8 +316,8 @@ def _plot_cross_validated_estimator( line_kwargs: dict[str, Any] = {} if self.ml_task == "binary-classification": - for split_idx in self.roc_curve["split_index"].cat.categories: - query = f"label == {self.pos_label!r} & split_index == {split_idx}" + for split_idx in self.roc_curve["split"].cat.categories: + query = f"label == {self.pos_label!r} & split == {split_idx}" roc_curve = self.roc_curve.query(query) roc_auc = self.roc_auc.query(query)["roc_auc"].item() @@ -325,7 +325,7 @@ def _plot_cross_validated_estimator( line_kwargs, roc_curve_kwargs[split_idx] ) line_kwargs_validated["label"] = ( - f"Fold #{split_idx + 1} (AUC = {roc_auc:0.2f})" + f"Split #{split_idx + 1} (AUC = {roc_auc:0.2f})" ) (line,) = self.ax_.plot( @@ -351,9 +351,9 @@ def _plot_cross_validated_estimator( roc_auc = self.roc_auc.query(f"label == {class_label}")["roc_auc"] roc_curve_kwargs_class = roc_curve_kwargs[class_idx] - for split_idx in self.roc_curve["split_index"].cat.categories: + for split_idx in self.roc_curve["split"].cat.categories: roc_curve_label = self.roc_curve.query( - f"label == {class_label} & split_index == {split_idx}" + f"label == {class_label} & split == {split_idx}" ) line_kwargs_validated = _validate_style_kwargs( @@ -583,9 +583,7 @@ def _plot_comparison_cross_validation( line_kwargs, roc_curve_kwargs[report_idx] ) - for split_idx, segment in roc_curve.groupby( - "split_index", observed=True - ): + for split_idx, segment in roc_curve.groupby("split", observed=True): if split_idx == 0: label_kwargs = { "label": ( @@ -649,9 +647,7 @@ def _plot_comparison_cross_validation( roc_auc = self.roc_auc.query(query)["roc_auc"] - for split_idx, segment in roc_curve.groupby( - "split_index", observed=True - ): + for split_idx, segment in roc_curve.groupby("split", observed=True): if split_idx == 0: label_kwargs = { "label": ( @@ -913,7 +909,7 @@ def _compute_data_for_display( roc_curve_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": pos_label_validated, "threshold": threshold, "fpr": fpr, @@ -924,7 +920,7 @@ def _compute_data_for_display( roc_auc_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": pos_label_validated, "roc_auc": roc_auc_i, } @@ -952,7 +948,7 @@ def _compute_data_for_display( roc_curve_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": class_, "threshold": threshold, "fpr": fpr, @@ -963,7 +959,7 @@ def _compute_data_for_display( roc_auc_records.append( { "estimator_name": y_true_i.estimator_name, - "split_index": y_true_i.split_index, + "split": y_true_i.split, "label": class_, "roc_auc": roc_auc_class_i, } @@ -971,7 +967,7 @@ def _compute_data_for_display( dtypes = { "estimator_name": "category", - "split_index": "category", + "split": "category", "label": "category", } @@ -999,7 +995,7 @@ def frame(self, with_roc_auc: bool = False) -> DataFrame: report type: - `estimator_name`: Name of the estimator (when comparing estimators) - - `split_index`: Cross-validation fold ID (when doing cross-validation) + - `split`: Cross-validation split ID (when doing cross-validation) - `label`: Class label (for multiclass-classification) - `threshold`: Decision threshold - `fpr`: False Positive Rate @@ -1035,11 +1031,11 @@ def frame(self, with_roc_auc: bool = False) -> DataFrame: if self.report_type == "estimator": indexing_columns = [] elif self.report_type == "cross-validation": - indexing_columns = ["split_index"] + indexing_columns = ["split"] elif self.report_type == "comparison-estimator": indexing_columns = ["estimator_name"] else: # self.report_type == "comparison-cross-validation" - indexing_columns = ["estimator_name", "split_index"] + indexing_columns = ["estimator_name", "split"] if self.ml_task == "binary-classification": columns = indexing_columns + statistical_columns diff --git a/skore/src/skore/_sklearn/types.py b/skore/src/skore/_sklearn/types.py index b54cc1265d..4108ff8b41 100644 --- a/skore/src/skore/_sklearn/types.py +++ b/skore/src/skore/_sklearn/types.py @@ -41,7 +41,7 @@ class YPlotData: """ estimator_name: str - split_index: int | None + split: int | None y: ArrayLike diff --git a/skore/src/skore/_utils/_testing.py b/skore/src/skore/_utils/_testing.py index 0d536b1725..b83b1d3551 100644 --- a/skore/src/skore/_utils/_testing.py +++ b/skore/src/skore/_utils/_testing.py @@ -53,7 +53,7 @@ def check_roc_curve_display_data(display: RocCurveDisplay): """Check the structure of the display's internal data.""" assert list(display.roc_curve.columns) == [ "estimator_name", - "split_index", + "split", "label", "threshold", "fpr", @@ -61,7 +61,7 @@ def check_roc_curve_display_data(display: RocCurveDisplay): ] assert list(display.roc_auc.columns) == [ "estimator_name", - "split_index", + "split", "label", "roc_auc", ] @@ -71,7 +71,7 @@ def check_precision_recall_curve_display_data(display: PrecisionRecallCurveDispl """Check the structure of the display's internal data.""" assert list(display.precision_recall.columns) == [ "estimator_name", - "split_index", + "split", "label", "threshold", "precision", @@ -79,7 +79,7 @@ def check_precision_recall_curve_display_data(display: PrecisionRecallCurveDispl ] assert list(display.average_precision.columns) == [ "estimator_name", - "split_index", + "split", "label", "average_precision", ] @@ -104,7 +104,7 @@ def check_frame_structure(df, expected_index, expected_data_columns): df : DataFrame The DataFrame to check. expected_index : list of str - The expected index column names (e.g., `estimator_name`, `split_index`, + The expected index column names (e.g., `estimator_name`, `split`, `label`). These columns should be of categorical type. expected_data_columns : list of str diff --git a/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py b/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py index 3adca38651..5a464a01d1 100644 --- a/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py +++ b/skore/tests/unit/displays/precision_recall_curve/test_comparison_cross_validation.py @@ -210,14 +210,14 @@ def test_binary_classification_constructor(forest_binary_classification_data): ) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df.query("estimator_name == 'estimator_1'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv)) assert df.query("estimator_name == 'estimator_2'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv + 1)) assert df["estimator_name"].unique().tolist() == report.report_names_ assert df["label"].unique() == 1 @@ -235,15 +235,15 @@ def test_multiclass_classification_constructor(forest_multiclass_classification_ ) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] classes = np.unique(y) for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df.query("estimator_name == 'estimator_1'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv)) assert df.query("estimator_name == 'estimator_2'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv + 1)) assert df["estimator_name"].unique().tolist() == report.report_names_ np.testing.assert_array_equal(df["label"].unique(), classes) @@ -261,7 +261,7 @@ def test_frame_binary_classification( display = report.metrics.precision_recall() df = display.frame(with_average_precision=with_average_precision) - expected_index = ["estimator_name", "split_index"] + expected_index = ["estimator_name", "split"] expected_columns = ["threshold", "precision", "recall"] if with_average_precision: expected_columns.append("average_precision") @@ -270,9 +270,7 @@ def test_frame_binary_classification( assert df["estimator_name"].nunique() == len(report.reports_) if with_average_precision: - for (_, _), group in df.groupby( - ["estimator_name", "split_index"], observed=True - ): + for (_, _), group in df.groupby(["estimator_name", "split"], observed=True): assert group["average_precision"].nunique() == 1 @@ -287,7 +285,7 @@ def test_frame_multiclass_classification( display = report.metrics.precision_recall() df = display.frame(with_average_precision=with_average_precision) - expected_index = ["estimator_name", "split_index", "label"] + expected_index = ["estimator_name", "split", "label"] expected_columns = ["threshold", "precision", "recall"] if with_average_precision: expected_columns.append("average_precision") @@ -297,6 +295,6 @@ def test_frame_multiclass_classification( if with_average_precision: for (_, _, _), group in df.groupby( - ["estimator_name", "split_index", "label"], observed=True + ["estimator_name", "split", "label"], observed=True ): assert group["average_precision"].nunique() == 1 diff --git a/skore/tests/unit/displays/precision_recall_curve/test_comparison_estimator.py b/skore/tests/unit/displays/precision_recall_curve/test_comparison_estimator.py index 037e0e4121..3c833f288e 100644 --- a/skore/tests/unit/displays/precision_recall_curve/test_comparison_estimator.py +++ b/skore/tests/unit/displays/precision_recall_curve/test_comparison_estimator.py @@ -347,11 +347,11 @@ def test_binary_classification_constructor( ) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique().tolist() == report.report_names_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() assert df["label"].unique() == 1 assert len(display.average_precision) == 2 @@ -375,11 +375,11 @@ def test_multiclass_classification_constructor( ) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique().tolist() == report.report_names_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() np.testing.assert_array_equal(df["label"].unique(), np.unique(y_train)) assert len(display.average_precision) == len(np.unique(y_train)) * 2 diff --git a/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py b/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py index 1879583e8a..7b8c0b7b00 100644 --- a/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py +++ b/skore/tests/unit/displays/precision_recall_curve/test_cross_validation.py @@ -43,11 +43,11 @@ def test_binary_classification( for split_idx, line in enumerate(display.lines_): assert isinstance(line, mpl.lines.Line2D) average_precision = display.average_precision.query( - f"label == {pos_label} & split_index == {split_idx}" + f"label == {pos_label} & split == {split_idx}" )["average_precision"].item() assert line.get_label() == ( - f"Fold #{split_idx + 1} (AP = {average_precision:0.2f})" + f"Split #{split_idx + 1} (AP = {average_precision:0.2f})" ) assert mpl.colors.to_rgba(line.get_color()) == expected_colors[split_idx] @@ -99,7 +99,7 @@ def test_multiclass_classification( assert isinstance(precision_recall_curve_mpl, mpl.lines.Line2D) if split_idx == 0: average_precision = display.average_precision.query( - f"label == {class_label} & split_index == {split_idx}" + f"label == {class_label} & split == {split_idx}" )["average_precision"] assert precision_recall_curve_mpl.get_label() == ( f"{str(class_label).title()} " @@ -155,16 +155,16 @@ def test_frame_binary_classification( df = report.metrics.precision_recall().frame( with_average_precision=with_average_precision ) - expected_index = ["split_index"] + expected_index = ["split"] expected_columns = ["threshold", "precision", "recall"] if with_average_precision: expected_columns.append("average_precision") check_frame_structure(df, expected_index, expected_columns) - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv if with_average_precision: - for (_), group in df.groupby(["split_index"], observed=True): + for (_), group in df.groupby(["split"], observed=True): assert group["average_precision"].nunique() == 1 @@ -178,17 +178,17 @@ def test_frame_multiclass_classification( df = report.metrics.precision_recall().frame( with_average_precision=with_average_precision ) - expected_index = ["split_index", "label"] + expected_index = ["split", "label"] expected_columns = ["threshold", "precision", "recall"] if with_average_precision: expected_columns.append("average_precision") check_frame_structure(df, expected_index, expected_columns) - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv assert df["label"].nunique() == len(np.unique(y)) if with_average_precision: - for (_, _), group in df.groupby(["split_index", "label"], observed=True): + for (_, _), group in df.groupby(["split", "label"], observed=True): assert group["average_precision"].nunique() == 1 @@ -197,14 +197,14 @@ def test_legend( ): """Check the rendering of the legend for with an `CrossValidationReport`.""" - # binary classification <= 5 folds + # binary classification <= 5 splits estimator, X, y = logistic_binary_classification_data report = CrossValidationReport(estimator, X=X, y=y, splitter=5) display = report.metrics.precision_recall() display.plot() check_legend_position(display.ax_, loc="lower left", position="inside") - # binary classification > 5 folds + # binary classification > 5 splits estimator, X, y = logistic_binary_classification_data report = CrossValidationReport(estimator, X=X, y=y, splitter=10) display = report.metrics.precision_recall() @@ -239,11 +239,11 @@ def test_binary_classification_constructor(logistic_binary_classification_data): report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv assert df["label"].unique() == 1 assert len(display.average_precision) == cv @@ -255,11 +255,11 @@ def test_multiclass_classification_constructor(logistic_multiclass_classificatio report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].unique().tolist() == list(range(cv)) + assert df["split"].unique().tolist() == list(range(cv)) np.testing.assert_array_equal(df["label"].unique(), np.unique(y)) assert len(display.average_precision) == len(np.unique(y)) * cv diff --git a/skore/tests/unit/displays/precision_recall_curve/test_estimator.py b/skore/tests/unit/displays/precision_recall_curve/test_estimator.py index 48c1c9aef0..a349d1c60b 100644 --- a/skore/tests/unit/displays/precision_recall_curve/test_estimator.py +++ b/skore/tests/unit/displays/precision_recall_curve/test_estimator.py @@ -379,11 +379,11 @@ def test_binary_classification_constructor( ) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() assert df["label"].unique() == 1 assert len(display.average_precision) == 1 @@ -401,11 +401,11 @@ def test_multiclass_classification_constructor( ) display = report.metrics.precision_recall() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.precision_recall, display.average_precision]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() np.testing.assert_array_equal(df["label"].unique(), estimator.classes_) assert len(display.average_precision) == len(estimator.classes_) diff --git a/skore/tests/unit/displays/prediction_error/test_comparison_cross_validation.py b/skore/tests/unit/displays/prediction_error/test_comparison_cross_validation.py index 12edd8e2d9..eafc45ad9a 100644 --- a/skore/tests/unit/displays/prediction_error/test_comparison_cross_validation.py +++ b/skore/tests/unit/displays/prediction_error/test_comparison_cross_validation.py @@ -133,14 +133,14 @@ def test_constructor(linear_regression_data): ) display = report.metrics.prediction_error() - index_columns = ["estimator_name", "split_index"] + index_columns = ["estimator_name", "split"] df = display._prediction_error assert all(col in df.columns for col in index_columns) assert df.query("estimator_name == 'estimator_1'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv)) assert df.query("estimator_name == 'estimator_2'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv + 1)) assert df["estimator_name"].unique().tolist() == report.report_names_ @@ -151,7 +151,7 @@ def test_frame(comparison_cross_validation_reports_regression): display = report.metrics.prediction_error() df = display.frame() - expected_index = ["estimator_name", "split_index"] + expected_index = ["estimator_name", "split"] expected_columns = ["y_true", "y_pred", "residuals"] check_frame_structure(df, expected_index, expected_columns) diff --git a/skore/tests/unit/displays/prediction_error/test_comparison_estimator.py b/skore/tests/unit/displays/prediction_error/test_comparison_estimator.py index 686393e028..dd7eaf0dce 100644 --- a/skore/tests/unit/displays/prediction_error/test_comparison_estimator.py +++ b/skore/tests/unit/displays/prediction_error/test_comparison_estimator.py @@ -272,8 +272,8 @@ def test_constructor(linear_regression_with_train_test): ) display = report.metrics.prediction_error() - index_columns = ["estimator_name", "split_index"] + index_columns = ["estimator_name", "split"] df = display._prediction_error assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique().tolist() == report.report_names_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() diff --git a/skore/tests/unit/displays/prediction_error/test_cross_validation.py b/skore/tests/unit/displays/prediction_error/test_cross_validation.py index da43e24472..6fba37c942 100644 --- a/skore/tests/unit/displays/prediction_error/test_cross_validation.py +++ b/skore/tests/unit/displays/prediction_error/test_cross_validation.py @@ -24,7 +24,7 @@ def test_regression(pyplot, linear_regression_data, data_source): # check the structure of the attributes assert isinstance(display._prediction_error, pd.DataFrame) - assert display._prediction_error["split_index"].nunique() == cv + assert display._prediction_error["split"].nunique() == cv assert display.data_source == data_source assert isinstance(display.range_y_true, RangeData) assert isinstance(display.range_y_pred, RangeData) @@ -68,7 +68,7 @@ def test_regression_actual_vs_predicted(pyplot, linear_regression_data): # check the structure of the attributes assert isinstance(display._prediction_error, pd.DataFrame) - assert display._prediction_error["split_index"].nunique() == cv + assert display._prediction_error["split"].nunique() == cv assert display.data_source == "test" assert isinstance(display.line_, mpl.lines.Line2D) @@ -135,11 +135,11 @@ def test_frame(linear_regression_data): display = report.metrics.prediction_error() df = display.frame() - expected_index = ["split_index"] + expected_index = ["split"] expected_columns = ["y_true", "y_pred", "residuals"] check_frame_structure(df, expected_index, expected_columns) - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv def test_legend(pyplot, linear_regression_data): @@ -171,8 +171,8 @@ def test_constructor(linear_regression_data): report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) display = report.metrics.prediction_error() - index_columns = ["estimator_name", "split_index"] + index_columns = ["estimator_name", "split"] df = display._prediction_error assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].unique().tolist() == list(range(cv)) + assert df["split"].unique().tolist() == list(range(cv)) diff --git a/skore/tests/unit/displays/prediction_error/test_estimator.py b/skore/tests/unit/displays/prediction_error/test_estimator.py index daaa330621..aa64b54f1d 100644 --- a/skore/tests/unit/displays/prediction_error/test_estimator.py +++ b/skore/tests/unit/displays/prediction_error/test_estimator.py @@ -284,11 +284,11 @@ def test_constructor(linear_regression_with_train_test): ) display = report.metrics.prediction_error() - index_columns = ["estimator_name", "split_index"] + index_columns = ["estimator_name", "split"] df = display._prediction_error assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() np.testing.assert_allclose(df["y_true"], y_test) np.testing.assert_allclose(df["y_pred"], estimator.predict(X_test)) np.testing.assert_allclose(df["residuals"], y_test - estimator.predict(X_test)) diff --git a/skore/tests/unit/displays/roc_curve/test_comparison_cross_validation.py b/skore/tests/unit/displays/roc_curve/test_comparison_cross_validation.py index 3c24d020d6..e86887356e 100644 --- a/skore/tests/unit/displays/roc_curve/test_comparison_cross_validation.py +++ b/skore/tests/unit/displays/roc_curve/test_comparison_cross_validation.py @@ -224,14 +224,14 @@ def test_binary_classification_constructor(logistic_binary_classification_data): ) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df.query("estimator_name == 'estimator_1'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv)) assert df.query("estimator_name == 'estimator_2'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv + 1)) assert df["estimator_name"].unique().tolist() == report.report_names_ assert df["label"].unique() == 1 @@ -249,15 +249,15 @@ def test_multiclass_classification_constructor(logistic_multiclass_classificatio ) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] classes = np.unique(y) for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df.query("estimator_name == 'estimator_1'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv)) assert df.query("estimator_name == 'estimator_2'")[ - "split_index" + "split" ].unique().tolist() == list(range(cv + 1)) assert df["estimator_name"].unique().tolist() == report.report_names_ np.testing.assert_array_equal(df["label"].unique(), classes) @@ -276,7 +276,7 @@ def test_frame_binary_classification( display = report.metrics.roc() df = display.frame(with_roc_auc=with_roc_auc) - expected_index = ["estimator_name", "split_index"] + expected_index = ["estimator_name", "split"] expected_columns = ["threshold", "fpr", "tpr"] if with_roc_auc: expected_columns.append("roc_auc") @@ -285,9 +285,7 @@ def test_frame_binary_classification( assert df["estimator_name"].nunique() == len(report.reports_) if with_roc_auc: - for (_, _), group in df.groupby( - ["estimator_name", "split_index"], observed=True - ): + for (_, _), group in df.groupby(["estimator_name", "split"], observed=True): assert group["roc_auc"].nunique() == 1 @@ -302,7 +300,7 @@ def test_frame_multiclass_classification( display = report.metrics.roc() df = display.frame(with_roc_auc=with_roc_auc) - expected_index = ["estimator_name", "split_index", "label"] + expected_index = ["estimator_name", "split", "label"] expected_columns = ["threshold", "fpr", "tpr"] if with_roc_auc: expected_columns.append("roc_auc") @@ -312,6 +310,6 @@ def test_frame_multiclass_classification( if with_roc_auc: for (_, _, _), group in df.groupby( - ["estimator_name", "split_index", "label"], observed=True + ["estimator_name", "split", "label"], observed=True ): assert group["roc_auc"].nunique() == 1 diff --git a/skore/tests/unit/displays/roc_curve/test_comparison_estimator.py b/skore/tests/unit/displays/roc_curve/test_comparison_estimator.py index 4dce0b6362..ba97b6edfa 100644 --- a/skore/tests/unit/displays/roc_curve/test_comparison_estimator.py +++ b/skore/tests/unit/displays/roc_curve/test_comparison_estimator.py @@ -368,11 +368,11 @@ def test_binary_classification_constructor( ) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique().tolist() == report.report_names_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() assert df["label"].unique() == 1 assert len(display.roc_auc) == 2 @@ -396,11 +396,11 @@ def test_multiclass_classification_constructor( ) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique().tolist() == report.report_names_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() np.testing.assert_array_equal(df["label"].unique(), np.unique(y_train)) assert len(display.roc_auc) == len(np.unique(y_train)) * 2 diff --git a/skore/tests/unit/displays/roc_curve/test_cross_validation.py b/skore/tests/unit/displays/roc_curve/test_cross_validation.py index e59f2ae0b2..fa8c52c4c9 100644 --- a/skore/tests/unit/displays/roc_curve/test_cross_validation.py +++ b/skore/tests/unit/displays/roc_curve/test_cross_validation.py @@ -35,9 +35,7 @@ def test_binary_classification( == [display.pos_label] ) assert ( - display.roc_curve["split_index"].nunique() - == display.roc_auc["split_index"].nunique() - == cv + display.roc_curve["split"].nunique() == display.roc_auc["split"].nunique() == cv ) display.plot() @@ -47,10 +45,10 @@ def test_binary_classification( for split_idx, line in enumerate(display.lines_): assert isinstance(line, mpl.lines.Line2D) roc_auc_split = display.roc_auc.query( - f"label == {pos_label} & split_index == {split_idx}" + f"label == {pos_label} & split == {split_idx}" )["roc_auc"].item() assert line.get_label() == ( - f"Fold #{split_idx + 1} (AUC = {roc_auc_split:0.2f})" + f"Split #{split_idx + 1} (AUC = {roc_auc_split:0.2f})" ) assert mpl.colors.to_rgba(line.get_color()) == expected_colors[split_idx] @@ -97,9 +95,7 @@ def test_multiclass_classification( == list(class_labels) ) assert ( - display.roc_curve["split_index"].nunique() - == display.roc_auc["split_index"].nunique() - == cv + display.roc_curve["split"].nunique() == display.roc_auc["split"].nunique() == cv ) display.plot() @@ -190,16 +186,16 @@ def test_frame_binary_classification(logistic_binary_classification_data, with_r (estimator, X, y), cv = logistic_binary_classification_data, 3 report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) df = report.metrics.roc().frame(with_roc_auc=with_roc_auc) - expected_index = ["split_index"] + expected_index = ["split"] expected_columns = ["threshold", "fpr", "tpr"] if with_roc_auc: expected_columns.append("roc_auc") check_frame_structure(df, expected_index, expected_columns) - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv if with_roc_auc: - for (_), group in df.groupby(["split_index"], observed=True): + for (_), group in df.groupby(["split"], observed=True): assert group["roc_auc"].nunique() == 1 @@ -211,17 +207,17 @@ def test_frame_multiclass_classification( (estimator, X, y), cv = logistic_multiclass_classification_data, 3 report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) df = report.metrics.roc().frame(with_roc_auc=with_roc_auc) - expected_index = ["split_index", "label"] + expected_index = ["split", "label"] expected_columns = ["threshold", "fpr", "tpr"] if with_roc_auc: expected_columns.append("roc_auc") check_frame_structure(df, expected_index, expected_columns) - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv assert df["label"].nunique() == len(np.unique(y)) if with_roc_auc: - for (_, _), group in df.groupby(["split_index", "label"], observed=True): + for (_, _), group in df.groupby(["split", "label"], observed=True): assert group["roc_auc"].nunique() == 1 @@ -231,14 +227,14 @@ def test_legend( """Check the rendering of the legend for ROC curves with a `CrossValidationReport`.""" - # binary classification <= 5 folds + # binary classification <= 5 splits estimator, X, y = logistic_binary_classification_data report = CrossValidationReport(estimator, X=X, y=y, splitter=5) display = report.metrics.roc() display.plot() check_legend_position(display.ax_, loc="lower right", position="inside") - # binary classification > 5 folds + # binary classification > 5 splits estimator, X, y = logistic_binary_classification_data report = CrossValidationReport(estimator, X=X, y=y, splitter=10) display = report.metrics.roc() @@ -273,11 +269,11 @@ def test_binary_classification_constructor(logistic_binary_classification_data): report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].nunique() == cv + assert df["split"].nunique() == cv assert df["label"].unique() == 1 assert len(display.roc_auc) == cv @@ -289,11 +285,11 @@ def test_multiclass_classification_constructor(logistic_multiclass_classificatio report = CrossValidationReport(estimator, X=X, y=y, splitter=cv) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].unique().tolist() == list(range(cv)) + assert df["split"].unique().tolist() == list(range(cv)) np.testing.assert_array_equal(df["label"].unique(), np.unique(y)) assert len(display.roc_auc) == len(np.unique(y)) * cv diff --git a/skore/tests/unit/displays/roc_curve/test_estimator.py b/skore/tests/unit/displays/roc_curve/test_estimator.py index 1b05a02b0b..5ce5b27c2d 100644 --- a/skore/tests/unit/displays/roc_curve/test_estimator.py +++ b/skore/tests/unit/displays/roc_curve/test_estimator.py @@ -383,11 +383,11 @@ def test_binary_classification_constructor( ) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() assert df["label"].unique() == 1 assert len(display.roc_auc) == 1 @@ -405,11 +405,11 @@ def test_multiclass_classification_constructor( ) display = report.metrics.roc() - index_columns = ["estimator_name", "split_index", "label"] + index_columns = ["estimator_name", "split", "label"] for df in [display.roc_curve, display.roc_auc]: assert all(col in df.columns for col in index_columns) assert df["estimator_name"].unique() == report.estimator_name_ - assert df["split_index"].isnull().all() + assert df["split"].isnull().all() np.testing.assert_array_equal(df["label"].unique(), estimator.classes_) assert len(display.roc_auc) == len(estimator.classes_) diff --git a/sphinx/index.rst b/sphinx/index.rst index 39217f9a62..e6a45134ba 100644 --- a/sphinx/index.rst +++ b/sphinx/index.rst @@ -32,7 +32,7 @@ Key features of Skore Lib All in just one line of code. Under the hood, we use efficient caching to make the computations blazing fast. -- :class:`skore.CrossValidationReport`: get a skore estimator report for each fold +- :class:`skore.CrossValidationReport`: get a skore estimator report for each split of your cross-validation. - :class:`skore.ComparisonReport`: benchmark your skore estimator reports. diff --git a/sphinx/reference/report/cross_validation_report.rst b/sphinx/reference/report/cross_validation_report.rst index 435ed6ec32..ae8c2b3ae0 100644 --- a/sphinx/reference/report/cross_validation_report.rst +++ b/sphinx/reference/report/cross_validation_report.rst @@ -38,7 +38,7 @@ Metrics ------- The `metrics` accessor helps you to evaluate the statistical performance of your -estimator across cross-validation folds. +estimator across cross-validation splits. .. autosummary:: :toctree: ../api/