Skip to content

Commit 7056f9e

Browse files
committed
update plot label
1 parent 947d83d commit 7056f9e

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

autoemulate/core/compare.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -873,13 +873,6 @@ def plot_preds( # noqa: PLR0912
873873
output_names = [f"$y_{i}$" for i in range(n_outputs)]
874874

875875
for i in range(n_outputs):
876-
axs[i].scatter(
877-
test_y[:, i],
878-
y_pred[:, i],
879-
alpha=0.6,
880-
linewidth=0.5,
881-
label="predicted",
882-
)
883876
if y_std is not None:
884877
axs[i].errorbar(
885878
test_y[:, i],
@@ -889,6 +882,12 @@ def plot_preds( # noqa: PLR0912
889882
alpha=0.4,
890883
capsize=3,
891884
)
885+
axs[i].scatter(
886+
test_y[:, i],
887+
y_pred[:, i],
888+
alpha=0.6,
889+
linewidth=0.5,
890+
)
892891
axs[i].plot(
893892
[test_y[:, i].min(), test_y[:, i].max()],
894893
[test_y[:, i].min(), test_y[:, i].max()],
@@ -897,7 +896,8 @@ def plot_preds( # noqa: PLR0912
897896
)
898897
axs[i].set_title(output_names[i])
899898
axs[i].set_xlabel("True values")
900-
axs[i].set_ylabel("Predicted values")
899+
axs[i].set_ylabel("Predicted values ±2\u03c3")
900+
plt.tight_layout()
901901

902902
if figsize is not None:
903903
fig.set_size_inches(figsize)

0 commit comments

Comments
 (0)