Skip to content

Commit e8b5cdf

Browse files
authored
Merge pull request #266 from alan-turing-institute/docs
Docs
2 parents 5cc7a0f + 8500aae commit e8b5cdf

File tree

9 files changed

+298
-165
lines changed

9 files changed

+298
-165
lines changed

autoemulate/compare.py

+62-27
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
from autoemulate.plotting import _plot_model
2121
from autoemulate.printing import _print_setup
2222
from autoemulate.save import ModelSerialiser
23-
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
24-
from autoemulate.sensitivity_analysis import sensitivity_analysis
23+
from autoemulate.sensitivity_analysis import _plot_sensitivity_analysis
24+
from autoemulate.sensitivity_analysis import _sensitivity_analysis
2525
from autoemulate.utils import _check_cv
2626
from autoemulate.utils import _ensure_2d
2727
from autoemulate.utils import _get_full_model_name
@@ -33,8 +33,8 @@
3333
class AutoEmulate:
3434
"""
3535
The AutoEmulate class is the main class of the AutoEmulate package. It is used to set up and compare
36-
different emulator models on a given dataset. It can also be used to save and load models, and to
37-
print and plot the results of the comparison.
36+
different emulator models on a given dataset. It can also be used to summarise and visualise results,
37+
to save and load models and to run sensitivity analysis.
3838
"""
3939

4040
def __init__(self):
@@ -178,12 +178,13 @@ def _get_metrics(self, METRIC_REGISTRY):
178178
return [metric for metric in METRIC_REGISTRY.values()]
179179

