Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions examples/getting_started/plot_skore_getting_started.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# skore has also (re-)implemented a :class:`skore.CrossValidationReport` class that
# contains several :class:`skore.EstimatorReport`, one for each fold.
# contains several :class:`skore.EstimatorReport`, one for each split.

# %%
from skore import CrossValidationReport
Expand All @@ -142,21 +142,21 @@
cv_report.metrics.summarize().frame()

# %%
# or by individual fold:
# or by individual split:

# %%
cv_report.metrics.summarize(aggregate=None).frame()

# %%
# We display the ROC curves for each fold:
# We display the ROC curves for each split:

# %%
roc_plot_cv = cv_report.metrics.roc()
roc_plot_cv.plot()

# %%
# We can retrieve the estimator report of a specific fold to investigate further,
# for example getting the report metrics for the first fold only:
# We can retrieve the estimator report of a specific split to investigate further,
# for example getting the report metrics for the first split only:

# %%
cv_report.estimator_reports_[0].metrics.summarize().frame()
Expand Down
4 changes: 2 additions & 2 deletions examples/technical_details/plot_cache_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
# Caching with :class:`~skore.CrossValidationReport`
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# :class:`~skore.CrossValidationReport` uses the same caching system for each fold
# :class:`~skore.CrossValidationReport` uses the same caching system for each split
# in cross-validation by leveraging the previous :class:`~skore.EstimatorReport`:
from skore import CrossValidationReport

Expand All @@ -262,7 +262,7 @@
# Since a :class:`~skore.CrossValidationReport` uses many
# :class:`~skore.EstimatorReport`, we will observe the same behaviour as we previously
# exposed.
# The first call will be slow because it computes the predictions for each fold.
# The first call will be slow because it computes the predictions for each split.
start = time.time()
result = report.metrics.summarize().frame()
end = time.time()
Expand Down
4 changes: 2 additions & 2 deletions examples/use_cases/plot_employee_salaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,9 @@ def periodic_spline_transformer(period, n_splines=None, degree=3):
).frame()

# %%
# Finally, we can even get a deeper understanding by analyzing each fold in the
# Finally, we can even get a deeper understanding by analyzing each split in the
# :class:`~skore.CrossValidationReport`.
# Here, we plot the actual-vs-predicted values for each fold.
# Here, we plot the actual-vs-predicted values for each split.

