20
20
from autoemulate .plotting import _plot_model
21
21
from autoemulate .printing import _print_setup
22
22
from autoemulate .save import ModelSerialiser
23
- from autoemulate .sensitivity_analysis import plot_sensitivity_analysis
24
- from autoemulate .sensitivity_analysis import sensitivity_analysis
23
+ from autoemulate .sensitivity_analysis import _plot_sensitivity_analysis
24
+ from autoemulate .sensitivity_analysis import _sensitivity_analysis
25
25
from autoemulate .utils import _check_cv
26
26
from autoemulate .utils import _ensure_2d
27
27
from autoemulate .utils import _get_full_model_name
33
33
class AutoEmulate :
34
34
"""
35
35
The AutoEmulate class is the main class of the AutoEmulate package. It is used to set up and compare
36
- different emulator models on a given dataset. It can also be used to save and load models, and to
37
- print and plot the results of the comparison .
36
+ different emulator models on a given dataset. It can also be used to summarise and visualise results,
37
+ to save and load models and to run sensitivity analysis .
38
38
"""
39
39
40
40
def __init__ (self ):
@@ -178,12 +178,13 @@ def _get_metrics(self, METRIC_REGISTRY):
178
178
return [metric for metric in METRIC_REGISTRY .values ()]
179
179
180
180
def compare (self ):
181
- """Compares the emulator models on the data. self.setup() must be run first.
181
+ """Compares models using cross-validation, with the option
182
+ to perform hyperparameter search. self.setup() must be run first.
182
183
183
184
Returns
184
185
-------
185
186
self.best_model : object
186
- Best performing model fitted on full data .
187
+ Emulator with the highest cross-validation R2 score .
187
188
"""
188
189
if not self .is_set_up :
189
190
raise RuntimeError ("Must run setup() before compare()" )
@@ -257,7 +258,8 @@ def get_model(self, name=None, rank=1, metric="r2"):
257
258
Parameters
258
259
----------
259
260
name : str
260
- Name of the model to return.
261
+ Name of the model to return. Can be full name or short name, e.g. "GaussianProcess" or "gp".
262
+ Short name abbreviations are the uppercase first letter of each word in the full name (e.g. "GaussianProcess" -> "gp").
261
263
rank : int
262
264
Rank of the model to return. Defaults to 1, which is the best model, 2 is the second best, etc.
263
265
metric : str
@@ -298,8 +300,7 @@ def get_model(self, name=None, rank=1, metric="r2"):
298
300
return chosen_model
299
301
300
302
def refit (self , model = None ):
301
- """Refits model on full data.
302
-
303
+ """Refits model on full data. This is useful, as `compare()` runs only on the training data.
303
304
Parameters
304
305
----------
305
306
model : model to refit.
@@ -359,7 +360,7 @@ def load(self, path=None):
359
360
return serialiser ._load_model (path )
360
361
361
362
def print_setup (self ):
362
- """Print the setup of the AutoEmulate object."""
363
+ """Print the parameters of the AutoEmulate object."""
363
364
_print_setup (self )
364
365
365
366
def summarise_cv (self , model = None , sort_by = "r2" ):
@@ -408,17 +409,17 @@ def plot_cv(
408
409
If a model name is specified, plots all folds of that model.
409
410
style : str, optional
410
411
The type of plot to draw:
411
- "Xy" observed and predicted values vs. features, including 2σ error bands where available (default).
412
- "actual_vs_predicted" draws the observed values (y-axis) vs. the predicted values (x-axis) (default ).
413
- "residual_vs_predicted" draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
412
+ "Xy" for plotting observed and predicted values vs. features, including 2σ error bands where available (default).
413
+ "actual_vs_predicted" for plotting observed values (y-axis) vs. the predicted values (x-axis).
414
+ "residual_vs_predicted" for plotting the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
414
415
n_cols : int
415
416
Number of columns in the plot grid.
416
417
figsize : tuple, optional
417
- Overrides the default figure size.
418
+ Overrides the default figure size, in inches, e.g. (6, 4) .
418
419
output_index : int
419
- Index of the output to plot. Default is 0.
420
+ Index of the output to plot. Default is 0. Can be a single index or a list of indices.
420
421
input_index : int
421
- Index of the input to plot. Default is 0.
422
+ Index of the input to plot. Default is 0. Can be a single index or a list of indices.
422
423
"""
423
424
model_name = (
424
425
_get_full_model_name (model , self .model_names ) if model is not None else None
@@ -440,6 +441,8 @@ def evaluate(self, model=None, multioutput="uniform_average"):
440
441
"""
441
442
Evaluates the model on the test set.
442
443
444
+ Test set size can be specified in `setup()` with `test_set_size`.
445
+
443
446
Parameters
444
447
----------
445
448
model : object
@@ -498,7 +501,7 @@ def plot_eval(
498
501
output_index = 0 ,
499
502
input_index = 0 ,
500
503
):
501
- """Visualise different model evaluations on the test set.
504
+ """Visualise model predictive performance on the test set.
502
505
503
506
Parameters
504
507
----------
@@ -534,28 +537,60 @@ def sensitivity_analysis(
534
537
):
535
538
"""Perform Sobol sensitivity analysis on a fitted emulator.
536
539
540
+ Sobol sensitivity analysis is a variance-based method that decomposes the variance of the model
541
+ output into contributions from individual input parameters and their interactions. It calculates:
542
+ - First-order indices (S1): Direct contribution of each input parameter
543
+ - Second-order indices (S2): Contribution from pairwise interactions between parameters
544
+ - Total-order indices (ST): Total contribution of a parameter, including all its interactions
545
+
537
546
Parameters
538
547
----------
539
548
model : object, optional
540
549
Fitted model. If None, uses the best model from cross-validation.
541
550
problem : dict, optional
542
- The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
543
- If None, the problem is generated from X using minimum and maximum values of the features as bounds.
551
+ The problem definition dictionary. If None, the problem is generated from X using
552
+ minimum and maximum values of the features as bounds. The dictionary should contain:
553
+
554
+ - 'num_vars': Number of input variables (int)
555
+ - 'names': List of variable names (list of str)
556
+ - 'bounds': List of [min, max] bounds for each variable (list of lists)
557
+ - 'output_names': Optional list of output names (list of str)
558
+
559
+ Example::
544
560
545
- Example:
546
- ```python
547
561
problem = {
548
562
"num_vars": 2,
549
563
"names": ["x1", "x2"],
550
564
"bounds": [[0, 1], [0, 1]],
565
+ "output_names": ["y1", "y2"] # optional
551
566
}
552
- ```
553
567
N : int, optional
554
- Number of samples to generate. Default is 1024.
568
+ Number of samples to generate for the analysis. Higher values give more accurate
569
+ results but increase computation time. Default is 1024.
555
570
conf_level : float, optional
556
- Confidence level for the confidence intervals. Default is 0.95.
571
+ Confidence level (between 0 and 1) for calculating confidence intervals of the
572
+ sensitivity indices. Default is 0.95 (95% confidence).
557
573
as_df : bool, optional
558
- If True, return a long-format pandas DataFrame (default is True).
574
+ If True, returns results as a long-format pandas DataFrame with columns for
575
+ parameters, sensitivity indices, and confidence intervals. If False, returns
576
+ the raw SALib results dictionary. Default is True.
577
+
578
+ Returns
579
+ -------
580
+ pandas.DataFrame or dict
581
+ If as_df=True (default), returns a DataFrame with columns:
582
+
583
+ - 'parameter': Input parameter name
584
+ - 'output': Output variable name
585
+ - 'S1', 'S2', 'ST': First, second, and total order sensitivity indices
586
+ - 'S1_conf', 'S2_conf', 'ST_conf': Confidence intervals for each index
587
+
588
+ If as_df=False, returns the raw SALib results dictionary.
589
+
590
+ Notes
591
+ -----
592
+ The analysis requires N * (2D + 2) model evaluations, where D is the number of input
593
+ parameters. For example, with N=1024 and 5 parameters, this requires 12,288 evaluations.
559
594
"""
560
595
if model is None :
561
596
if not hasattr (self , "best_model" ):
@@ -565,7 +600,7 @@ def sensitivity_analysis(
565
600
f"No model provided, using { get_model_name (model )} , which had the highest average cross-validation score, refitted on full data."
566
601
)
567
602
568
- Si = sensitivity_analysis (model , problem , self .X , N , conf_level , as_df )
603
+ Si = _sensitivity_analysis (model , problem , self .X , N , conf_level , as_df )
569
604
return Si
570
605
571
606
def plot_sensitivity_analysis (self , results , index = "S1" , n_cols = None , figsize = None ):
@@ -588,4 +623,4 @@ def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=No
588
623
Figure size as (width, height) in inches.If None, automatically calculated.
589
624
590
625
"""
591
- return plot_sensitivity_analysis (results , index , n_cols , figsize )
626
+ return _plot_sensitivity_analysis (results , index , n_cols , figsize )
0 commit comments