Skip to content

Commit 47cde3d

Browse files
visualization of sobol indices
1 parent deaf3b3 commit 47cde3d

File tree

2 files changed

+52
-124
lines changed

2 files changed

+52
-124
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 36 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -249,16 +249,34 @@ def sensitivity_df(self, key="normalized") -> pd.DataFrame:
249249
index=self.sensitivity[key].coords["parameter"]
250250
)
251251

252+
def plot_sensitivity(
253+
self,
254+
key: str, cutoff=0.1,
255+
cluster_rows: bool = True,
256+
title: Optional[str] = None,
257+
cmap: str = "seismic",
258+
**kwargs
259+
) -> None:
260+
df = self.sensitivity_df(key=key)
261+
heatmap(
262+
df=df,
263+
parameter_labels={p.uid: p.name for p in self.parameters},
264+
output_labels={q.uid: q.name for q in self.outputs},
265+
cutoff=cutoff,
266+
cluster_rows=cluster_rows,
267+
title=title,
268+
cmap=cmap,
269+
**kwargs
270+
)
271+
252272
import os
253273

254274
def run_simulation(
255275
params_tuple
256276
):
257277
"""Pass all required arguments as parameter tuple."""
258278
sensitivity_simulation, r, chunked_changes = params_tuple
259-
260279
outputs = []
261-
262280
for kc in track(range(len(chunked_changes)), description=f"Simulate samples PID={os.getpid()}"):
263281
changes = chunked_changes[kc]
264282
# console.print(f"PID={os.getpid()} | k={kc}")
@@ -271,16 +289,6 @@ def run_simulation(
271289
return outputs
272290

273291

274-
275-
276-
277-
278-
279-
280-
281-
282-
283-
284292
class LocalSensitivityAnalysis(SensitivityAnalysis):
285293
"""Local sensitivity analysis based on local differences.
286294
@@ -371,35 +379,7 @@ def calculate_sensitivity(self):
371379
sensitivity_normalized[kp, ko] = sensitivity_raw[kp, ko] * p_ref/q_ref
372380

373381

374-
def plot_sensitivity(self, cutoff=0.1, cluster_rows: bool = True, title: Optional[str] = None):
375-
df = self.sensitivity_df(key="normalized")
376-
self.plot_sensitivity_df(
377-
df=df,
378-
parameter_labels={p.uid: p.name for p in self.parameters},
379-
output_labels={q.uid: q.name for q in self.outputs},
380-
cutoff=cutoff,
381-
cluster_rows=cluster_rows,
382-
title=title
383-
)
384-
385-
@staticmethod
386-
def plot_sensitivity_df(
387-
df: pd.DataFrame,
388-
parameter_labels: dict[str, str],
389-
output_labels: dict[str, str],
390-
cutoff=0.1, cluster_rows: bool = True,
391-
title: Optional[str] = None,
392-
):
393-
console.print(df)
394382

395-
heatmap(
396-
df,
397-
parameter_labels=parameter_labels,
398-
output_labels=output_labels,
399-
cutoff=cutoff,
400-
cluster_rows=False,
401-
title=title,
402-
)
403383

404384

405385
@dataclass
@@ -457,50 +437,41 @@ def create_samples(self, N: int=1024):
457437

458438

459439
def calculate_sensitivity(self):
460-
# transfer results in libsa results format
440+
"""Calculate the sensitivity matrices."""
461441

462442
Y = self.results.values
463443
self.ssa_problem.set_results(Y)
464444

445+
# num_parameters x num_outputs
446+
sensitivity_keys = ["S1", "ST", "S1_conf", "ST_conf"]
447+
for key in sensitivity_keys:
448+
self.sensitivity[key] = xr.DataArray(
449+
np.full((self.num_parameters, self.num_outputs), np.nan),
450+
dims=["parameter", "output"],
451+
coords={"parameter": self.parameter_ids,
452+
"output": self.output_ids},
453+
name=key
454+
)
455+
465456
# Perform Analysis
466457
# Si is a Python dict-like with the keys "S1", "S2", "ST",
467458
# "S1_conf", "S2_conf", and "ST_conf".
468459
# The _conf keys store the corresponding confidence intervals,
469460
# typically with a confidence level of 95%.
470461

471462
# Calculate Sobol indices for every output
472-
Si_all = []
473463
for ko in range(self.num_outputs):
474464
Yo = Y[:, ko]
475465
Si = SALib.analyze.sobol.analyze(
476466
self.ssa_problem, Yo,
477467
calc_second_order=True,
478468
print_to_console=True,
479469
)
480-
Si_all.append(Si)
481-
482-
Si.plot()
483-
from matplotlib import pyplot as plt
484-
plt.show()
485-
486-
# Si = SALib.analyze.sobol.analyze(
487-
# self.ssa_problem, Y,
488-
# calc_second_order=True,
489-
# print_to_console=True,
490-
# )
491-
492-
# Store the sensitivity matrices
470+
console.print("S1")
471+
console.print(Si["S1"])
472+
for key in sensitivity_keys:
473+
self.sensitivity[key][:, ko] = Si[key]
493474

494-
sensitivity_total = Si['ST']
495-
sensitivity_first = Si['S1']
496-
print(Si['S1'])
497-
print(Si['ST'])
498-
499-
500-
501-
Si.plot()
502-
from matplotlib import pyplot as plt
503-
plt.show()
504475

505476
def plot(self):
506477
Si.plot()

src/sbmlsim/sensitivity/plots.py

Lines changed: 16 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ def heatmap(
1717
df: pd.DataFrame,
1818
parameter_labels: Optional[dict[str, str]] = None,
1919
output_labels: Optional[dict[str, str]] = None,
20-
cutoff: float=0.01,
20+
cutoff: float=0.1,
2121
annotate_values=True,
2222
cluster_rows: bool = True, # cluster parameters
2323
cluster_cols: bool = False, # cluster outputs
24-
transpose: bool=False,
2524
title: Optional[str] = None,
25+
cmap: str = "seismic",
26+
vcenter: float = 0.0,
27+
vmin: float = -2.0,
28+
vmax: float = 2.0,
2629
):
2730
"""Creates heatmap of model sensitivity"""
2831

@@ -47,7 +50,9 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
4750

4851
if cutoff > 0:
4952
df_subset = calculate_subset(df, cutoff=cutoff)
50-
df_subset_mask = calculate_mask(df_subset, cutoff)
53+
else:
54+
df_subset = df
55+
df_subset_mask = calculate_mask(df_subset, cutoff)
5156

5257
# outputs
5358
xticklabels = [qid for qid in df_subset.columns]
@@ -62,19 +67,18 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
6267

6368
n_outputs = df_subset.shape[1]
6469
n_parameters = df_subset.shape[0]
65-
figsize = (int(n_outputs/n_parameters*30), 15)
70+
figsize = (int(n_outputs/n_parameters*15), 15)
6671

67-
colorbar_range = 2.0
6872

6973
# plot heatmap
7074
cg = sns.clustermap(
7175
df_subset,
72-
center=0,
73-
vmin=-colorbar_range,
74-
vmax=colorbar_range,
76+
center=vcenter,
77+
vmin=vmin,
78+
vmax=vmax,
7579
xticklabels=xticklabels,
7680
yticklabels=yticklabels,
77-
cmap="seismic",
81+
cmap=cmap,
7882
# cbar_pos=(0.0, 0.0, 0.6, 0.05), # (left, bottom, width, height),
7983
cbar_pos=(0.0, 0.4, 0.03, 0.2), # (left, bottom, width, height),
8084
cbar_kws={
@@ -85,8 +89,8 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
8589
fmt="1.2f",
8690
annot_kws={"size": 11},
8791
mask=df_subset_mask,
88-
col_cluster=False,
89-
row_cluster=True,
92+
col_cluster=cluster_cols,
93+
row_cluster=cluster_rows,
9094
method="single",
9195
figsize=figsize,
9296
)
@@ -96,7 +100,7 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
96100
horizontalalignment="right",
97101
size=20,
98102
)
99-
label_fontsize=10
103+
label_fontsize=13
100104
plt.setp(cg.ax_heatmap.get_yticklabels(), size=label_fontsize)
101105
plt.setp(cg.ax_heatmap.get_xticklabels(), size=label_fontsize)
102106
cg.ax_cbar.tick_params(labelsize=label_fontsize)
@@ -106,50 +110,3 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
106110
if title:
107111
plt.suptitle(title)
108112

109-
# for label in cg.ax_heatmap.get_xticklabels():
110-
# label.set_bbox(dict(facecolor='tab:blue', edgecolor='black', alpha=0.8))
111-
#
112-
# for label in cg.ax_heatmap.get_yticklabels():
113-
# label.set_bbox(dict(facecolor='tab:orange', edgecolor='black', alpha=0.8))
114-
115-
# create custom legend containing yticklabels and their description
116-
# handles = [t.get_text() for t in ax.ax_heatmap.get_yticklabels()]
117-
# labels = [pnames[pid]["label"] for pid in handles]
118-
#
119-
# # FIXME: update after defining labels
120-
# idx = [pnames[pid]["idx"] for pid in handles]
121-
# # idx = [k for k, pid in enumerate(handles)]
122-
#
123-
# labels = [label for _, label in sorted(zip(idx, labels))]
124-
# handles = [f"{handle}:" for _, handle in sorted(zip(idx, handles))]
125-
# handles = [handle.replace("_", "\_") for handle in handles]
126-
127-
# mid = int(np.ceil(len(handles) / 2))
128-
# legend1 = plt.legend(
129-
# handles[:mid],
130-
# labels[:mid],
131-
# handler_map={str: LegendTitle({"fontsize": 16})},
132-
# fontsize=16,
133-
# frameon=False,
134-
# bbox_to_anchor=(1.2, -0.6),
135-
# loc="upper left",
136-
# handlelength=14,
137-
# )
138-
# legend2 = plt.legend(
139-
# handles[mid:],
140-
# labels[mid:],
141-
# handler_map={str: LegendTitle({"fontsize": 16})},
142-
# fontsize=16,
143-
# frameon=False,
144-
# bbox_to_anchor=(13, -0.6),
145-
# loc="upper left",
146-
# handlelength=19,
147-
# )
148-
# plt.gca().add_artist(legend1)
149-
150-
# plt.savefig(
151-
# results_dir / "parameter.sensitivity_cluster.png", dpi=300, bbox_inches="tight"
152-
# )
153-
# plt.savefig(results_dir / "parameter.sensitivity_cluster.svg", bbox_inches="tight")
154-
155-
# plt.show()

0 commit comments

Comments
 (0)