Skip to content

Commit a6f7795

Browse files
committed
removing sensitivity analysis variable fix from here, this is in a seperate PR now
1 parent 0cf3b45 commit a6f7795

File tree

1 file changed

+304
-0
lines changed

1 file changed

+304
-0
lines changed

sensitivity_analysis.py

Lines changed: 304 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,304 @@
1+
import matplotlib.pyplot as plt
2+
import numpy as np
3+
import pandas as pd
4+
from SALib.analyze.sobol import analyze
5+
from SALib.sample.sobol import sample
6+
7+
from autoemulate.utils import _ensure_2d
8+
9+
10+
def _sensitivity_analysis(
11+
model, problem=None, X=None, N=1024, conf_level=0.95, as_df=True
12+
):
13+
"""Perform Sobol sensitivity analysis on a fitted emulator.
14+
15+
Parameters:
16+
-----------
17+
model : fitted emulator model
18+
The emulator model to analyze.
19+
problem : dict
20+
The problem definition, including 'num_vars', 'names', and 'bounds', optional 'output_names'.
21+
Example:
22+
```python
23+
problem = {
24+
"num_vars": 2,
25+
"names": ["x1", "x2"],
26+
"bounds": [[0, 1], [0, 1]],
27+
}
28+
```
29+
N : int, optional
30+
The number of samples to generate (default is 1024).
31+
conf_level : float, optional
32+
The confidence level for the confidence intervals (default is 0.95).
33+
as_df : bool, optional
34+
If True, return a pandas DataFrame (default is True).
35+
36+
Returns:
37+
--------
38+
pd.DataFrame or dict
39+
If as_df is True, returns a long-format DataFrame with the sensitivity indices.
40+
Otherwise, returns a dictionary where each key is the name of an output variable and each value is a dictionary
41+
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
42+
is a list of length corresponding to the number of parameters.
43+
"""
44+
Si = _sobol_analysis(model, problem, X, N, conf_level)
45+
46+
if as_df:
47+
return _sobol_results_to_df(Si)
48+
else:
49+
return Si
50+
51+
52+
def _check_problem(problem):
53+
"""
54+
Check that the problem definition is valid.
55+
"""
56+
if not isinstance(problem, dict):
57+
raise ValueError("problem must be a dictionary.")
58+
59+
if "num_vars" not in problem:
60+
raise ValueError("problem must contain 'num_vars'.")
61+
if "names" not in problem:
62+
raise ValueError("problem must contain 'names'.")
63+
if "bounds" not in problem:
64+
raise ValueError("problem must contain 'bounds'.")
65+
66+
if len(problem["names"]) != problem["num_vars"]:
67+
raise ValueError("Length of 'names' must match 'num_vars'.")
68+
if len(problem["bounds"]) != problem["num_vars"]:
69+
raise ValueError("Length of 'bounds' must match 'num_vars'.")
70+
71+
return problem
72+
73+
74+
def _get_output_names(problem, num_outputs):
75+
"""
76+
Get the output names from the problem definition or generate default names.
77+
"""
78+
# check if output_names is given
79+
if "output_names" not in problem:
80+
output_names = [f"y{i+1}" for i in range(num_outputs)]
81+
else:
82+
if isinstance(problem["output_names"], list):
83+
output_names = problem["output_names"]
84+
else:
85+
raise ValueError("'output_names' must be a list of strings.")
86+
87+
return output_names
88+
89+
90+
def _generate_problem(X):
91+
"""
92+
Generate a problem definition from a design matrix.
93+
"""
94+
if X.ndim == 1:
95+
raise ValueError("X must be a 2D array.")
96+
97+
return {
98+
"num_vars": X.shape[1],
99+
"names": [f"x{i+1}" for i in range(X.shape[1])],
100+
"bounds": [[X[:, i].min(), X[:, i].max()] for i in range(X.shape[1])],
101+
}
102+
103+
104+
def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
105+
"""
106+
Perform Sobol sensitivity analysis on a fitted emulator.
107+
108+
Parameters:
109+
-----------
110+
model : fitted emulator model
111+
The emulator model to analyze.
112+
problem : dict
113+
The problem definition, including 'num_vars', 'names', and 'bounds'.
114+
N : int, optional
115+
The number of samples to generate (default is 1000).
116+
117+
Returns:
118+
--------
119+
dict
120+
A dictionary where each key is the name of an output variable and each value is a dictionary
121+
containing the Sobol indices keys ‘S1’, ‘S1_conf’, ‘ST’, and ‘ST_conf’, where each entry
122+
is a list of length corresponding to the number of parameters.
123+
"""
124+
# get problem
125+
if problem is not None:
126+
problem = _check_problem(problem)
127+
elif X is not None:
128+
problem = _generate_problem(X)
129+
else:
130+
raise ValueError("Either problem or X must be provided.")
131+
132+
# saltelli sampling
133+
param_values = sample(problem, N)
134+
135+
# evaluate
136+
Y = model.predict(param_values)
137+
Y = _ensure_2d(Y)
138+
139+
num_outputs = Y.shape[1]
140+
output_names = _get_output_names(problem, num_outputs)
141+
142+
# single or multiple output sobol analysis
143+
results = {}
144+
for i in range(num_outputs):
145+
Si = analyze(problem, Y[:, i], conf_level=conf_level)
146+
results[output_names[i]] = Si
147+
148+
return results
149+
150+
151+
def _sobol_results_to_df(results):
152+
"""
153+
Convert Sobol results to a (long-format)pandas DataFrame.
154+
155+
Parameters:
156+
-----------
157+
results : dict
158+
The Sobol indices returned by sobol_analysis.
159+
160+
Returns:
161+
--------
162+
pd.DataFrame
163+
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
164+
"""
165+
rows = []
166+
for output, indices in results.items():
167+
for index_type in ["S1", "ST", "S2"]:
168+
values = indices.get(index_type)
169+
conf_values = indices.get(f"{index_type}_conf")
170+
if values is None or conf_values is None:
171+
continue
172+
173+
if index_type in ["S1", "ST"]:
174+
rows.extend(
175+
{
176+
"output": output,
177+
"parameter": f"X{i+1}",
178+
"index": index_type,
179+
"value": value,
180+
"confidence": conf,
181+
}
182+
for i, (value, conf) in enumerate(zip(values, conf_values))
183+
)
184+
185+
elif index_type == "S2":
186+
n = values.shape[0]
187+
rows.extend(
188+
{
189+
"output": output,
190+
"parameter": f"X{i+1}-X{j+1}",
191+
"index": index_type,
192+
"value": values[i, j],
193+
"confidence": conf_values[i, j],
194+
}
195+
for i in range(n)
196+
for j in range(i + 1, n)
197+
if not np.isnan(values[i, j])
198+
)
199+
200+
return pd.DataFrame(rows)
201+
202+
203+
# plotting --------------------------------------------------------------------
204+
205+
206+
def _validate_input(results, index):
207+
if not isinstance(results, pd.DataFrame):
208+
results = _sobol_results_to_df(results)
209+
# we only want to plot one index type at a time
210+
valid_indices = ["S1", "S2", "ST"]
211+
if index not in valid_indices:
212+
raise ValueError(
213+
f"Invalid index type: {index}. Must be one of {valid_indices}."
214+
)
215+
return results[results["index"].isin([index])]
216+
217+
218+
def _calculate_layout(n_outputs, n_cols=None):
219+
if n_cols is None:
220+
n_cols = 3 if n_outputs >= 3 else n_outputs
221+
n_rows = int(np.ceil(n_outputs / n_cols))
222+
return n_rows, n_cols
223+
224+
225+
def _create_bar_plot(ax, output_data, output_name):
226+
"""Create a bar plot for a single output."""
227+
bar_color = "#4C4B63"
228+
x_pos = np.arange(len(output_data))
229+
230+
bars = ax.bar(
231+
x_pos,
232+
output_data["value"],
233+
color=bar_color,
234+
yerr=output_data["confidence"].values / 2,
235+
capsize=3,
236+
)
237+
238+
ax.set_xticks(x_pos)
239+
ax.set_xticklabels(output_data["parameter"], rotation=45, ha="right")
240+
ax.set_ylabel("Sobol Index")
241+
ax.set_title(f"Output: {output_name}")
242+
243+
244+
def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
245+
"""
246+
Plot the sensitivity analysis results.
247+
248+
Parameters:
249+
-----------
250+
results : pd.DataFrame
251+
The results from sobol_results_to_df.
252+
index : str, default "S1"
253+
The type of sensitivity index to plot.
254+
- "S1": first-order indices
255+
- "S2": second-order/interaction indices
256+
- "ST": total-order indices
257+
n_cols : int, optional
258+
The number of columns in the plot. Defaults to 3 if there are 3 or more outputs,
259+
otherwise the number of outputs.
260+
figsize : tuple, optional
261+
Figure size as (width, height) in inches.If None, automatically calculated.
262+
263+
"""
264+
with plt.style.context("fast"):
265+
# prepare data
266+
results = _validate_input(results, index)
267+
unique_outputs = results["output"].unique()
268+
n_outputs = len(unique_outputs)
269+
270+
# layout
271+
n_rows, n_cols = _calculate_layout(n_outputs, n_cols)
272+
figsize = figsize or (4.5 * n_cols, 4 * n_rows)
273+
274+
fig, axes = plt.subplots(n_rows, n_cols, figsize=figsize)
275+
if isinstance(axes, np.ndarray):
276+
axes = axes.flatten()
277+
elif n_outputs == 1:
278+
axes = [axes]
279+
280+
for ax, output in zip(axes, unique_outputs):
281+
output_data = results[results["output"] == output]
282+
_create_bar_plot(ax, output_data, output)
283+
284+
# remove any empty subplots
285+
for idx in range(len(unique_outputs), len(axes)):
286+
fig.delaxes(axes[idx])
287+
288+
index_names = {
289+
"S1": "First-Order",
290+
"S2": "Second-order/Interaction",
291+
"ST": "Total-Order",
292+
}
293+
294+
# title
295+
fig.suptitle(
296+
f"{index_names[index]} indices and 95% CI",
297+
fontsize=14,
298+
)
299+
300+
plt.tight_layout()
301+
# prevent double plotting in notebooks
302+
plt.close(fig)
303+
304+
return fig

0 commit comments

Comments
 (0)