Skip to content

Commit c5c3e0b

Browse files
support for FAST sensitivity
1 parent efcc49e commit c5c3e0b

File tree

5 files changed

+326
-101
lines changed

5 files changed

+326
-101
lines changed

src/sbmlsim/sensitivity/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
within a consistent workflow.
1515
1616
TODO implementation of alternative methods:
17-
- [ ] FAST
1817
- [ ] Morris
1918
2019
FIXME: generate simple example
@@ -33,6 +32,7 @@
3332
from .sensitivity_local import LocalSensitivityAnalysis
3433
from .sensitivity_sampling import SamplingSensitivityAnalysis
3534
from .sensitivity_sobol import SobolSensitivityAnalysis
35+
from .sensitivity_fast import FASTSensitivityAnalysis
3636

3737
__all__ = [
3838
"SensitivityParameter",
@@ -42,4 +42,5 @@
4242
"SobolSensitivityAnalysis",
4343
"SamplingSensitivityAnalysis",
4444
"LocalSensitivityAnalysis",
45+
"FASTSensitivityAnalysis",
4546
]

src/sbmlsim/sensitivity/example/sensitivity_example.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@
1818
# subgroups to perform sensitivity analysis on
1919
sensitivity_groups: list[AnalysisGroup] = [
2020
AnalysisGroup(
21-
uid="low S1",
21+
uid="lowS1",
2222
name="Low S1",
2323
changes={"[S1]": 0.1},
2424
color="tab:red",
2525
),
2626
AnalysisGroup(
27-
uid="reference S1",
27+
uid="refS1",
2828
name="Reference S1",
2929
changes={"[S1]": 1},
3030
color="dimgrey",
3131
),
3232
AnalysisGroup(
33-
uid="high S1",
33+
uid="highS1",
3434
name="High S1",
3535
changes={"[S1]": 10},
3636
color="tab:blue",
@@ -116,18 +116,20 @@ def _sensitivity_parameters() -> list[SensitivityParameter]:
116116
LocalSensitivityAnalysis,
117117
SobolSensitivityAnalysis,
118118
SamplingSensitivityAnalysis,
119+
FASTSensitivityAnalysis,
119120
)
120121

121122
sensitivity_path = Path(__file__).parent / "results"
122123
console.print(SensitivityParameter.parameters_to_df(sensitivity_parameters))
124+
cache = False
123125

124126
SamplingSensitivityAnalysis.run_sensitivity_analysis(
125127
results_path=sensitivity_path / "sampling",
126128
sensitivity_simulation=sensitivity_simulation,
127129
parameters=sensitivity_parameters,
128130
groups=sensitivity_groups,
129-
cache_results=True,
130-
cache_sensitivity=True,
131+
cache_results=cache,
132+
cache_sensitivity=cache,
131133
N=1000,
132134
seed=1234,
133135
)
@@ -137,8 +139,8 @@ def _sensitivity_parameters() -> list[SensitivityParameter]:
137139
sensitivity_simulation=sensitivity_simulation,
138140
parameters=sensitivity_parameters,
139141
groups=[sensitivity_groups[1]],
140-
cache_results=True,
141-
cache_sensitivity=True,
142+
cache_results=cache,
143+
cache_sensitivity=cache,
142144
difference=0.01,
143145
seed=1234,
144146
)
@@ -148,9 +150,20 @@ def _sensitivity_parameters() -> list[SensitivityParameter]:
148150
sensitivity_simulation=sensitivity_simulation,
149151
parameters=sensitivity_parameters,
150152
groups=[sensitivity_groups[1]],
151-
cache_results=True,
152-
cache_sensitivity=True,
153-
N=2048,
153+
cache_results=cache,
154+
cache_sensitivity=cache,
155+
N=4096,
154156
# N=8,
155157
seed=1234,
156158
)
159+
160+
FASTSensitivityAnalysis.run_sensitivity_analysis(
161+
results_path=sensitivity_path / "fast",
162+
sensitivity_simulation=sensitivity_simulation,
163+
parameters=sensitivity_parameters,
164+
groups=[sensitivity_groups[1]],
165+
cache_results=cache,
166+
cache_sensitivity=cache,
167+
N=1000,
168+
seed=1234,
169+
)

