Skip to content

Conversation

@waridrox
Copy link
Contributor

@waridrox waridrox commented May 19, 2025

Closes #1445. Converting to draft to iterate upon after #1669, #1701, #1709

Implementation

Added subplot support to the main metric display classes: PredictionErrorDisplay, RocCurveDisplay, and PrecisionRecallCurveDisplay. Instead of overlaying all results on a single axis, one can visualize metrics across multiple subplots such as - per fold, per estimator, per class.

  • Parameters:

    • All display .plot() methods now accept subplots, nrows, ncols, and figsize parameters.
    • subplots=True triggers subplot creation; nrows, ncols, and figsize allow custom grid layouts.
  • Automatic Layout Calculation:

    • By default, the grid uses 2 columns (adapting from Nixtla’s utilsforecast approach) and computes rows via ceiling division.
    • Handles edge cases where nrows or ncols is None to avoid type errors.
  • Subplotting based on different contexts:

    • For cross-validation: one subplot per fold.
    • For comparison: one subplot per estimator.
    • For multiclass: one subplot per class.
    • For single-estimator/single-class: falls back to a single plot.
  • Axis Sharing:

    • All subplots share x/y axes for scaling and easier comparison.
  • Backward-compatibility: (TBD: whether to keep this or discard it completely)

    • If subplots=False (which is the default option), behavior is unchanged.
    • self.ax_ param is always set to the first subplot.
  • Tests:

    • Added/adapted tests to verify subplot creation, layout, titles, axis labels, and error handling.

Usage

# One subplot per fold (cross-validation)
display.plot(subplots=True)

# Custom grid layout
display.plot(subplots=True, nrows=2, ncols=3, figsize=(12, 8))

References

TODO:


Sample plots following the above gist code ^^


  1. Prediction Error Display with CV folds (w/o nrows and ncols):
# Get the prediction error display
display = cv_report.metrics.prediction_error()

# Plot with subplots=True (one subplot per fold)
fig = display.plot(kind="actual_vs_predicted", subplots=True, figsize=(14, 10))
fig.suptitle("Actual vs Predicted Values Across CV Folds", fontsize=16, y=0.98)
plt.subplots_adjust(wspace=0.4, hspace=0.5, top=0.92, bottom=0.1, left=0.1, right=0.95)
plt.show()
Screenshot 2025-05-22 at 8 43 55 PM
  1. Prediction Error Display with model comparison (w/ nrows=1, ncols=3):
# Get the prediction error display
display = comp_report.metrics.prediction_error()

# Plot with subplots=True (one subplot per model) and custom layout
fig = display.plot(
    kind="actual_vs_predicted", subplots=True, nrows=1, ncols=3, figsize=(20, 7)
)
fig.suptitle("Comparing Prediction Error Across Different Models", fontsize=16, y=0.98)
plt.subplots_adjust(wspace=0.4, hspace=0.3, top=0.92, bottom=0.15, left=0.07, right=0.98)
plt.show()
Screenshot 2025-05-22 at 8 44 04 PM
  1. ROC Curve Display with CV folds (w/o nrows and ncols):
# Get the ROC curve display
display = cv_report.metrics.roc()

# Plot with subplots=True (one subplot per fold)
fig = display.plot(subplots=True, figsize=(14, 10))
fig.suptitle("ROC Curves Across CV Folds", fontsize=16, y=0.98)
plt.subplots_adjust(wspace=0.4, hspace=0.5, top=0.92, bottom=0.1, left=0.1, right=0.95)
plt.show()
Screenshot 2025-05-22 at 8 44 12 PM
  1. ROC Curve Display with model comparison (w/ nrows=1, ncols=3):
# Get the ROC curve display
display = comp_report.metrics.roc()

# Plot with subplots=True (one subplot per model) and custom layout
fig = display.plot(subplots=True, nrows=1, ncols=3, figsize=(20, 7))
fig.suptitle("Comparing ROC Curves Across Different Models", fontsize=16, y=0.98)
plt.subplots_adjust(wspace=0.4, hspace=0.3, top=0.92, bottom=0.15, left=0.07, right=0.98)
plt.show()
Screenshot 2025-05-22 at 8 44 21 PM
  1. Precision-Recall Curve Display with CV folds (w/o nrows and ncols):
# Using the same CV report as in Part 3
display = cv_report.metrics.precision_recall()

# Plot with subplots=True (one subplot per fold)
fig = display.plot(subplots=True, figsize=(14, 10))
fig.suptitle("Precision-Recall Curves Across CV Folds", fontsize=16, y=0.98)
plt.subplots_adjust(wspace=0.4, hspace=0.5, top=0.92, bottom=0.1, left=0.1, right=0.95)
plt.show()
Screenshot 2025-05-22 at 8 44 28 PM
  1. Precision-Recall Curve Display with model comparison (w/ nrows=1, ncols=3):
# Using the same comparison report as in Part 4
display = comp_report.metrics.precision_recall()

# Plot with subplots=True (one subplot per model) and custom layout
fig = display.plot(subplots=True, nrows=1, ncols=3, figsize=(20, 7))
fig.suptitle("Comparing Precision-Recall Curves Across Different Models", fontsize=16, y=0.98)
plt.subplots_adjust(wspace=0.4, hspace=0.3, top=0.92, bottom=0.15, left=0.07, right=0.98)
plt.show()
Screenshot 2025-05-22 at 8 44 38 PM

This might be too much to review at one go but there is quite some repetitive logic among different classes.
CC: @glemaitre @sylvaincom @auguste-probabl

