2020from autoemulate .plotting import _plot_model
2121from autoemulate .printing import _print_setup
2222from autoemulate .save import ModelSerialiser
23- from autoemulate .sensitivity_analysis import plot_sensitivity_analysis
24- from autoemulate .sensitivity_analysis import sensitivity_analysis
23+ from autoemulate .sensitivity_analysis import _plot_sensitivity_analysis
24+ from autoemulate .sensitivity_analysis import _sensitivity_analysis
2525from autoemulate .utils import _check_cv
2626from autoemulate .utils import _ensure_2d
2727from autoemulate .utils import _get_full_model_name
3333class AutoEmulate :
3434 """
3535 The AutoEmulate class is the main class of the AutoEmulate package. It is used to set up and compare
36- different emulator models on a given dataset. It can also be used to save and load models, and to
37- print and plot the results of the comparison .
36+ different emulator models on a given dataset. It can also be used to summarise and visualise results,
37+ to save and load models and to run sensitivity analysis .
3838 """
3939
4040 def __init__ (self ):
@@ -178,12 +178,13 @@ def _get_metrics(self, METRIC_REGISTRY):
178178 return [metric for metric in METRIC_REGISTRY .values ()]
179179
180180 def compare (self ):
181- """Compares the emulator models on the data. self.setup() must be run first.
181+ """Compares models using cross-validation, with the option
182+ to perform hyperparameter search. self.setup() must be run first.
182183
183184 Returns
184185 -------
185186 self.best_model : object
186- Best performing model fitted on full data .
187+ Emulator with the highest cross-validation R2 score .
187188 """
188189 if not self .is_set_up :
189190 raise RuntimeError ("Must run setup() before compare()" )
@@ -257,7 +258,8 @@ def get_model(self, name=None, rank=1, metric="r2"):
257258 Parameters
258259 ----------
259260 name : str
260- Name of the model to return.
261+ Name of the model to return. Can be full name or short name, e.g. "GaussianProcess" or "gp".
262+ Short name abbreviations are the uppercase first letter of each word in the full name (e.g. "GaussianProcess" -> "gp").
261263 rank : int
262264 Rank of the model to return. Defaults to 1, which is the best model, 2 is the second best, etc.
263265 metric : str
@@ -298,8 +300,7 @@ def get_model(self, name=None, rank=1, metric="r2"):
298300 return chosen_model
299301
300302 def refit (self , model = None ):
301- """Refits model on full data.
302-
303+ """Refits model on full data. This is useful, as `compare()` runs only on the training data.
303304 Parameters
304305 ----------
305306 model : model to refit.
@@ -359,7 +360,7 @@ def load(self, path=None):
359360 return serialiser ._load_model (path )
360361
361362 def print_setup (self ):
362- """Print the setup of the AutoEmulate object."""
363+ """Print the parameters of the AutoEmulate object."""
363364 _print_setup (self )
364365
365366 def summarise_cv (self , model = None , sort_by = "r2" ):
@@ -408,17 +409,17 @@ def plot_cv(
408409 If a model name is specified, plots all folds of that model.
409410 style : str, optional
410411 The type of plot to draw:
411- "Xy" observed and predicted values vs. features, including 2σ error bands where available (default).
412- "actual_vs_predicted" draws the observed values (y-axis) vs. the predicted values (x-axis) (default ).
413- "residual_vs_predicted" draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
412+ "Xy" for plotting observed and predicted values vs. features, including 2σ error bands where available (default).
413+ "actual_vs_predicted" for plotting observed values (y-axis) vs. the predicted values (x-axis).
414+ "residual_vs_predicted" for plotting the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
414415 n_cols : int
415416 Number of columns in the plot grid.
416417 figsize : tuple, optional
417- Overrides the default figure size.
418+ Overrides the default figure size, in inches, e.g. (6, 4) .
418419 output_index : int
419- Index of the output to plot. Default is 0.
420+ Index of the output to plot. Default is 0. Can be a single index or a list of indices.
420421 input_index : int
421- Index of the input to plot. Default is 0.
422+ Index of the input to plot. Default is 0. Can be a single index or a list of indices.
422423 """
423424 model_name = (
424425 _get_full_model_name (model , self .model_names ) if model is not None else None
@@ -440,6 +441,8 @@ def evaluate(self, model=None, multioutput="uniform_average"):
440441 """
441442 Evaluates the model on the test set.
442443
444+ Test set size can be specified in `setup()` with `test_set_size`.
445+
443446 Parameters
444447 ----------
445448 model : object
@@ -498,7 +501,7 @@ def plot_eval(
498501 output_index = 0 ,
499502 input_index = 0 ,
500503 ):
501- """Visualise different model evaluations on the test set.
504+ """Visualise model predictive performance on the test set.
502505
503506 Parameters
504507 ----------
@@ -534,28 +537,60 @@ def sensitivity_analysis(
534537 ):
535538 """Perform Sobol sensitivity analysis on a fitted emulator.
536539
540+ Sobol sensitivity analysis is a variance-based method that decomposes the variance of the model
541+ output into contributions from individual input parameters and their interactions. It calculates:
542+ - First-order indices (S1): Direct contribution of each input parameter
543+ - Second-order indices (S2): Contribution from pairwise interactions between parameters
544+ - Total-order indices (ST): Total contribution of a parameter, including all its interactions
545+
537546 Parameters
538547 ----------
539548 model : object, optional
540549 Fitted model. If None, uses the best model from cross-validation.
541550 problem : dict, optional
542- The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
543- If None, the problem is generated from X using minimum and maximum values of the features as bounds.
551+ The problem definition dictionary. If None, the problem is generated from X using
552+ minimum and maximum values of the features as bounds. The dictionary should contain:
553+
554+ - 'num_vars': Number of input variables (int)
555+ - 'names': List of variable names (list of str)
556+ - 'bounds': List of [min, max] bounds for each variable (list of lists)
557+ - 'output_names': Optional list of output names (list of str)
558+
559+ Example::
544560
545- Example:
546- ```python
547561 problem = {
548562 "num_vars": 2,
549563 "names": ["x1", "x2"],
550564 "bounds": [[0, 1], [0, 1]],
565+ "output_names": ["y1", "y2"] # optional
551566 }
552- ```
553567 N : int, optional
554- Number of samples to generate. Default is 1024.
568+ Number of samples to generate for the analysis. Higher values give more accurate
569+ results but increase computation time. Default is 1024.
555570 conf_level : float, optional
556- Confidence level for the confidence intervals. Default is 0.95.
571+ Confidence level (between 0 and 1) for calculating confidence intervals of the
572+ sensitivity indices. Default is 0.95 (95% confidence).
557573 as_df : bool, optional
558- If True, return a long-format pandas DataFrame (default is True).
574+ If True, returns results as a long-format pandas DataFrame with columns for
575+ parameters, sensitivity indices, and confidence intervals. If False, returns
576+ the raw SALib results dictionary. Default is True.
577+
578+ Returns
579+ -------
580+ pandas.DataFrame or dict
581+ If as_df=True (default), returns a DataFrame with columns:
582+
583+ - 'parameter': Input parameter name
584+ - 'output': Output variable name
585+ - 'S1', 'S2', 'ST': First, second, and total order sensitivity indices
586+ - 'S1_conf', 'S2_conf', 'ST_conf': Confidence intervals for each index
587+
588+ If as_df=False, returns the raw SALib results dictionary.
589+
590+ Notes
591+ -----
592+ The analysis requires N * (2D + 2) model evaluations, where D is the number of input
593+ parameters. For example, with N=1024 and 5 parameters, this requires 12,288 evaluations.
559594 """
560595 if model is None :
561596 if not hasattr (self , "best_model" ):
@@ -565,7 +600,7 @@ def sensitivity_analysis(
565600 f"No model provided, using { get_model_name (model )} , which had the highest average cross-validation score, refitted on full data."
566601 )
567602
568- Si = sensitivity_analysis (model , problem , self .X , N , conf_level , as_df )
603+ Si = _sensitivity_analysis (model , problem , self .X , N , conf_level , as_df )
569604 return Si
570605
571606 def plot_sensitivity_analysis (self , results , index = "S1" , n_cols = None , figsize = None ):
@@ -588,4 +623,4 @@ def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=No
588623 Figure size as (width, height) in inches.If None, automatically calculated.
589624
590625 """
591- return plot_sensitivity_analysis (results , index , n_cols , figsize )
626+ return _plot_sensitivity_analysis (results , index , n_cols , figsize )
0 commit comments