Skip to content

Commit 3daa3ed

Browse files
authored
Merge pull request #264 from alan-turing-institute/plotting
fix wrong R2's in plot_cv due to not subsetting training data
2 parents ca6274e + 4dd28c4 commit 3daa3ed

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

autoemulate/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@ def plot_cv(
422422
)
423423
figure = _plot_cv(
424424
self.cv_results,
425-
self.X,
426-
self.y,
425+
self.X[self.train_idxs],
426+
self.y[self.train_idxs],
427427
model_name=model_name,
428428
n_cols=n_cols,
429429
style=style,

autoemulate/plotting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import matplotlib.pyplot as plt
44
import numpy as np
55
from sklearn.metrics import PredictionErrorDisplay
6+
from sklearn.metrics import r2_score
67
from sklearn.pipeline import Pipeline
78

89
from autoemulate.utils import _ensure_2d
@@ -482,6 +483,7 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
482483
"""
483484
Plots observed and predicted values vs. features, including 2σ error bands where available.
484485
"""
486+
485487
# Sort the data
486488
sort_idx = np.argsort(X).flatten()
487489
X_sorted = X[sort_idx]
@@ -558,10 +560,10 @@ def _plot_Xy(X, y, y_pred, y_std=None, ax=None, title="Xy"):
558560
columnspacing=0,
559561
ncol=2,
560562
)
563+
561564
# Calculate R2 score
562-
r2 = 1 - np.sum((y_sorted - y_pred_sorted) ** 2) / np.sum(
563-
(y_sorted - np.mean(y_sorted)) ** 2
564-
)
565+
r2 = r2_score(y, y_pred)
566+
565567
ax.text(
566568
0.05,
567569
0.05,

0 commit comments

Comments
 (0)