Skip to content

Commit dcfb148

Browse files
authored
Merge pull request #260 from alan-turing-institute/sensitivity-analysis
Sensitivity analysis
2 parents 48129e4 + 580f340 commit dcfb148

File tree

8 files changed

+745
-66
lines changed

8 files changed

+745
-66
lines changed

.github/workflows/ci.yaml

+13-2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ jobs:
2525
with:
2626
python-version: ${{ matrix.python-version }}
2727

28+
# Clean python cache files
29+
- name: Clean cache files
30+
run: |
31+
find . -type f -name "*.pyc" -delete
32+
find . -type d -name "__pycache__" -delete
33+
rm -f .coverage* coverage.xml
2834
# Cache Poetry dependencies
2935
# - name: Cache dependencies
3036
# uses: actions/cache@v2
@@ -56,10 +62,15 @@ jobs:
5662
run: |
5763
poetry run python -c "import torch; print(torch.__version__); print('CUDA available:', torch.cuda.is_available())"
5864
65+
- name: Debug Coverage Config
66+
run: |
67+
cat pyproject.toml
68+
poetry run coverage debug config
69+
5970
- name: Run Tests with Coverage
6071
run: |
61-
poetry run coverage run -m pytest
62-
poetry run coverage xml -o coverage.xml
72+
poetry run coverage run --source=. -m pytest
73+
poetry run coverage xml -i -o coverage.xml
6374
env:
6475
COVERAGE_FILE: ".coverage.${{ matrix.python-version }}"
6576

.gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,8 @@ Thumbs.db
2727
*.DS_Store
2828

2929
# VScode settings
30-
.vscode/
30+
.vscode/
31+
32+
# Quarto
33+
README.html
34+
README_files/

README.md

+14-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ There's currently a lot of development, so we recommend installing the most curr
2222
pip install git+https://github.com/alan-turing-institute/autoemulate.git
2323
```
2424

25-
There's also a release available on PyPI (will not contain the most recent features and models)
25+
There's also a release available on PyPI (note: currently and older version and out of date with the documentation)
2626
```bash
2727
pip install autoemulate
2828
```
@@ -47,19 +47,27 @@ from autoemulate.simulations.projectile import simulate_projectile
4747
lhd = LatinHypercube([(-5., 1.), (0., 1000.)])
4848
X = lhd.sample(100)
4949
y = np.array([simulate_projectile(x) for x in X])
50+
5051
# compare emulator models
5152
ae = AutoEmulate()
5253
ae.setup(X, y)
53-
best_model = ae.compare()
54+
best_emulator = ae.compare()
55+
5456
# training set cross-validation results
5557
ae.summarise_cv()
5658
ae.plot_cv()
59+
5760
# test set results for the best model
58-
ae.evaluate(best_model)
59-
ae.plot_eval(best_model)
61+
ae.evaluate(best_emulator)
62+
ae.plot_eval(best_emulator)
63+
6064
# refit on full data and emulate!
61-
best_model = ae.refit(best_model)
62-
best_model.predict(X)
65+
emulator = ae.refit(best_emulator)
66+
emulator.predict(X)
67+
68+
# global sensitivity analysis
69+
si = ae.sensitivity_analysis(emulator)
70+
ae.plot_sensitivity_analysis(si)
6371
```
6472

6573
## documentation

autoemulate/compare.py

+64
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from autoemulate.plotting import _plot_model
2121
from autoemulate.printing import _print_setup
2222
from autoemulate.save import ModelSerialiser
23+
from autoemulate.sensitivity_analysis import plot_sensitivity_analysis
24+
from autoemulate.sensitivity_analysis import sensitivity_analysis
25+
from autoemulate.utils import _ensure_2d
2326
from autoemulate.utils import _get_full_model_name
2427
from autoemulate.utils import _redirect_warnings
2528
from autoemulate.utils import get_model_name
@@ -522,3 +525,64 @@ def plot_eval(
522525
)
523526

524527
return fig
528+
529+
def sensitivity_analysis(
530+
self, model=None, problem=None, N=1024, conf_level=0.95, as_df=True
531+
):
532+
"""Perform Sobol sensitivity analysis on a fitted emulator.
533+
534+
Parameters
535+
----------
536+
model : object, optional
537+
Fitted model. If None, uses the best model from cross-validation.
538+
problem : dict, optional
539+
The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
540+
If None, the problem is generated from X using minimum and maximum values of the features as bounds.
541+
542+
Example:
543+
```python
544+
problem = {
545+
"num_vars": 2,
546+
"names": ["x1", "x2"],
547+
"bounds": [[0, 1], [0, 1]],
548+
}
549+
```
550+
N : int, optional
551+
Number of samples to generate. Default is 1024.
552+
conf_level : float, optional
553+
Confidence level for the confidence intervals. Default is 0.95.
554+
as_df : bool, optional
555+
If True, return a long-format pandas DataFrame (default is True).
556+
"""
557+
if model is None:
558+
if not hasattr(self, "best_model"):
559+
raise RuntimeError("Must run compare() before sensitivity_analysis()")
560+
model = self.refit(self.best_model)
561+
self.logger.info(
562+
f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data."
563+
)
564+
565+
Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)
566+
return Si
567+
568+
def plot_sensitivity_analysis(self, results, index="S1", n_cols=None, figsize=None):
569+
"""
570+
Plot the sensitivity analysis results.
571+
572+
Parameters:
573+
-----------
574+
results : pd.DataFrame
575+
The results from sobol_results_to_df.
576+
index : str, default "S1"
577+
The type of sensitivity index to plot.
578+
- "S1": first-order indices
579+
- "S2": second-order/interaction indices
580+
- "ST": total-order indices
581+
n_cols : int, optional
582+
The number of columns in the plot. Defaults to 3 if there are 3 or more outputs,
583+
otherwise the number of outputs.
584+
figsize : tuple, optional
585+
Figure size as (width, height) in inches.If None, automatically calculated.
586+
587+
"""
588+
return plot_sensitivity_analysis(results, index, n_cols, figsize)

0 commit comments

Comments
 (0)