Skip to content

Commit 91243cd

Browse files
updated sensitivity analysis
1 parent 2162d63 commit 91243cd

File tree

3 files changed

+156
-131
lines changed

3 files changed

+156
-131
lines changed

src/sbmlsim/sensitivity/example/sensitivity_example.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -128,29 +128,39 @@ def _sensitivity_parameters() -> list[SensitivityParameter]:
128128
SobolSensitivityAnalysis,
129129
SamplingSensitivityAnalysis,
130130
)
131-
132131
sensitivity_path = Path(__file__).parent / "results"
133132
console.print(SensitivityParameter.parameters_to_df(sensitivity_parameters))
134133

135-
SamplingSensitivityAnalysis.run_sensitivity_analysis(
136-
results_path=sensitivity_path / "sampling",
137-
sensitivity_simulation=sensitivity_simulation,
138-
parameters=sensitivity_parameters,
139-
groups=sensitivity_groups,
140-
# cache_results=False,
141-
# cache_sensitivity=False,
142-
N=200,
143-
seed=1234,
144-
)
134+
# SamplingSensitivityAnalysis.run_sensitivity_analysis(
135+
# results_path=sensitivity_path / "sampling",
136+
# sensitivity_simulation=sensitivity_simulation,
137+
# parameters=sensitivity_parameters,
138+
# groups=sensitivity_groups,
139+
# # cache_results=False,
140+
# # cache_sensitivity=False,
141+
# N=200,
142+
# seed=1234,
143+
# )
144+
#
145+
# LocalSensitivityAnalysis.run_sensitivity_analysis(
146+
# results_path=sensitivity_path / "local",
147+
# sensitivity_simulation=sensitivity_simulation,
148+
# parameters=sensitivity_parameters,
149+
# groups=[sensitivity_groups[1]],
150+
# # cache_results=False,
151+
# # cache_sensitivity=False,
152+
# difference=0.01,
153+
# seed=1234,
154+
# )
145155

146156
SobolSensitivityAnalysis.run_sensitivity_analysis(
147157
results_path=sensitivity_path / "sobol",
148158
sensitivity_simulation=sensitivity_simulation,
149159
parameters=sensitivity_parameters,
150-
groups=sensitivity_groups,
160+
groups=[sensitivity_groups[1]],
151161
# cache_results=False,
152162
# cache_sensitivity=False,
153-
# N=2048
163+
# N=2048,
154164
N=8,
155165
seed=1234,
156166
)

src/sbmlsim/sensitivity/sensitivity_local.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from pathlib import Path
2+
from typing import Optional
23

4+
import numpy as np
35
import xarray as xr
6+
from pymetadata.console import console
47

