Skip to content

Commit 0cf3b45

Browse files
committed
remove sensitivity analysis fix from this PR
1 parent da65bf0 commit 0cf3b45

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

autoemulate/sensitivity_analysis.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _sensitivity_analysis(
4444
Si = _sobol_analysis(model, problem, X, N, conf_level)
4545

4646
if as_df:
47-
return _sobol_results_to_df(Si, problem)
47+
return _sobol_results_to_df(Si)
4848
else:
4949
return Si
5050

@@ -148,30 +148,21 @@ def _sobol_analysis(model, problem=None, X=None, N=1024, conf_level=0.95):
148148
return results
149149

150150

151-
def _sobol_results_to_df(results, problem=None):
151+
def _sobol_results_to_df(results):
152152
"""
153-
Convert Sobol results to a (long-format) pandas DataFrame.
153+
Convert Sobol results to a (long-format)pandas DataFrame.
154154
155155
Parameters:
156156
-----------
157157
results : dict
158158
The Sobol indices returned by sobol_analysis.
159-
problem : dict, optional
160-
The problem definition, including 'names'.
161159
162160
Returns:
163161
--------
164162
pd.DataFrame
165163
A DataFrame with columns: 'output', 'parameter', 'index', 'value', 'confidence'.
166164
"""
167165
rows = []
168-
# Use custom names if provided, else default to "x1", "x2", etc.
169-
parameter_names = (
170-
problem["names"]
171-
if problem is not None
172-
else [f"x{i+1}" for i in range(len(next(iter(results.values()))["S1"]))]
173-
)
174-
175166
for output, indices in results.items():
176167
for index_type in ["S1", "ST", "S2"]:
177168
values = indices.get(index_type)
@@ -183,7 +174,7 @@ def _sobol_results_to_df(results, problem=None):
183174
rows.extend(
184175
{
185176
"output": output,
186-
"parameter": parameter_names[i], # Use appropriate names
177+
"parameter": f"X{i+1}",
187178
"index": index_type,
188179
"value": value,
189180
"confidence": conf,
@@ -196,7 +187,7 @@ def _sobol_results_to_df(results, problem=None):
196187
rows.extend(
197188
{
198189
"output": output,
199-
"parameter": f"{parameter_names[i]}-{parameter_names[j]}", # Use appropriate names
190+
"parameter": f"X{i+1}-X{j+1}",
200191
"index": index_type,
201192
"value": values[i, j],
202193
"confidence": conf_values[i, j],
@@ -205,15 +196,16 @@ def _sobol_results_to_df(results, problem=None):
205196
for j in range(i + 1, n)
206197
if not np.isnan(values[i, j])
207198
)
199+
208200
return pd.DataFrame(rows)
209201

210202

211203
# plotting --------------------------------------------------------------------
212204

213205

214-
def _validate_input(results, problem, index):
206+
def _validate_input(results, index):
215207
if not isinstance(results, pd.DataFrame):
216-
results = _sobol_results_to_df(results, problem=problem)
208+
results = _sobol_results_to_df(results)
217209
# we only want to plot one index type at a time
218210
valid_indices = ["S1", "S2", "ST"]
219211
if index not in valid_indices:
@@ -249,7 +241,7 @@ def _create_bar_plot(ax, output_data, output_name):
249241
ax.set_title(f"Output: {output_name}")
250242

251243

252-
def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsize=None):
244+
def _plot_sensitivity_analysis(results, index="S1", n_cols=None, figsize=None):
253245
"""
254246
Plot the sensitivity analysis results.
255247
@@ -271,7 +263,7 @@ def _plot_sensitivity_analysis(results, problem, index="S1", n_cols=None, figsiz
271263
"""
272264
with plt.style.context("fast"):
273265
# prepare data
274-
results = _validate_input(results, problem, index)
266+
results = _validate_input(results, index)
275267
unique_outputs = results["output"].unique()
276268
n_outputs = len(unique_outputs)
277269

0 commit comments

Comments
 (0)