Skip to content

Commit 580f340

Browse files
committed
refit model before sensitivity analysis
1 parent 7c60829 commit 580f340

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,8 @@ Thumbs.db
2727
*.DS_Store
2828

2929
# VScode settings
30-
.vscode/
30+
.vscode/
31+
32+
# Quarto
33+
README.html
34+
README_files/

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ There's currently a lot of development, so we recommend installing the most curr
2222
pip install git+https://github.com/alan-turing-institute/autoemulate.git
2323
```
2424

25-
There's also a release available on PyPI (will not contain the most recent features and models)
25+
There's also a release available on PyPI (note: currently and older version and out of date with the documentation)
2626
```bash
2727
pip install autoemulate
2828
```
@@ -47,19 +47,27 @@ from autoemulate.simulations.projectile import simulate_projectile
4747
lhd = LatinHypercube([(-5., 1.), (0., 1000.)])
4848
X = lhd.sample(100)
4949
y = np.array([simulate_projectile(x) for x in X])
50+
5051
# compare emulator models
5152
ae = AutoEmulate()
5253
ae.setup(X, y)
53-
best_model = ae.compare()
54+
best_emulator = ae.compare()
55+
5456
# training set cross-validation results
5557
ae.summarise_cv()
5658
ae.plot_cv()
59+
5760
# test set results for the best model
58-
ae.evaluate(best_model)
59-
ae.plot_eval(best_model)
61+
ae.evaluate(best_emulator)
62+
ae.plot_eval(best_emulator)
63+
6064
# refit on full data and emulate!
61-
best_model = ae.refit(best_model)
62-
best_model.predict(X)
65+
emulator = ae.refit(best_emulator)
66+
emulator.predict(X)
67+
68+
# global sensitivity analysis
69+
si = ae.sensitivity_analysis(emulator)
70+
ae.plot_sensitivity_analysis(si)
6371
```
6472

6573
## documentation

autoemulate/compare.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,9 +557,9 @@ def sensitivity_analysis(
557557
if model is None:
558558
if not hasattr(self, "best_model"):
559559
raise RuntimeError("Must run compare() before sensitivity_analysis()")
560-
model = self.best_model
560+
model = self.refit(self.best_model)
561561
self.logger.info(
562-
f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score."
562+
f"No model provided, using {get_model_name(model)}, which had the highest average cross-validation score, refitted on full data."
563563
)
564564

565565
Si = sensitivity_analysis(model, problem, self.X, N, conf_level, as_df)

0 commit comments

Comments
 (0)