Skip to content

Commit 2fecaa9

Browse files
sensitivity analysis
1 parent d39b922 commit 2fecaa9

File tree

1 file changed

+138
-15
lines changed

1 file changed

+138
-15
lines changed

src/sbmlsim/sensitivity/global_sensitivity.py

Lines changed: 138 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
2121
"""
2222
from typing import Optional
23+
import xarray as xr
2324

2425
import SALib
2526
from SALib import ProblemSpec
@@ -79,9 +80,11 @@ def simulate(self, changes: dict[str, float]) -> dict[str, float]:
7980
def parameter_values(self, parameters: list[str], changes: dict[str, float]) -> dict[str, float]:
8081
"""Get the parameter values for a given set of changes."""
8182
self.apply_changes(changes, reset_all=True)
83+
8284
values: dict[str, float] = {}
8385
for pid in parameters:
8486
values[pid] = self.rr.getValue(pid)
87+
8588
return values
8689

8790

@@ -121,11 +124,11 @@ def __init__(self, sensitivity_simulation: SensitivitySimulation,
121124
# outputs to calculate sensitivity on; shape: (num_outputs,)
122125
self.outputs: list[str] = sensitivity_simulation.outputs
123126
# parameter samples for sensitivity; shape: (num_samples x num_parameters)
124-
self.samples: Optional[np.ndarray] = None
127+
self.samples: Optional[xr.DataArray] = None
125128
# outputs for given samples; shape: (num_samples x num_outputs)
126-
self.results: Optional[np.ndarray] = None
129+
self.results: Optional[xr.DataArray] = None
127130
# sensitivity matrix; shape: (num_parameters x num_outputs); could be multiple
128-
self.sensitivity_results: Optional[np.ndarray] = None
131+
self.sensitivity_results: Optional[xr.DataArray] = None
129132

130133
@property
131134
def num_parameters(self) -> int:
@@ -152,7 +155,7 @@ def simulate_samples(self) -> None:
152155
self.samples = np.zeros(shape=(self.num_samples, self.num_parameters))
153156
self.outputs = np.zeros(shape=(self.num_samples, self.num_outputs))
154157

155-
for k in range(self.num_samples()):
158+
for k in range(self.num_samples):
156159
changes = dict(zip(self.parameters, self.samples[k, :]))
157160
outputs = self.sensitivity_simulation.simulate(changes=changes)
158161
self.outputs[k, :] = outputs
@@ -185,26 +188,35 @@ def num_samples(self) -> int:
185188
"""Number of parameter samples to simulate."""
186189
return 2 * self.num_parameters
187190

188-
def create_samples(self) -> np.ndarray:
191+
def create_samples(self) -> None:
189192

190193
# Calculate the parameter values in the reference state
191-
parameter_values = self.sensitivity_simulation.parameter_values(
194+
parameter_values: dict[str, float] = self.sensitivity_simulation.parameter_values(
195+
parameters=self.parameters,
192196
changes=self.sensitivity_simulation.changes_simulation
193197
)
194198

195199
# (num_samples x num_outputs)
196-
samples = np.empty(shape=(self.num_samples, self.num_parameters))
197-
198-
for key, value in :
199-
values = np.ones(shape=(2 * num_pars,)) * value.magnitude
200+
num_samples = 2*self.num_parameters
201+
samples = np.empty(shape=(num_samples, self.num_parameters))
202+
samples = xr.DataArray(
203+
np.full((num_samples, self.num_parameters), np.nan),
204+
dims=["sample", "parameter"],
205+
coords={"sample": range(num_samples), "parameter": self.parameters},
206+
name="samples"
207+
)
200208

209+
reference_values = np.array(list(parameter_values.values()))
210+
for kp, pid in enumerate(parameter_values):
211+
value = parameter_values[pid]
201212

213+
# right sided changes
214+
samples[2*kp, :] = reference_values
215+
samples[2*kp, kp] = value * (1.0 + self.difference)
216+
samples[2 * kp + 1 , :] = reference_values
217+
samples[2 * kp + 1, :] = value * (1.0 - self.difference)
202218

203-
# change parameters in correct position
204-
values[index] = value.magnitude * (1.0 + difference)
205-
values[index + num_pars] = value.magnitude * (1.0 - difference)
206-
changes[key] = Q_(values, value.units)
207-
index += 1
219+
self.samples = samples
208220

209221
def calculate_sensitivity(self):
210222

@@ -214,6 +226,117 @@ def plot_sensitivity(self):
214226

215227
pass
216228

229+
from matplotlib import pyplot as plt
230+
import seaborn as sns
231+
import numpy as np
232+
233+
def heatmap(da: xr.DataArray, cutoff: float=0.01, annotate_values=True, transpose: bool=False):
234+
"""Creates heatmap of model sensitivity"""
235+
236+
def calculate_mask(df, cutoff=0.01):
237+
"""Calculates a boolean mask DataFrame for the heatmap based on cutoff."""
238+
mask = np.empty(shape=df.shape, dtype="bool")
239+
for index, value in np.ndenumerate(df):
240+
if np.abs(value) < cutoff:
241+
mask[index] = True
242+
else:
243+
mask[index] = False
244+
return pd.DataFrame(data=mask, columns=df.COLUMNS, index=df.index)
245+
246+
def calculate_subset(df, cutoff=0.01):
247+
"""Calculates subset of data frame consisting of rows where at least
248+
one value is above cutoff."""
249+
return df[(df.abs() >= cutoff).any(axis=1)]
250+
251+
252+
253+
# filter rows
254+
# X.drop(pk_exclude, axis=1, inplace=True)
255+
256+
# if cutoff > 0:
257+
# X_subset = calculate_subset(X, cutoff=cutoff)
258+
# X_subset_mask = calculate_mask(X_subset, cutoff)
259+
da_subset = da
260+
261+
# yticklabels = ["{}".format(pid) for pid in X_subset.index]
262+
# xticklabels = ["{}".format(pnames[pid]["label"]) for pid in X_subset.COLUMNS]
263+
264+
xticklabels = da.coords[da.dims[1]]
265+
yticklabels = da.coords[da.dims[0]]
266+
267+
# plot heatmap
268+
ax = sns.clustermap(
269+
da_subset,
270+
center=0,
271+
# vmin=-0.2,
272+
# vmax=0.2,
273+
xticklabels=xticklabels,
274+
yticklabels=yticklabels,
275+
cmap="seismic",
276+
cbar_pos=(0.05, 0.25, 0.03, 0.4),
277+
annot=annotate_values,
278+
fmt="1.2f",
279+
annot_kws={"size": 13},
280+
# mask=X_subset_mask,
281+
col_cluster=False,
282+
method="single",
283+
figsize=(20, 20),
284+
)
285+
plt.setp(
286+
ax.ax_heatmap.get_xticklabels(),
287+
rotation=45,
288+
horizontalalignment="right",
289+
size=20,
290+
)
291+
plt.setp(ax.ax_heatmap.get_yticklabels(), size=20)
292+
ax.ax_cbar.tick_params(labelsize=20)
293+
ax.ax_row_dendrogram.set_visible(False)
294+
ax.ax_col_dendrogram.set_visible(False)
295+
296+
# create custom legend containing yticklabels and their description
297+
# handles = [t.get_text() for t in ax.ax_heatmap.get_yticklabels()]
298+
# labels = [pnames[pid]["label"] for pid in handles]
299+
#
300+
# # FIXME: update after defining labels
301+
# idx = [pnames[pid]["idx"] for pid in handles]
302+
# # idx = [k for k, pid in enumerate(handles)]
303+
#
304+
# labels = [label for _, label in sorted(zip(idx, labels))]
305+
# handles = [f"{handle}:" for _, handle in sorted(zip(idx, handles))]
306+
# handles = [handle.replace("_", "\_") for handle in handles]
307+
308+
# mid = int(np.ceil(len(handles) / 2))
309+
# legend1 = plt.legend(
310+
# handles[:mid],
311+
# labels[:mid],
312+
# handler_map={str: LegendTitle({"fontsize": 16})},
313+
# fontsize=16,
314+
# frameon=False,
315+
# bbox_to_anchor=(1.2, -0.6),
316+
# loc="upper left",
317+
# handlelength=14,
318+
# )
319+
# legend2 = plt.legend(
320+
# handles[mid:],
321+
# labels[mid:],
322+
# handler_map={str: LegendTitle({"fontsize": 16})},
323+
# fontsize=16,
324+
# frameon=False,
325+
# bbox_to_anchor=(13, -0.6),
326+
# loc="upper left",
327+
# handlelength=19,
328+
# )
329+
# plt.gca().add_artist(legend1)
330+
331+
# plt.savefig(
332+
# results_dir / "parameter.sensitivity_cluster.png", dpi=300, bbox_inches="tight"
333+
# )
334+
# plt.savefig(results_dir / "parameter.sensitivity_cluster.svg", bbox_inches="tight")
335+
336+
plt.show()
337+
338+
339+
217340

218341
@dataclass
219342
class SamplingSensitivityAnalysis(SensitivityAnalysis):

0 commit comments

Comments
 (0)