22import warnings
33from datetime import datetime
44from pathlib import Path
5+ from typing import Literal
56
67import joblib
78import matplotlib .pyplot as plt
@@ -621,10 +622,13 @@ def fit_from_reinitialized(
621622 def plot ( # noqa: PLR0912, PLR0915
622623 self ,
623624 model_obj : int | Emulator | Result ,
625+ input_names : list [str ] | None = None ,
626+ output_names : list [str ] | None = None ,
624627 input_index : list [int ] | int | None = None ,
625628 output_index : list [int ] | int | None = None ,
626629 input_ranges : dict | None = None ,
627630 output_ranges : dict | None = None ,
631+ error_style : Literal ["bars" , "fill" ] = "bars" ,
628632 figsize = None ,
629633 ncols : int = 3 ,
630634 fname : str | None = None ,
@@ -637,6 +641,10 @@ def plot( # noqa: PLR0912, PLR0915
637641 model_obj: int | Emulator | Result
638642 The model to plot. Can be an integer ID of a Result, an Emulator instance,
639643 or a Result instance.
644+ input_names: list[str] | None
645+ The names of the input features. If None, generic names are used.
646+ output_names: list[str] | None
647+ The names of the output features. If None, generic names are used.
640648 input_index: int
641649 The index of the input feature to plot against the output.
642650 output_index: int
@@ -649,6 +657,9 @@ def plot( # noqa: PLR0912, PLR0915
649657 The ranges of the output features to consider for the plot. Ranges are
650658 combined such that the final subset is the intersection data within
651659 the specified ranges. Defaults to None.
660+ error_style: Literal["bars", "fill"]
661+ The style of error representation in the plots. Can be "bars" for error
662+ bars or "fill" for shaded error regions. Defaults to "bars".
652663 figsize: tuple[int, int] | None
653664 The size of the figure to create. If None, it is set based on the number
654665 of input and output features.
@@ -728,6 +739,26 @@ def plot( # noqa: PLR0912, PLR0915
728739 fig , axs = plt .subplots (nrows , ncols , figsize = figsize , squeeze = False )
729740 axs = axs .flatten ()
730741
742+ if input_names is not None :
743+ if len (input_names ) != n_features :
744+ msg = (
745+ "Length of input_names does not match number of input features. "
746+ f"Expected { n_features } , got { len (input_names )} ."
747+ )
748+ raise ValueError (msg )
749+ else :
750+ input_names = [f"$x_{ i } $" for i in range (n_features )]
751+
752+ if output_names is not None :
753+ if len (output_names ) != n_outputs :
754+ msg = (
755+ "Length of output_names does not match number of outputs. "
756+ f"Expected { n_outputs } , got { len (output_names )} ."
757+ )
758+ raise ValueError (msg )
759+ else :
760+ output_names = [f"$y_{ i } $" for i in range (n_outputs )]
761+
731762 plot_index = 0
732763 for out_idx in output_index :
733764 for in_idx in input_index :
@@ -778,10 +809,11 @@ def subset_outputs(x, y, y_p):
778809 y_pred_subset [:, out_idx ],
779810 y_variance [:, out_idx ] if y_variance is not None else None ,
780811 ax = axs [plot_index ],
781- title = f"$x_ { in_idx } $ vs. $y_ { out_idx } $ " ,
782- input_index = in_idx ,
783- output_index = out_idx ,
812+ title = f"{ input_names [ in_idx ] } vs. { output_names [ out_idx ] } " ,
813+ input_label = input_names [ in_idx ] ,
814+ output_label = output_names [ out_idx ] ,
784815 r2_score = r2_score ,
816+ error_style = error_style ,
785817 )
786818 plot_index += 1
787819
@@ -795,6 +827,116 @@ def subset_outputs(x, y, y_p):
795827 fig .savefig (fname , bbox_inches = "tight" )
796828 return None
797829
830+ def plot_preds ( # noqa: PLR0912
831+ self ,
832+ model_obj : int | Emulator | Result ,
833+ output_names : list [str ] | None = None ,
834+ figsize = None ,
835+ ncols : int = 3 ,
836+ fname : str | None = None ,
837+ ):
838+ """
839+ Plot predicted means (and variances) against observations for all outputs.
840+
841+ Parameters
842+ ----------
843+ model_obj: int | Emulator | Result
844+ The model to plot. Can be an integer ID of a Result, an Emulator instance,
845+ or a Result instance.
846+ output_names: list[str] | None
847+ The names of the outputs to use in the plot titles. If None, generic names
848+ like "y_0", "y_1", etc. are used.
849+ figsize: tuple[int, int] | None
850+ The size of the figure to create. If None, it is set based on the number
851+ of outputs.
852+ ncols: int
853+ The maximum number of columns in the subplot grid. Defaults to 3.
854+ fname: str | None
855+ If provided, the figure will be saved to this file path.
856+ """
857+ result = None
858+ if isinstance (model_obj , int ):
859+ if model_obj not in self ._id_to_result :
860+ raise ValueError (f"No result found with ID: { model_obj } " )
861+ result = self .get_result (model_obj )
862+ model = result .model
863+ elif isinstance (model_obj , Emulator ):
864+ model = model_obj
865+ elif isinstance (model_obj , Result ):
866+ model = model_obj .model
867+
868+ test_x , test_y = self ._convert_to_tensors (self .test )
869+
870+ # Re-run prediction with just this model to get the predictions
871+ y_pred , y_variance = model .predict_mean_and_variance (test_x )
872+ y_std = None
873+ if y_variance is not None :
874+ y_variance , _ = self ._convert_to_numpy (y_variance , None )
875+ y_variance = self ._ensure_numpy_2d (y_variance )
876+ y_std = np .sqrt (y_variance )
877+
878+ # Convert to numpy for plotting
879+ test_x , test_y = self ._convert_to_numpy (test_x , test_y )
880+ assert test_x is not None
881+ assert test_y is not None
882+ assert y_pred is not None
883+ y_pred , _ = self ._convert_to_numpy (y_pred , None )
884+ test_x = self ._ensure_numpy_2d (test_x )
885+ test_y = self ._ensure_numpy_2d (test_y )
886+ y_pred = self ._ensure_numpy_2d (y_pred )
887+
888+ # Figure out layout
889+ n_outputs = test_y .shape [1 ] if test_y .ndim > 1 else 1
890+ nrows , ncols = calculate_subplot_layout (n_outputs , ncols )
891+ if figsize is None :
892+ figsize = (5 * ncols , 4 * nrows )
893+ fig , axs = plt .subplots (nrows , ncols , figsize = figsize , squeeze = False )
894+ axs = axs .flatten ()
895+
896+ if output_names is not None :
897+ if len (output_names ) != n_outputs :
898+ msg = (
899+ "Length of output_names does not match number of outputs. "
900+ f"Expected { n_outputs } , got { len (output_names )} ."
901+ )
902+ raise ValueError (msg )
903+ else :
904+ output_names = [f"$y_{ i } $" for i in range (n_outputs )]
905+
906+ for i in range (n_outputs ):
907+ if y_std is not None :
908+ axs [i ].errorbar (
909+ test_y [:, i ],
910+ y_pred [:, i ],
911+ yerr = 2 * y_std [:, i ],
912+ fmt = "none" ,
913+ alpha = 0.4 ,
914+ capsize = 3 ,
915+ )
916+ axs [i ].scatter (
917+ test_y [:, i ],
918+ y_pred [:, i ],
919+ alpha = 0.6 ,
920+ linewidth = 0.5 ,
921+ )
922+ axs [i ].plot (
923+ [test_y [:, i ].min (), test_y [:, i ].max ()],
924+ [test_y [:, i ].min (), test_y [:, i ].max ()],
925+ linestyle = "--" ,
926+ color = "gray" ,
927+ )
928+ axs [i ].set_title (output_names [i ])
929+ axs [i ].set_xlabel ("True values" )
930+ axs [i ].set_ylabel ("Predicted values ±2\u03c3 " )
931+ plt .tight_layout ()
932+
933+ if figsize is not None :
934+ fig .set_size_inches (figsize )
935+ if fname is None :
936+ return display_figure (fig )
937+ fig .savefig (fname , bbox_inches = "tight" )
938+ return None
939+
798940 def plot_surface (
799941 self ,
800942 model : Emulator ,
0 commit comments