58
from sbmlsim.sensitivity.analysis import SensitivitySimulation, AnalysisGroup, \
69
SensitivityAnalysis
@@ -136,45 +139,61 @@ def calculate_sensitivity(self, cache_filename: Optional[str] = None, cache: boo
136139
# write to cache
137140
self.write_cache(data=self.sensitivity, cache_filename=cache_filename, cache=cache)
138141

139-
140-
def local_sensitivity_analysis():
141-
"""Local sensitivity analysis"""
142-
console.rule("LOCAL SENSITIVITY ANALYSIS", style="blue bold", align="center")
143-
144-
sensitivity_simulation = CanagliflozinSensitivitySimulation.sensitivity_simulation()
145-
parameters = sensitivity_simulation.sensitivity_parameters()
146-
147-
sa = LocalSensitivityAnalysis(
148-
sensitivity_simulation=sensitivity_simulation,
149-
parameters=sensitivity_parameters,
150-
groups=sensitivity_groups,
151-
results_path=RESULTS_PATH / "sensitivity",
152-
seed=1234,
153-
difference=0.01, # 1% change
154-
)
155-
156-
console.rule("Samples", style="white")
157-
sa.create_samples()
158-
159-
console.rule("Results", style="white")
160-
sa.simulate_samples()
161-
console.print(sa.results)
162-
163-
console.rule("Sensitivity", style="white")
164-
sa.calculate_sensitivity()
165-
console.print(sa.sensitivity)
166-
167-
console.rule("Plotting", style="white")
168-
for kg, group in enumerate(sa.groups):
169-
sa.plot_sensitivity(
170-
group_id=group.uid,
171-
sensitivity_key="normalized",
172-
# title=f"{group.name}",
173-
cutoff=0.05,
174-
cluster_rows = False,
175-
cmap = "seismic",
176-
vcenter=0.0,
177-
vmin=-2.0,
178-
vmax=2.0,
179-
fig_path=sa.results_path / f"local_sensitivity_{kg:>02}_{group.uid}_{sa.difference}.png",
142+
@staticmethod
143+
def run_sensitivity_analysis(
144+
results_path: Path,
145+
sensitivity_simulation: SensitivitySimulation,
146+
parameters: list[SensitivityParameter],
147+
groups: list[AnalysisGroup],
148+
seed: int,
149+
difference: float = 0.01,
150+
cache_results: bool = False,
151+
cache_sensitivity: bool = False,
152+
) -> None:
153+
"""Local sensitivity analysis.
154+
155+
:param sensitivity_simulation: Sensitivity simulation.
156+
:param parameters: Sensitivity parameters.
157+
:param groups: Sensitivity groups.
158+
:param difference: relative change of parameters.
159+
:param seed: Random seed.
160+
"""
161+
console.rule("LOCAL SENSITIVITY ANALYSIS", style="blue bold", align="center")
162+
if cache_sensitivity and not cache_results:
163+
# sensitivity must be recalculated for new results
164+
cache_sensitivity = False
165+
166+
sa = LocalSensitivityAnalysis(
167+
sensitivity_simulation=sensitivity_simulation,
168+
parameters=parameters,
169+
groups=groups,
170+
results_path=results_path,
171+
seed=1234,
172+
difference=difference,
180173
)
174+
console.rule("Samples", style="white")
175+
sa.create_samples()
176+
console.print(sa.samples_table())
177+
178+
console.rule("Results", style="white")
179+
sa.simulate_samples(cache_filename=f"local_results_difference{sa.difference}.pkl", cache=cache_results)
180+
console.print(sa.results_table())
181+
182+
console.rule("Sensitivity", style="white")
183+
sa.calculate_sensitivity(cache_filename=f"local_sensitivity_difference{sa.difference}.pkl", cache=cache_sensitivity)
184+
# console.print(sa.sensitivity_tables())
185+
186+
console.rule("Plotting", style="white")
187+
for kg, group in enumerate(sa.groups):
188+
sa.plot_sensitivity(
189+
group_id=group.uid,
190+
sensitivity_key="normalized",
191+
# title=f"{group.name}",
192+
cutoff=0.05,
193+
cluster_rows=False,
194+
cmap="seismic",
195+
vcenter=0.0,
196+
vmin=-2.0,
197+
vmax=2.0,
198+
fig_path=sa.results_path / f"local_sensitivity_{kg:>02}_{group.uid}_{sa.difference}.png",
199+
)

src/sbmlsim/sensitivity/sensitivity_sobol.py

Lines changed: 73 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def create_samples(self) -> None:
6464
The Sobol' sequence is a popular quasi-random low-discrepancy sequence used
6565
to generate uniform samples of parameter space.
6666
"""
67-
console.rule("Samples", style="white")
6867
# (num_samples x num_outputs)
6968
# total model evaluations are (2d+2) * N for d input factors
7069
num_samples = (2 * self.num_parameters + 2) * self.N
@@ -81,7 +80,6 @@ def create_samples(self) -> None:
8180
"parameter": self.parameter_ids},
8281
name="samples"
8382
)
84-
console.print(self.samples)
8583

8684

8785
def calculate_sensitivity(self, cache_filename: Optional[str] = None, cache: bool = False):
@@ -128,40 +126,6 @@ def calculate_sensitivity(self, cache_filename: Optional[str] = None, cache: boo
128126
# write to cache
129127
self.write_cache(data=self.sensitivity, cache_filename=cache_filename, cache=cache)
130128

131-
def plot_sobol_indices(
132-
self,
133-
fig_path: Path,
134-
):
135-
"""Barplots for the Sobol indices."""
136-
# parameter_labels: dict[str, str] = {p.uid: f"{p.uid}: {p.name}" for p in self.parameters}
137-
parameter_labels: dict[str, str] = {p.uid: p.uid for p in self.parameters}
138-
output_labels: dict[str, str] = {q.uid: q.name for q in self.outputs}
139-
140-
for group in self.groups:
141-
gid = group.uid
142-
ymax = self.sensitivity[gid]["ST"].max(dim=None)
143-
ymin = self.sensitivity[gid]["S1"].min(dim=None)
144-
145-
for ko, output in enumerate(self.outputs):
146-
# f_path = fig_path.parent / f"FigS{ko+22}_{fig_path.stem}_{ko:>03}_{output.uid}{fig_path.suffix}"
147-
f_path = fig_path.parent / f"{fig_path.stem}_{ko:>03}_{output.uid}{fig_path.suffix}"
148-
149-
S1 = self.sensitivity[gid]["S1"][:, ko]
150-
ST = self.sensitivity[gid]["ST"][:, ko]
151-
S1_conf = self.sensitivity[gid]["S1_conf"][:, ko]
152-
ST_conf = self.sensitivity[gid]["ST_conf"][:, ko]
153-
sobol_barplot(
154-
S1=S1,
155-
ST=ST,
156-
S1_conf=S1_conf,
157-
ST_conf=ST_conf,
158-
title=f"{output_labels[output.uid]} ({group.name})",
159-
fig_path=f_path,
160-
parameter_labels=parameter_labels,
161-
ymax=np.max([1.05, ymax]),
162-
ymin=np.min([-0.05, ymin]),
163-
)
164-
165129
@staticmethod
166130
def run_sensitivity_analysis(
167131
results_path: Path,
@@ -192,9 +156,8 @@ def run_sensitivity_analysis(
192156
parameters=parameters,
193157
groups=groups,
194158
results_path=results_path,
195-
# N=4096,
196-
N=16,
197-
seed=1234,
159+
N=N,
160+
seed=seed,
198161
)
199162

200163
console.rule("Samples", style="white")
@@ -232,46 +195,79 @@ def run_sensitivity_analysis(
232195
fig_path=sa.results_path / f"sobol_sensitivity_N{sa.N}_{kg:>02}_{group.uid}.png",
233196
)
234197

198+
def plot_sobol_indices(
199+
self,
200+
fig_path: Path,
201+
):
202+
"""Barplots for the Sobol indices."""
203+
# parameter_labels: dict[str, str] = {p.uid: f"{p.uid}: {p.name}" for p in self.parameters}
204+
parameter_labels: dict[str, str] = {p.uid: p.uid for p in self.parameters}
205+
output_labels: dict[str, str] = {q.uid: q.name for q in self.outputs}
235206

207+
for group in self.groups:
208+
gid = group.uid
209+
ymax = self.sensitivity[gid]["ST"].max(dim=None)
210+
ymin = self.sensitivity[gid]["S1"].min(dim=None)
236211

237-
def sobol_barplot(
238-
S1, ST, S1_conf, ST_conf,
239-
parameter_labels: dict[str, str],
240-
fig_path: Optional[Path] = None,
241-
title: Optional[str] = None,
242-
ymax: float = 1.1,
243-
ymin: float = -0.1,
244-
):
245-
# width
246-
figsize = (15, 3)
247-
label_fontsize = 15
248-
249-
categories: list[str] = list(parameter_labels.values())
250-
f, ax = plt.subplots(figsize=figsize)
251-
252-
ax.bar(categories, ST, label='ST',
253-
color="tab:orange",
254-
alpha=1.0,
255-
edgecolor="black",
256-
yerr=ST_conf, capsize=5
257-
)
258-
259-
ax.bar(categories, S1, label='S1', color="tab:blue",
260-
edgecolor="black", yerr=S1_conf, capsize=5)
261-
262-
263-
# ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
264-
ax.set_ylabel('Sobol Index', fontsize=label_fontsize, fontweight="bold")
265-
ax.set_ylim(bottom=ymin, top=ymax)
266-
ax.grid(True, axis="y")
267-
ax.tick_params(axis='x', labelrotation=90)
268-
# ax.tick_params(axis='x', labelweight='bold')
269-
ax.legend()
212+
for ko, output in enumerate(self.outputs):
213+
# f_path = fig_path.parent / f"FigS{ko+22}_{fig_path.stem}_{ko:>03}_{output.uid}{fig_path.suffix}"
214+
f_path = fig_path.parent / f"{fig_path.stem}_{ko:>03}_{output.uid}{fig_path.suffix}"
270215

271-
if title:
272-
plt.suptitle(title, fontsize=20, fontweight="bold")
216+
S1 = self.sensitivity[gid]["S1"][:, ko]
217+
ST = self.sensitivity[gid]["ST"][:, ko]
218+
S1_conf = self.sensitivity[gid]["S1_conf"][:, ko]
219+
ST_conf = self.sensitivity[gid]["ST_conf"][:, ko]
220+
self.sobol_barplot(
221+
S1=S1,
222+
ST=ST,
223+
S1_conf=S1_conf,
224+
ST_conf=ST_conf,
225+
title=f"{output_labels[output.uid]} ({group.name})",
226+
fig_path=f_path,
227+
parameter_labels=parameter_labels,
228+
ymax=np.max([1.05, ymax]),
229+
ymin=np.min([-0.05, ymin]),
230+
)
273231

274-
if fig_path:
275-
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
276-
plt.show()
232+
@staticmethod
233+
def sobol_barplot(
234+
S1, ST, S1_conf, ST_conf,
235+
parameter_labels: dict[str, str],
236+
fig_path: Optional[Path] = None,
237+
title: Optional[str] = None,
238+
ymax: float = 1.1,
239+
ymin: float = -0.1,
240+
):
241+
# width
242+
figsize = (15, 3)
243+
label_fontsize = 15
244+
245+
categories: list[str] = list(parameter_labels.values())
246+
f, ax = plt.subplots(figsize=figsize)
247+
248+
ax.bar(categories, ST, label='ST',
249+
color="tab:orange",
250+
alpha=1.0,
251+
edgecolor="black",
252+
yerr=ST_conf, capsize=5
253+
)
254+
255+
ax.bar(categories, S1, label='S1', color="tab:blue",
256+
edgecolor="black", yerr=S1_conf, capsize=5)
257+
258+
259+
# ax.set_xlabel('Parameter', fontsize=label_fontsize, fontweight="bold")
260+
ax.set_ylabel('Sobol Index', fontsize=label_fontsize, fontweight="bold")
261+
ax.set_ylim(bottom=ymin, top=ymax)
262+
ax.grid(True, axis="y")
263+
ax.tick_params(axis='x', labelrotation=90)
264+
# ax.tick_params(axis='x', labelweight='bold')
265+
ax.legend()
266+
267+
if title:
268+
plt.suptitle(title, fontsize=20, fontweight="bold")
269+
270+
if fig_path:
271+
plt.savefig(fig_path, dpi=300, bbox_inches="tight")
272+
plt.show()
277273

0 commit comments

Comments
 (0)