@waridrox waridrox marked this pull request as draft May 19, 2025 15:14
@waridrox waridrox marked this pull request as ready for review May 20, 2025 18:53
report_type=self.report_type,
)

# Handle subplot creation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make more sense to put the subplots logic in each _plot_<report_type> method

Copy link
Contributor Author

@waridrox waridrox May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, in that case, will look into prediction_error and precision_recall_curve metrics as well

Copy link
Contributor

@auguste-probabl auguste-probabl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks; Some refactoring is needed

@auguste-probabl

This comment was marked as outdated.

@auguste-probabl
Copy link
Contributor

In your second example with the folds it looks like there are only 2 columns, but the code says ncol=3, is this a bug?

@waridrox
Copy link
Contributor Author

waridrox commented May 22, 2025

In your second example with the folds it looks like there are only 2 columns, but the code says ncol=3, is this a bug?

No, because actually in that example, there isn't any mention of the number of columns to begin with. For plotting the prediction error display with subplots for cross-validation, this line is responsible:

fig = display.plot(kind="actual_vs_predicted", subplots=True, figsize=(14, 10))
fig.suptitle("Actual vs Predicted Values Across CV Folds", fontsize=16, y=0.98)

Notice how there isn't any mention of the nrows or ncols. So the layout gets calculated automatically and it appears square-like as a 2×2 grid, with 3 plots.

@auguste-probabl
Copy link
Contributor

In your second example with the folds it looks like there are only 2 columns, but the code says ncol=3, is this a bug?

No, because actually in that example, there isn't any mention of the number of columns to begin with.

Can you update the PR description to make it clear which code leads to which plot?

auguste-probabl added a commit to auguste-probabl/skore that referenced this pull request May 23, 2025
Starting from probabl-ai#1669, it
is sometimes possible to draw plots onto (a `ndarray` of) several
Axes rather than just a single Axes. In fact, starting from
probabl-ai#1720, the same data might be
drawn on one or the other depending on user parameters.

In this context, allowing the user to pass `ax` might add too much
complexity compared to how useful it might be; it is considered an
advanced use-case which is difficult to motivate.

This PR thus removes the ability to pass user-created `ax`.
auguste-probabl added a commit to auguste-probabl/skore that referenced this pull request May 23, 2025
Starting from probabl-ai#1669, it
is sometimes possible to draw plots onto (a `ndarray` of) several
Axes rather than just a single Axes. In fact, starting from
probabl-ai#1720, the same data might be
drawn on one or the other depending on user parameters.

In this context, allowing the user to pass `ax` might add too much
complexity compared to how useful it might be; it is considered an
advanced use-case which is difficult to motivate.

This PR thus removes the ability to pass user-created `ax`.
auguste-probabl added a commit to auguste-probabl/skore that referenced this pull request May 23, 2025
Starting from probabl-ai#1669, it
is sometimes possible to draw plots onto (a `ndarray` of) several
Axes rather than just a single Axes. In fact, starting from
probabl-ai#1720, the same data might be
drawn on one or the other depending on user parameters.

In this context, allowing the user to pass `ax` might add too much
complexity compared to how useful it might be; it is considered an
advanced use-case which is difficult to motivate.

This PR thus removes the ability to pass user-created `ax`.
auguste-probabl added a commit to auguste-probabl/skore that referenced this pull request May 23, 2025
Starting from probabl-ai#1669, it
is sometimes possible to draw plots onto (a `ndarray` of) several
Axes rather than just a single Axes. In fact, starting from
probabl-ai#1720, the same data might be
drawn on one or the other depending on user parameters.

In this context, allowing the user to pass `ax` might add too much
complexity compared to how useful it might be; it is considered an
advanced use-case which is difficult to motivate.

This PR thus removes the ability to pass user-created `ax`.
@waridrox waridrox marked this pull request as draft May 26, 2025 05:45
@thomass-dev
Copy link
Collaborator

thomass-dev commented May 26, 2025

[automated comment] Please update your PR with main, so that the pytest workflow status will be reported.

@MarieSacksick
Copy link
Contributor

@waridrox, do you want to continue this PR, or should we take it over and iterate from here?

@waridrox
Copy link
Contributor Author

waridrox commented Oct 8, 2025

@waridrox, do you want to continue this PR, or should we take it over and iterate from here?

Hi @MarieSacksick, I'm willing to continue if the current implementation is fine. I was waiting on some follow ups from other PRs hence I converted this one to draft :)

@waridrox
Copy link
Contributor Author

waridrox commented Nov 14, 2025

Seems like a lot has been iterated over the past few months. Is this still something to be worked upon @MarieSacksick? I can update the implementation to resolve merge conflicts or @GaetandeCast do you have any plans to work on this?

@glemaitre
Copy link
Member

While working on #2152, then I think that the scope for the subplot is a bit clearer.

@waridrox I think it would be easier to go step by step:

  • let's forget for the API of figsize and layout for a moment and let's just focus on subplot.
  • We now see that depending of the type of report, we will be interested in creating subplots based on a specific variable (e.g. estimator, class label, training vs testing (not yet implemented)).
  • So since that we might encounter many little details that we did not thing yet, I would suggest to start with one specific display only. Let's start with the ROC curve.

The API would be something like:

report = *Report(...)
display = report.metrics.roc()
display.plot()  # have a reasonable default
display.plot(subplots_by="label")  # one class label per axis
display.plot(subplots_by="estimator_name")  # one estimator by axis

The PR that I linked is still work in progress but it should provide a good overview of what we can do even if it is on the feature importance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

enh(skore): Add an option to plot into several subplots some display information

5 participants