# %%
linear_model_report.metrics.prediction_error().plot(kind="actual_vs_predicted")
Expand Down
12 changes: 5 additions & 7 deletions skore/src/skore/_sklearn/_comparison/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,7 @@ def _get_display(
y_true.append(
YPlotData(
estimator_name=report_name,
split_index=None,
split=None,
y=report_y,
)
)
Expand All @@ -1280,7 +1280,7 @@ def _get_display(
y_pred.append(
YPlotData(
estimator_name=report_name,
split_index=None,
split=None,
y=value,
)
)
Expand All @@ -1301,9 +1301,7 @@ def _get_display(
for report, report_name in zip(
self._parent.reports_, self._parent.report_names_, strict=False
):
for split_index, estimator_report in enumerate(
report.estimator_reports_
):
for split, estimator_report in enumerate(report.estimator_reports_):
report_X, report_y, _ = (
estimator_report.metrics._get_X_y_and_data_source_hash(
data_source=data_source,
Expand All @@ -1315,7 +1313,7 @@ def _get_display(
y_true.append(
YPlotData(
estimator_name=report_name,
split_index=split_index,
split=split,
y=report_y,
)
)
Expand All @@ -1337,7 +1335,7 @@ def _get_display(
y_pred.append(
YPlotData(
estimator_name=report_name,
split_index=split_index,
split=split,
y=value,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1146,7 +1146,7 @@ def _get_display(
y_true.append(
YPlotData(
estimator_name=self._parent.estimator_name_,
split_index=report_idx,
split=report_idx,
y=y,
)
)
Expand All @@ -1167,7 +1167,7 @@ def _get_display(
y_pred.append(
YPlotData(
estimator_name=self._parent.estimator_name_,
split_index=report_idx,
split=report_idx,
y=value,
)
)
Expand Down
12 changes: 6 additions & 6 deletions skore/src/skore/_sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class CrossValidationReport(_BaseReport, DirNamesMixin):
Determines the cross-validation splitting strategy.
Possible inputs for `splitter` are:

- int, to specify the number of folds in a `(Stratified)KFold`,
- int, to specify the number of splits in a `(Stratified)KFold`,
- a scikit-learn :term:`CV splitter`,
- An iterable yielding (train, test) splits as arrays of indices.

Expand Down Expand Up @@ -339,19 +339,19 @@ def cache_predictions(
total_estimators = len(self.estimator_reports_)
progress.update(main_task, total=total_estimators)

for fold_idx, estimator_report in enumerate(self.estimator_reports_, 1):
for split_idx, estimator_report in enumerate(self.estimator_reports_, 1):
# Share the parent's progress bar with child report
estimator_report._progress_info = {
"current_progress": progress,
"fold_info": {"current": fold_idx, "total": total_estimators},
"split_info": {"current": split_idx, "total": total_estimators},
}

# Update the progress bar description to include the fold number
# Update the progress bar description to include the split number
progress.update(
main_task,
description=(
"Cross-validation predictions for fold "
f"#{fold_idx}/{total_estimators}"
"Cross-validation predictions for split "
f"#{split_idx}/{total_estimators}"
),
)

Expand Down
4 changes: 2 additions & 2 deletions skore/src/skore/_sklearn/_estimator/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1692,14 +1692,14 @@ def _get_display(
y_true=[
YPlotData(
estimator_name=self._parent.estimator_name_,
split_index=None,
split=None,
y=y,
)
],
y_pred=[
YPlotData(
estimator_name=self._parent.estimator_name_,
split_index=None,
split=None,
y=y_pred,
)
],
Expand Down
34 changes: 17 additions & 17 deletions skore/src/skore/_sklearn/_plot/metrics/precision_recall_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class PrecisionRecallCurveDisplay(
The precision-recall curve data to display. The columns are

- `estimator_name`
- `split_index` (may be null)
- `split` (may be null)
- `label`
- `threshold`
- `precision`
Expand All @@ -70,7 +70,7 @@ class PrecisionRecallCurveDisplay(
The average precision data to display. The columns are

- `estimator_name`
- `split_index` (may be null)
- `split` (may be null)
- `label`
- `average_precision`.

Expand Down Expand Up @@ -274,8 +274,8 @@ def _plot_cross_validated_estimator(
line_kwargs: dict[str, Any] = {"drawstyle": "steps-post"}

if self.ml_task == "binary-classification":
for split_idx in self.precision_recall["split_index"].cat.categories:
query = f"label == {self.pos_label!r} & split_index == {split_idx}"
for split_idx in self.precision_recall["split"].cat.categories:
query = f"label == {self.pos_label!r} & split == {split_idx}"
precision_recall = self.precision_recall.query(query)
average_precision = self.average_precision.query(query)[
"average_precision"
Expand All @@ -285,7 +285,7 @@ def _plot_cross_validated_estimator(
line_kwargs, pr_curve_kwargs[split_idx]
)
line_kwargs_validated["label"] = (
f"Fold #{split_idx + 1} (AP = {average_precision:0.2f})"
f"Split #{split_idx + 1} (AP = {average_precision:0.2f})"
)

(line,) = self.ax_.plot(
Expand Down Expand Up @@ -314,8 +314,8 @@ def _plot_cross_validated_estimator(
# average_precision_class = self.average_precision[class_]
pr_curve_kwargs_class = pr_curve_kwargs[class_idx]

for split_idx in self.precision_recall["split_index"].cat.categories:
query = f"label == {class_label!r} & split_index == {split_idx}"
for split_idx in self.precision_recall["split"].cat.categories:
query = f"label == {class_label!r} & split == {split_idx}"
precision_recall = self.precision_recall.query(query)
average_precision = self.average_precision.query(query)[
"average_precision"
Expand Down Expand Up @@ -519,7 +519,7 @@ def _plot_comparison_cross_validation(
precision_recall = self.precision_recall.query(query)

for split_idx, segment in precision_recall.groupby(
"split_index", observed=True
"split", observed=True
):
if split_idx == 0:
label_kwargs = {
Expand Down Expand Up @@ -586,7 +586,7 @@ def _plot_comparison_cross_validation(
precision_recall = self.precision_recall.query(query)

for split_idx, segment in precision_recall.groupby(
"split_index", observed=True
"split", observed=True
):
if split_idx == 0:
label_kwargs = {
Expand Down Expand Up @@ -836,7 +836,7 @@ def _compute_data_for_display(
precision_recall_records.append(
{
"estimator_name": y_true_i.estimator_name,
"split_index": y_true_i.split_index,
"split": y_true_i.split,
"label": pos_label_validated,
"threshold": threshold,
"precision": precision,
Expand All @@ -846,7 +846,7 @@ def _compute_data_for_display(
average_precision_records.append(
{
"estimator_name": y_true_i.estimator_name,
"split_index": y_true_i.split_index,
"split": y_true_i.split,
"label": pos_label_validated,
"average_precision": average_precision_i,
}
Expand Down Expand Up @@ -879,7 +879,7 @@ def _compute_data_for_display(
precision_recall_records.append(
{
"estimator_name": y_true_i.estimator_name,
"split_index": y_true_i.split_index,
"split": y_true_i.split,
"label": class_,
"threshold": threshold,
"precision": precision,
Expand All @@ -889,15 +889,15 @@ def _compute_data_for_display(
average_precision_records.append(
{
"estimator_name": y_true_i.estimator_name,
"split_index": y_true_i.split_index,
"split": y_true_i.split,
"label": class_,
"average_precision": average_precision_class_i,
}
)

dtypes = {
"estimator_name": "category",
"split_index": "category",
"split": "category",
"label": "category",
}

Expand Down Expand Up @@ -929,7 +929,7 @@ def frame(self, with_average_precision: bool = False) -> DataFrame:
depending on the report type:

- `estimator_name`: Name of the estimator (when comparing estimators)
- `split_index`: Cross-validation fold ID (when doing cross-validation)
- `split`: Cross-validation split ID (when doing cross-validation)
- `label`: Class label (for multiclass-classification)
- `threshold`: Decision threshold
- `precision`: Precision score at threshold
Expand Down Expand Up @@ -966,11 +966,11 @@ def frame(self, with_average_precision: bool = False) -> DataFrame:
if self.report_type == "estimator":
indexing_columns = []
elif self.report_type == "cross-validation":
indexing_columns = ["split_index"]
indexing_columns = ["split"]
elif self.report_type == "comparison-estimator":
indexing_columns = ["estimator_name"]
else: # self.report_type == "comparison-cross-validation"
indexing_columns = ["estimator_name", "split_index"]
indexing_columns = ["estimator_name", "split"]

if self.ml_task == "binary-classification":
columns = indexing_columns + statistical_columns
Expand Down
Loading
Loading