Skip to content

Commit c0f1d46

Browse files
updated sensitivity analysis
1 parent 47cde3d commit c0f1d46

File tree

2 files changed

+71
-28
lines changed

2 files changed

+71
-28
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 54 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,26 @@ class SensitivityAnalysis:
105105

106106
def __init__(self,
107107
sensitivity_simulation: SensitivitySimulation,
108-
parameters: list[SensitivityParameter]
108+
parameters: list[SensitivityParameter],
109+
results_path: Path,
109110
) -> None:
110111
"""Create a sensitivity analysis for given parameter ids.
111112
112113
Based on the results matrix the sensitivity is calculated.
113114
"""
114115
self.sensitivity_simulation = sensitivity_simulation
115116

117+
# outputs to calculate sensitivity on; shape: (num_outputs,)
118+
self.outputs: list[SensitivityOutput] = sensitivity_simulation.outputs
119+
self.output_ids: list[str] = [q.uid for q in self.outputs]
120+
116121
# parameters to vary; shape: (num_parameters,)
117122
self.parameters: list[SensitivityParameter] = parameters
118123
self.parameter_ids: list[str] = [p.uid for p in self.parameters]
119124

120-
# outputs to calculate sensitivity on; shape: (num_outputs,)
121-
self.outputs: list[SensitivityOutput] = sensitivity_simulation.outputs
122-
self.output_ids: list[str] = [q.uid for q in self.outputs]
125+
# storage directory
126+
self.results_path: Path = results_path
127+
results_path.mkdir(parents=True, exist_ok=True)
123128

124129
# parameter samples for sensitivity; shape: (num_samples x num_parameters)
125130
self.samples: Optional[xr.DataArray] = None
@@ -255,17 +260,19 @@ def plot_sensitivity(
255260
cluster_rows: bool = True,
256261
title: Optional[str] = None,
257262
cmap: str = "seismic",
263+
fig_path: Optional[Path] = None,
258264
**kwargs
259265
) -> None:
260266
df = self.sensitivity_df(key=key)
261267
heatmap(
262268
df=df,
263-
parameter_labels={p.uid: p.name for p in self.parameters},
269+
parameter_labels={p.uid: f"{p.uid}: {p.name}" for p in self.parameters},
264270
output_labels={q.uid: q.name for q in self.outputs},
265271
cutoff=cutoff,
266272
cluster_rows=cluster_rows,
267273
title=title,
268274
cmap=cmap,
275+
fig_path=fig_path,
269276
**kwargs
270277
)
271278