src/sbmlsim/sensitivity/plots.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,76 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
108108
if fig_path:
109109
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
110110
plt.show()
111+
112+
113+
def plot_S1_ST_indices(
114+
sa, # SensitivityAnalysis,
115+
fig_path: Path,
116+
):
117+
"""Barplots for the S1 and ST indices."""
118+
parameter_labels: dict[str, str] = {p.uid: p.uid for p in sa.parameters}
119+
output_labels: dict[str, str] = {q.uid: q.name for q in sa.outputs}
120+
121+
for group in sa.groups:
122+
gid = group.uid
123+
ymax = sa.sensitivity[gid]["ST"].max(dim=None)
124+
ymin = sa.sensitivity[gid]["S1"].min(dim=None)
125+
126+
for ko, output in enumerate(sa.outputs):
127+
f_path = fig_path.parent / f"{fig_path.stem}_{ko:>03}_{output.uid}{fig_path.suffix}"
128+
129+
S1 = sa.sensitivity[gid]["S1"][:, ko]
130+
ST = sa.sensitivity[gid]["ST"][:, ko]
131+
S1_conf = sa.sensitivity[gid]["S1_conf"][:, ko]
132+
ST_conf = sa.sensitivity[gid]["ST_conf"][:, ko]
133+
S1_ST_barplot(
134+
S1=S1,
135+
ST=ST,
136+
S1_conf=S1_conf,
137+
ST_conf=ST_conf,
138+
title=f"{output_labels[output.uid]} ({group.name})",
139+
fig_path=f_path,
140+
parameter_labels=parameter_labels,
141+
ymax=np.max([1.05, ymax]),
142+
ymin=np.min([-0.05, ymin]),
143+
)
144+
145+
def S1_ST_barplot(
146+
S1, ST, S1_conf, ST_conf,
147+
parameter_labels: dict[str, str],
148+
fig_path: Optional[Path] = None,
149+
title: Optional[str] = None,
150+
ymax: float = 1.1,
151+
ymin: float = -0.1,
152+
):
153+
# width
154+
figsize = (15, 3)
155+
label_fontsize = 15
156+
157+
categories: list[str] = list(parameter_labels.values())
158+
f, ax = plt.subplots(figsize=figsize)
159+
160+
ax.bar(categories, ST, label='ST',
161+
color="black",
162+
alpha=1.0,
163+
edgecolor="black",
164+
yerr=ST_conf, capsize=5
165+
)
166+
167+
ax.bar(categories, S1, label='S1', color="tab:blue",
168+
edgecolor="black", yerr=S1_conf, capsize=5)
169+
170+
# ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
171+
ax.set_ylabel('Sensitivity', fontsize=label_fontsize, fontweight="bold")
172+
ax.set_ylim(bottom=ymin, top=ymax)
173+
ax.grid(True, axis="y")
174+
ax.tick_params(axis='x', labelrotation=90)
175+
# ax.tick_params(axis='x', labelweight='bold')
176+
ax.legend()
177+
178+
if title:
179+
plt.suptitle(title, fontsize=20, fontweight="bold")
180+
181+
if fig_path:
182+
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
183+
plt.show()
Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
"""
2+
Global sensitivity analysis using FAST - Fourier Amplitude Sensitivity Test.
3+
4+
Cukier, R.I., Fortuin, C.M., Shuler, K.E., Petschek, A.G., Schaibly,
5+
J.H., 1973. Study of the sensitivity of coupled reaction systems to uncertainties
6+
in rate coefficients. I theory. Journal of Chemical Physics 59, 3873-3878. https://doi.org/10.1063/1.1680571
7+
Saltelli, A., S. Tarantola, and K. P.-S. Chan (1999). A Quantitative
8+
Model-Independent Method for Global Sensitivity Analysis of Model Output.
9+
Technometrics, 41(1):39-56, doi:10.1080/00401706.1999.10485594.
10+
11+
"""
12+
13+
from pathlib import Path
14+
from typing import Optional
15+
16+
import SALib
17+
import numpy as np
18+
import xarray as xr
19+
from SALib import ProblemSpec
20+
from SALib.analyze import fast
21+
from SALib.sample import fast_sampler
22+
from matplotlib import pyplot as plt
23+
from pymetadata.console import console
24+
25+
from sbmlsim.sensitivity.analysis import SensitivityAnalysis, SensitivitySimulation, \
26+
AnalysisGroup
27+
from sbmlsim.sensitivity.parameters import SensitivityParameter
28+
from sbmlsim.sensitivity.plots import plot_S1_ST_indices
29+
30+
31+
class FASTSensitivityAnalysis(SensitivityAnalysis):
32+
"""Global sensitivity analysis based Fourier Amplitude Sensitivity Test (FAST)
33+
(Cukier et al. 1973, Saltelli et al. 1999)."""
34+
35+
sensitivity_keys = ["S1", "ST", "S1_conf", "ST_conf"]
36+
37+
def __init__(
38+
self,
39+
sensitivity_simulation: SensitivitySimulation,
40+
parameters: list[SensitivityParameter],
41+
groups: list[AnalysisGroup],
42+
results_path: Path,
43+
N: int,
44+
M: int = 4,
45+
**kwargs,
46+
):
47+
"""
48+
N (int) – The number of samples to generate
49+
M (int) – The interference parameter, i.e., the number of harmonics to sum
50+
in the Fourier series decomposition (default 4)
51+
52+
The Sobol' sequence is a popular quasi-random low-discrepancy sequence used
53+
to generate uniform samples of parameter space.
54+
"""
55+
56+
super().__init__(sensitivity_simulation, parameters, groups, results_path,
57+
**kwargs)
58+
self.N: int = N
59+
self.M: int = M
60+
61+
# define the problem specification
62+
self.ssa_problems: dict[str, ProblemSpec] = {}
63+
for group in self.groups:
64+
self.ssa_problems[group.uid] = ProblemSpec({
65+
'num_vars': self.num_parameters,
66+
'names': self.parameter_ids,
67+
'bounds': [[p.lower_bound, p.upper_bound] for p in self.parameters],
68+
"outputs": self.output_ids,
69+
})
70+
71+
def create_samples(self) -> None:
72+
"""Create samples for FAST."""
73+
# (num_samples x num_outputs)
74+
# total model evaluations are N * num_parameters
75+
num_samples = self.N * self.num_parameters
76+
77+
for gid in self.group_ids:
78+
# libssa samples based on definition
79+
ssa_samples = fast_sampler.sample(
80+
self.ssa_problems[gid], N=self.N, M=self.M,
81+
)
82+
self.ssa_problems[gid].set_samples(ssa_samples)
83+
84+
self.samples[gid] = xr.DataArray(
85+
ssa_samples,
86+
dims=["sample", "parameter"],
87+
coords={"sample": range(num_samples),
88+
"parameter": self.parameter_ids},
89+
name="samples"
90+
)
91+
92+
def calculate_sensitivity(self, cache_filename: Optional[str] = None,
93+
cache: bool = False):
94+
""" Perform extended Fourier Amplitude Sensitivity Test on model outputs.
95+
96+
Returns a dictionary with keys 'S1' and 'ST', where each entry is a list of
97+
size D (the number of parameters) containing the indices in the same order
98+
as the parameter file.
99+
"""
100+
101+
data = self.read_cache(cache_filename, cache)
102+
if data:
103+
self.sensitivity = data
104+
return
105+
106+
for gid in self.group_ids:
107+
Y = self.results[gid].values
108+
self.ssa_problems[gid].set_results(Y)
109+
110+
# num_parameters x num_outputs
111+
for key in self.sensitivity_keys:
112+
self.sensitivity[gid][key] = xr.DataArray(
113+
np.full((self.num_parameters, self.num_outputs), np.nan),
114+
dims=["parameter", "output"],
115+
coords={"parameter": self.parameter_ids,
116+
"output": self.output_ids},
117+
name=key
118+
)
119+
120+
# Calculate FAST indices
121+
for ko in range(self.num_outputs):
122+
Yo = Y[:, ko]
123+
Si = SALib.analyze.fast.analyze(
124+
self.ssa_problems[gid], Yo,
125+
M=self.M,
126+
num_resamples=100,
127+
conf_level=0.95,
128+
print_to_console=False,
129+
)
130+
for key in self.sensitivity_keys:
131+
self.sensitivity[gid][key][:, ko] = Si[key]
132+
133+
# write to cache
134+
self.write_cache(data=self.sensitivity, cache_filename=cache_filename,
135+
cache=cache)
136+
137+
@staticmethod
138+
def run_sensitivity_analysis(
139+
results_path: Path,
140+
sensitivity_simulation: SensitivitySimulation,
141+
parameters: list[SensitivityParameter],
142+
groups: list[AnalysisGroup],
143+
N: int,
144+
seed: int,
145+
M: int = 4,
146+
cache_results: bool = False,
147+
cache_sensitivity: bool = False,
148+
) -> None:
149+
"""FAST sensitivity analysis.
150+
151+
First-order FAST (main effects only):
152+
100 × num_pars samples is usually sufficient
153+
154+
Extended FAST (eFAST, total effects):
155+
200–500 × k samples recommended
156+
(higher frequencies needed to separate interactions)
157+
158+
:param sensitivity_simulation: Sensitivity simulation.
159+
:param parameters: Sensitivity parameters.
160+
:param groups: Sensitivity groups.
161+
N (int) – The number of samples to generate
162+
M (int) – The interference parameter, i.e., the number of harmonics to sum
163+
:param seed: Random seed.
164+
"""
165+
prefix = "fast"
166+
console.rule(f"{prefix.upper()} SENSITIVITY ANALYSIS", style="blue bold", align="center")
167+
if cache_sensitivity and not cache_results:
168+
# sensitivity must be recalculated for new results
169+
cache_sensitivity = False
170+
171+
sa = FASTSensitivityAnalysis(
172+
sensitivity_simulation=sensitivity_simulation,
173+
parameters=parameters,
174+
groups=groups,
175+
results_path=results_path,
176+
N=N,
177+
M=M,
178+
seed=seed,
179+
)
180+
181+
console.rule("Samples", style="white")
182+
sa.create_samples()
183+
console.print(sa.samples_table())
184+
185+
console.rule("Results", style="white")
186+
sa.simulate_samples(cache_filename=f"{prefix}_results_N{sa.N}.pkl",
187+
cache=cache_results)
188+
console.print(sa.results_table())
189+
190+
console.rule("Sensitivity", style="white")
191+
sa.calculate_sensitivity(cache_filename=f"{prefix}_sensitivity_N{sa.N}.pkl",
192+
cache=cache_sensitivity)
193+
194+
console.rule("Plotting", style="white")
195+
for kg, group in enumerate(sa.groups):
196+
# heatmaps
197+
for key in ["ST", "S1"]:
198+
sa.plot_sensitivity(
199+
group_id=group.uid,
200+
sensitivity_key=key,
201+
# title=f"{key} {group.name}",
202+
cutoff=0.05,
203+
cluster_rows=False,
204+
cmap="viridis",
205+
vcenter=0.5,
206+
vmin=0.0,
207+
vmax=1.0,
208+
fig_path=sa.results_path / f"{prefix}_sensitivity_N{sa.N}_{kg:>02}_{group.uid}_{key}.png"
209+
)
210+
211+
# barplots
212+
plot_S1_ST_indices(
213+
sa=sa,
214+
fig_path=sa.results_path / f"{prefix}_sensitivity_N{sa.N}_{kg:>02}_{group.uid}.png",
215+
)

0 commit comments

Comments
 (0)