180180
def compare(self):
181-
"""Compares the emulator models on the data. self.setup() must be run first.
181+
"""Compares models using cross-validation, with the option
182+
to perform hyperparameter search. self.setup() must be run first.
182183
183184
Returns
184185
-------
185186
self.best_model : object
186-
Best performing model fitted on full data.
187+
Emulator with the highest cross-validation R2 score.
187188
"""
188189
if not self.is_set_up:
189190
raise RuntimeError("Must run setup() before compare()")
@@ -257,7 +258,8 @@ def get_model(self, name=None, rank=1, metric="r2"):
257258
Parameters
258259
----------
259260
name : str
260-
Name of the model to return.
261+
Name of the model to return. Can be full name or short name, e.g. "GaussianProcess" or "gp".
262+
Short name abbreviations are the uppercase first letter of each word in the full name (e.g. "GaussianProcess" -> "gp").
261263
rank : int
262264
Rank of the model to return. Defaults to 1, which is the best model, 2 is the second best, etc.
263265
metric : str
@@ -298,8 +300,7 @@ def get_model(self, name=None, rank=1, metric="r2"):
298300
return chosen_model
299301

300302
def refit(self, model=None):
301-
"""Refits model on full data.
302-
303+
"""Refits model on full data. This is useful, as `compare()` runs only on the training data.
303304
Parameters
304305
----------
305306
model : model to refit.
@@ -359,7 +360,7 @@ def load(self, path=None):
359360
return serialiser._load_model(path)
360361

361362
def print_setup(self):
362-
"""Print the setup of the AutoEmulate object."""
363+
"""Print the parameters of the AutoEmulate object."""
363364
_print_setup(self)
364365

365366
def summarise_cv(self, model=None, sort_by="r2"):
@@ -408,17 +409,17 @@ def plot_cv(
408409
If a model name is specified, plots all folds of that model.
409410
style : str, optional
410411
The type of plot to draw:
411-
"Xy" observed and predicted values vs. features, including 2σ error bands where available (default).
412-
"actual_vs_predicted" draws the observed values (y-axis) vs. the predicted values (x-axis) (default).
413-
"residual_vs_predicted" draws the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
412+
"Xy" for plotting observed and predicted values vs. features, including 2σ error bands where available (default).
413+
"actual_vs_predicted" for plotting observed values (y-axis) vs. the predicted values (x-axis).
414+
"residual_vs_predicted" for plotting the residuals, i.e. difference between observed and predicted values, (y-axis) vs. the predicted values (x-axis).
414415
n_cols : int
415416
Number of columns in the plot grid.
416417
figsize : tuple, optional
417-
Overrides the default figure size.
418+
Overrides the default figure size, in inches, e.g. (6, 4).
418419
output_index : int
419-
Index of the output to plot. Default is 0.
420+
Index of the output to plot. Default is 0. Can be a single index or a list of indices.
420421
input_index : int
421-
Index of the input to plot. Default is 0.
422+
Index of the input to plot. Default is 0. Can be a single index or a list of indices.
422423
"""
423424
model_name = (
424425
_get_full_model_name(model, self.model_names) if model is not None else None
@@ -440,6 +441,8 @@ def evaluate(self, model=None, multioutput="uniform_average"):
440441
"""
441442
Evaluates the model on the test set.
442443
444+
Test set size can be specified in `setup()` with `test_set_size`.
445+
443446
Parameters
444447
----------
445448
model : object
@@ -498,7 +501,7 @@ def plot_eval(
498501
output_index=0,
499502
input_index=0,
500503
):
501-
"""Visualise different model evaluations on the test set.
504+
"""Visualise model predictive performance on the test set.
502505
503506
Parameters
504507
----------
@@ -534,28 +537,60 @@ def sensitivity_analysis(
534537
):
535538
"""Perform Sobol sensitivity analysis on a fitted emulator.
536539
540+
Sobol sensitivity analysis is a variance-based method that decomposes the variance of the model
541+
output into contributions from individual input parameters and their interactions. It calculates:
542+
- First-order indices (S1): Direct contribution of each input parameter
543+
- Second-order indices (S2): Contribution from pairwise interactions between parameters
544+
- Total-order indices (ST): Total contribution of a parameter, including all its interactions
545+
537546
Parameters
538547
----------
539548
model : object, optional
540549
Fitted model. If None, uses the best model from cross-validation.
541550
problem : dict, optional
542-
The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
543-
If None, the problem is generated from X using minimum and maximum values of the features as bounds.
551+
The problem definition dictionary. If None, the problem is generated from X using
552+
minimum and maximum values of the features as bounds. The dictionary should contain:
553+
554+
- 'num_vars': Number of input variables (int)
555+
- 'names': List of variable names (list of str)
556+
- 'bounds': List of [min, max] bounds for each variable (list of lists)
557+
- 'output_names': Optional list of output names (list of str)
558+
559+
Example::
544560
545-
Example:
546-
```python
547561
problem = {
548562
"num_vars": 2,
549563
"names": ["x1", "x2"],
550564
"bounds": [[0, 1], [0, 1]],
565+
"output_names": ["y1", "y2"] # optional
551566
}
552-
```
553567
N : int, optional
554-
Number of samples to generate. Default is 1024.
568+
Number of samples to generate for the analysis. Higher values give more accurate
569+
results but increase computation time. Default is 1024.
555570
conf_level : float, optional
556-
Confidence level for the confidence intervals. Default is 0.95.
571+
Confidence level (between 0 and 1) for calculating confidence intervals of the
572+
sensitivity indices. Default is 0.95 (95% confidence).
557573
as_df : bool, optional
558-
If True, return a long-format pandas DataFrame (default is True).
574+
If True, returns results as a long-format pandas DataFrame with columns for
575+
parameters, sensitivity indices, and confidence intervals. If False, returns
576+
the raw SALib results dictionary. Default is True.
577+
578+
Returns
579+
-------
580+
pandas.DataFrame or dict
581+
If as_df=True (default), returns a DataFrame with columns:
582+
583+
- 'parameter': Input parameter name
584+
- 'output': Output variable name
585+
- 'S1', 'S2', 'ST': First, second, and total order sensitivity indices
586+
- 'S1_conf', 'S2_conf', 'ST_conf': Confidence intervals for each index
587+
588+
If as_df=False, returns the raw SALib results dictionary.
589+
590+
Notes
591+
-----
592+
The analysis requires N * (2D + 2) model evaluations, where D is the number of input
593+
parameters. For example, with N=1024 and 5 parameters, this requires 12,288 evaluations.
559594
"""
560595
if model is None:
561596
if not hasattr(self, "best_model"):
@@ -565,7 +600,7 @@ def sensitivity_analysis(
565600
f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data."
566601
)
567602

568-
Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
603+
Si = _sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
569604
return Si
570605

571606
def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None):
@@ -588,4 +623,4 @@ def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=No
588623
Figure size as (width, height) in inches.If None, automatically calculated.
589624
590625
"""
591-
return plot_sensitivity_analysis(results, index, n_cols, figsize)
626+
return _plot_sensitivity_analysis(results, index, n_cols, figsize)

autoemulate/sensitivity_analysis.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from autoemulate.utils import _ensure_2d
88

99

10-
def sensitivity_analysis(
10+
def _sensitivity_analysis(
1111
model, problem=None, X=None, N=1024, conf_level=0.95, as_df=True
1212
):
1313
"""Perform Sobol sensitivity analysis on a fitted emulator.
@@ -41,10 +41,10 @@ def sensitivity_analysis(
4141
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
4242
is a list of length corresponding to the number of parameters.
4343
"""
44-
Si = sobol_analysis(model, problem, X, N, conf_level)
44+
Si = _sobol_analysis(model, problem, X, N, conf_level)
4545

4646
if as_df:
47-
return sobol_results_to_df(Si)
47+
return _sobol_results_to_df(Si)
4848
else:
4949
return Si
5050

@@ -101,7 +101,7 @@ def _generate_problem(X):
101101
}
102102

103103

104-
def sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
104+
def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
105105
"""
106106
Perform Sobol sensitivity analysis on a fitted emulator.
107107
@@ -148,7 +148,7 @@ def sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
148148
return results
149149

150150

151-
def sobol_results_to_df(results):
151+
def _sobol_results_to_df(results):
152152
"""
153153
Convert Sobol results to a (long-format)pandas DataFrame.
154154
@@ -205,7 +205,7 @@ def sobol_results_to_df(results):
205205

206206
def _validate_input(results, index):
207207
if not isinstance(results, pd.DataFrame):
208-
results = sobol_results_to_df(results)
208+
results = _sobol_results_to_df(results)
209209
# we only want to plot one index type at a time
210210
valid_indices = ["S1", "S2", "ST"]
211211
if index not in valid_indices:
@@ -241,7 +241,7 @@ def _create_bar_plot(ax, output_data, output_name):
241241
ax.set_title(f"Output: {output_name}")
242242

243243

244-
def plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
244+
def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
245245
"""
246246
Plot the sensitivity analysis results.
247247

docs/_config.yml

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
title: "AutoEmulate: An emulator platform for Digital Twins"
1+
title: "AutoEmulate: A package for semi-automated emulation"
22
author: Martin Stoffel
33
# logo: logo.png
44

@@ -38,7 +38,10 @@ sphinx:
3838
config:
3939
add_module_names: False
4040
autodoc_typehints: none
41+
autodoc_member_order: 'bysource'
4142
autoclass_content: class
43+
autodoc_default_options:
44+
exclude-members: set_score_request
4245
bibtex_reference_style: author_year
4346
intersphinx_mapping:
4447
python:

docs/_toc.yml

+2-10
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,8 @@ chapters:
2929
- file: reference/index
3030
sections:
3131
- file: reference/compare
32-
- file: reference/cross_validate
3332
- file: reference/datasets
34-
- file: reference/experimental_design
35-
- file: reference/hyperparam_searching
36-
- file: reference/logging_config
37-
- file: reference/metrics
38-
- file: reference/model_processing
39-
- file: reference/plotting
40-
- file: reference/printing
41-
- file: reference/save
42-
- file: reference/utils
33+
- file: reference/sensitivity_analysis
4334
- file: reference/simulations/index
4435
sections:
4536
- file: reference/simulations/projectile
@@ -53,6 +44,7 @@ chapters:
5344
- file: reference/emulators/gradient_boosting
5445
- file: reference/emulators/light_gbm
5546
- file: reference/emulators/conditional_neural_process
47+
- file: reference/emulators/conditional_neural_process_attn
5648
- file: reference/emulators/gaussian_process
5749
- file: reference/emulators/gaussian_process_mt
5850
- file: reference/emulators/neural_net_sk

0 commit comments

Comments
 (0)