@@ -296,9 +303,11 @@ class LocalSensitivityAnalysis(SensitivityAnalysis):
296303
"""
297304

298305
def __init__(self, sensitivity_simulation: SensitivitySimulation,
299-
parameters: list[SensitivityParameter], difference: float = 0.01):
306+
parameters: list[SensitivityParameter],
307+
results_path: Path,
308+
difference: float = 0.01):
300309

301-
super().__init__(sensitivity_simulation, parameters)
310+
super().__init__(sensitivity_simulation, parameters, results_path)
302311

303312
self.difference: float = difference
304313

@@ -345,6 +354,7 @@ def create_samples(self) -> None:
345354
samples[-1, :] = reference_values # reference
346355

347356
self.samples = samples
357+
console.print(self.samples)
348358

349359
def calculate_sensitivity(self):
350360
"""Calculate the two-sided local sensitivity matrix."""
@@ -395,9 +405,12 @@ class SobolSensitivityAnalysis(SensitivityAnalysis):
395405
def __init__(self,
396406
sensitivity_simulation: SensitivitySimulation,
397407
parameters: list[SensitivityParameter],
408+
N: int,
409+
results_path: Path,
398410
):
399411

400-
super().__init__(sensitivity_simulation, parameters)
412+
super().__init__(sensitivity_simulation, parameters, results_path)
413+
self.N: int = N
401414

402415
# define the problem specification
403416
self.ssa_problem: ProblemSpec = ProblemSpec({
@@ -409,7 +422,7 @@ def __init__(self,
409422
# console.print(self.ssa_problem)
410423

411424

412-
def create_samples(self, N: int=1024):
425+
def create_samples(self) -> None:
413426
"""Create samples for sobol.
414427
415428
Generates model inputs using Saltelli's extension of the Sobol' sequence
@@ -419,15 +432,14 @@ def create_samples(self, N: int=1024):
419432
"""
420433

421434
# libsa samples based on definition
422-
ssa_samples = saltelli.sample(self.ssa_problem, N=N, calc_second_order=True)
435+
ssa_samples = saltelli.sample(self.ssa_problem, N=self.N, calc_second_order=True)
423436
self.ssa_problem.set_samples(ssa_samples)
424437

425438
# (num_samples x num_outputs)
426439
# total model evaluations are (2d+2) * N for d input factors
427-
num_samples = (2 * self.num_parameters + 2) * N
440+
num_samples = (2 * self.num_parameters + 2) * self.N
428441

429442
self.samples = xr.DataArray(
430-
# np.full((num_samples, self.num_parameters), np.nan),
431443
ssa_samples,
432444
dims=["sample", "parameter"],
433445
coords={"sample": range(num_samples),
@@ -436,7 +448,7 @@ def create_samples(self, N: int=1024):
436448
)
437449

438450

439-
def calculate_sensitivity(self):
451+
def calculate_sensitivity(self) -> None:
440452
"""Calculate the sensitivity matrices."""
441453

442454
Y = self.results.values
@@ -465,16 +477,38 @@ def calculate_sensitivity(self):
465477
Si = SALib.analyze.sobol.analyze(
466478
self.ssa_problem, Yo,
467479
calc_second_order=True,
468-
print_to_console=True,
480+
print_to_console=False,
481+
n_processors=4,
469482
)
470-
console.print("S1")
471-
console.print(Si["S1"])
472483
for key in sensitivity_keys:
473484
self.sensitivity[key][:, ko] = Si[key]
474485

475486

476-
def plot(self):
477-
Si.plot()
478-
from matplotlib import pyplot as plt
479-
plt.show()
487+
def plot_sobol_indices(
488+
self,
489+
fig_path: Path,
490+
):
491+
"""Barplots for the Sobol indices.
492+
493+
"""
494+
parameter_labels: dict[str, str] = {p.uid: f"{p.uid}: {p.name}" for p in self.parameters}
495+
output_labels: dict[str, str] = {q.uid: q.name for q in self.outputs}
496+
497+
498+
for ko, output in enumerate(self.outputs):
499+
S1 = self.sensitivity["S1"][:, ko]
500+
ST = self.sensitivity["ST"][:, ko]
501+
S1_conf = self.sensitivity["S1_conf"][:, ko]
502+
ST_conf = self.sensitivity["ST_conf"][:, ko]
503+
console.print(S1)
504+
console.print(type(S1))
505+
506+
break
507+
508+
509+
510+
511+
512+
513+
480514

src/sbmlsim/sensitivity/plots.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
FIXME: use patchcollection
44
https://stackoverflow.com/questions/59381273/heatmap-with-circles-indicating-size-of-population
55
"""
6+
from pathlib import Path
7+
68
import xarray as xr
79
from typing import Optional
810

@@ -26,6 +28,7 @@ def heatmap(
2628
vcenter: float = 0.0,
2729
vmin: float = -2.0,
2830
vmax: float = 2.0,
31+
fig_path: Optional[Path] = None,
2932
):
3033
"""Creates heatmap of model sensitivity"""
3134

@@ -63,12 +66,12 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
6366
# parameters
6467
yticklabels = [pid for pid in df_subset.index]
6568
if parameter_labels:
66-
yticklabels = [f"{pid}: {parameter_labels[pid]}" for pid in yticklabels]
69+
yticklabels = [parameter_labels[pid] for pid in yticklabels]
6770

6871
n_outputs = df_subset.shape[1]
6972
n_parameters = df_subset.shape[0]
70-
figsize = (int(n_outputs/n_parameters*15), 15)
71-
73+
# (width, height)
74+
figsize = (15, int(n_parameters/n_outputs*15))
7275

7376
# plot heatmap
7477
cg = sns.clustermap(
@@ -79,7 +82,6 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
7982
xticklabels=xticklabels,
8083
yticklabels=yticklabels,
8184
cmap=cmap,
82-
# cbar_pos=(0.0, 0.0, 0.6, 0.05), # (left, bottom, width, height),
8385
cbar_pos=(0.0, 0.4, 0.03, 0.2), # (left, bottom, width, height),
8486
cbar_kws={
8587
"orientation": "vertical",
@@ -100,13 +102,20 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
100102
horizontalalignment="right",
101103
size=20,
102104
)
103-
label_fontsize=13
104-
plt.setp(cg.ax_heatmap.get_yticklabels(), size=label_fontsize)
105-
plt.setp(cg.ax_heatmap.get_xticklabels(), size=label_fontsize)
105+
label_fontsize=15
106+
plt.setp(cg.ax_heatmap.get_yticklabels(), size=label_fontsize, weight="bold")
107+
plt.setp(cg.ax_heatmap.get_xticklabels(), size=label_fontsize, weight="bold")
106108
cg.ax_cbar.tick_params(labelsize=label_fontsize)
107109
cg.ax_row_dendrogram.set_visible(False)
108110
cg.ax_col_dendrogram.set_visible(False)
109111

110112
if title:
111-
plt.suptitle(title)
113+
plt.suptitle(title, fontsize=40, fontweight="bold")
114+
115+
if fig_path:
116+
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
117+
plt.show()
118+
112119

120+
def sobol_barplot():
121+
pass

0 commit comments

Comments
 (0)