Skip to content

Commit b42df60

Browse files
local sensitivity analysis
1 parent 6de5dc2 commit b42df60

File tree

3 files changed

+70
-29
lines changed

3 files changed

+70
-29
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@
4848
from sbmlsim.sensitivity.outputs import SensitivityOutput
4949
import pandas as pd
5050

51+
@dataclass
52+
class SensitivityOutput:
53+
"""Output for SensitivityAnalysis"""
54+
uid: str
55+
name: str
56+
unit: Optional[str] = None
57+
58+
def __hash__(self):
59+
return hash(self.uid)
60+
5161
@dataclass
5262
class SensitivitySimulation:
5363
"""Base class for sensitivity calculation.
@@ -60,10 +70,10 @@ class SensitivitySimulation:
6070
model_path: Path
6171
selections: list[str]
6272
rr: roadrunner.RoadRunner = None
63-
outputs: list[str] = None
73+
outputs: list[SensitivityOutput] = None
6474
changes_simulation: dict[str, float] = None
6575

66-
def __init__(self, model_path: Path, selections: list[str], changes_simulation: dict[str, float]):
76+
def __init__(self, model_path: Path, selections: list[str], changes_simulation: dict[str, float], outputs: list[SensitivityOutput]):
6777
self.model_path = model_path
6878
self.selections = selections
6979
self.rr: roadrunner.RoadRunner = roadrunner.RoadRunner(str(model_path))
@@ -74,10 +84,15 @@ def __init__(self, model_path: Path, selections: list[str], changes_simulation:
7484

7585
# store the simulation changes
7686
self.changes_simulation = changes_simulation
87+
self.outputs: list[SensitivityOutput] = outputs
7788

78-
# get the outputs from the simulation
89+
# validate the outputs from the simulation
7990
y = self.simulate(changes={})
80-
self.outputs = list(y.keys())
91+
outputs_dict = {q.uid for q in self.outputs}
92+
for key in y:
93+
if key not in outputs_dict:
94+
raise ValueError(f"Key '{key}' missing in outputs dictionary: '{outputs_dict}")
95+
8196

8297

8398
# def output_definitions(self) -> list[SensitivityOutput]:
@@ -118,17 +133,13 @@ def apply_changes(self, changes: dict[str, float], reset_all: bool=True) -> None
118133

119134

120135

121-
@dataclass
136+
122137
class SensitivityAnalysis:
123138
"""Parent class for all sensitivity analysis.
124139
125140
TODO: additional metadata for the outputs and the parameters; i.e. name, units, bounds, ....
126141
"""
127142

128-
sensitivity_simulation: SensitivitySimulation
129-
parameters: list[SensitivityParameter]
130-
outputs: list[str]
131-
132143
def __init__(self, sensitivity_simulation: SensitivitySimulation,
133144
parameters: SensitivityParameter) -> None:
134145
"""Create a sensitivity analysis for given parameter ids.
@@ -139,8 +150,11 @@ def __init__(self, sensitivity_simulation: SensitivitySimulation,
139150

140151
# parameters to vary; shape: (num_parameters,)
141152
self.parameters: list[SensitivityParameter] = parameters
153+
self.parameter_ids: list[str] = [p.uid for p in self.parameters]
154+
142155
# outputs to calculate sensitivity on; shape: (num_outputs,)
143-
self.outputs: list[output] = sensitivity_simulation.outputs
156+
self.outputs: list[SensitivityOutput] = sensitivity_simulation.outputs
157+
self.output_ids: list[str] = [q.uid for q in self.outputs]
144158

145159
# parameter samples for sensitivity; shape: (num_samples x num_parameters)
146160
self.samples: Optional[xr.DataArray] = None
@@ -235,7 +249,7 @@ def create_samples(self) -> None:
235249
samples = xr.DataArray(
236250
np.full((num_samples, self.num_parameters), np.nan),
237251
dims=["sample", "parameter"],
238-
coords={"sample": range(num_samples), "parameter": self.parameters},
252+
coords={"sample": range(num_samples), "parameter": [p.uid for p in self.parameters]},
239253
name="samples"
240254
)
241255

@@ -262,15 +276,15 @@ def calculate_sensitivity(self):
262276
self.sensitivity = xr.DataArray(
263277
np.full((self.num_parameters, self.num_outputs), np.nan),
264278
dims=["parameter", "output"],
265-
coords={"parameter": [p.uid for p in self.parameters],
266-
"output": self.outputs},
279+
coords={"parameter": self.parameter_ids,
280+
"output": self.output_ids},
267281
name="sensitivity"
268282
)
269283
self.sensitivity_normalized = xr.DataArray(
270284
np.full((self.num_parameters, self.num_outputs), np.nan),
271285
dims=["parameter", "output"],
272-
coords={"parameter": [p.uid for p in self.parameters],
273-
"output": self.outputs},
286+
coords={"parameter": self.parameter_ids,
287+
"output": self.output_ids},
274288
name="sensitivity"
275289
)
276290

@@ -302,21 +316,33 @@ def sensitivity_df(self) -> pd.DataFrame:
302316

303317
def plot_sensitivity(self):
304318
df = self.sensitivity_df
305-
self.plot_sensitivity_df(df)
319+
self.plot_sensitivity_df(
320+
df=df,
321+
parameter_labels={p.uid: p.name for p in self.parameters},
322+
output_labels={q.uid: q.name for q in self.outputs},
323+
)
306324

307325
@staticmethod
308-
def plot_sensitivity_df(df: pd.DataFrame, cutoff=0.1, cluster_rows: bool = True):
326+
def plot_sensitivity_df(
327+
df: pd.DataFrame,
328+
parameter_labels: dict[str, str],
329+
output_labels: dict[str, str],
330+
cutoff=0.1, cluster_rows: bool = True
331+
):
309332
from sbmlsim.sensitivity.plots import heatmap
310333
console.print(df)
311334

312335
# TODO: labels of parameters
313336
# TODO: labels of outputs
314337
# TODO: better position of colorbar
315338

316-
heatmap(df, cutoff=cutoff, cluster_rows=False)
317-
318-
319-
339+
heatmap(
340+
df,
341+
parameter_labels=parameter_labels,
342+
output_labels=output_labels,
343+
cutoff=cutoff,
344+
cluster_rows=False
345+
)
320346

321347
@dataclass
322348
class SamplingSensitivityAnalysis(SensitivityAnalysis):

src/sbmlsim/sensitivity/parameters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __hash__(self):
2424
return hash(self.uid)
2525

2626

27+
28+
2729
def parameters_for_sensitivity_analysis(
2830
sbml_path: Path,
2931
exclude_ids: Optional[set[str]] = None,
@@ -113,7 +115,7 @@ def parameter_from_sbase(sbase: libsbml.SBase) -> SensitivityParameter:
113115
else:
114116
parameters_filtered.append(sp)
115117

116-
console.print(f"Excluded parameters: {parameters_excluded}")
118+
console.print(f"Excluded parameters: {[sp.uid for sp in parameters_excluded]}")
117119

118120
return parameters_filtered
119121

@@ -128,5 +130,4 @@ def parameter_from_sbase(sbase: libsbml.SBase) -> SensitivityParameter:
128130
}
129131
)
130132

131-
console.print("finished")
132133
console.print(parameters)

src/sbmlsim/sensitivity/plots.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,19 @@
44
https://stackoverflow.com/questions/59381273/heatmap-with-circles-indicating-size-of-population
55
"""
66
import xarray as xr
7-
7+
from typing import Optional
88

99
from matplotlib import pyplot as plt
1010
import seaborn as sns
1111
import numpy as np
1212
import pandas as pd
13+
from sbmlutils.console import console
14+
1315

1416
def heatmap(
1517
df: pd.DataFrame,
18+
parameter_labels: Optional[dict[str, str]] = None,
19+
output_labels: Optional[dict[str, str]] = None,
1620
cutoff: float=0.01,
1721
annotate_values=True,
1822
cluster_rows: bool = True, # cluster parameters
@@ -44,12 +48,20 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
4448
df_subset = calculate_subset(df, cutoff=cutoff)
4549
df_subset_mask = calculate_mask(df_subset, cutoff)
4650

51+
# outputs
52+
xticklabels = [qid for qid in df_subset.columns]
53+
if output_labels:
54+
console.print(output_labels)
55+
xticklabels = [output_labels[qid] for qid in xticklabels]
56+
57+
# parameters
4758
yticklabels = [pid for pid in df_subset.index]
48-
xticklabels = [pid for pid in df_subset.columns]
59+
if parameter_labels:
60+
yticklabels = [f"{pid}: {parameter_labels[pid]}" for pid in yticklabels]
4961

5062
n_outputs = df_subset.shape[1]
5163
n_parameters = df_subset.shape[0]
52-
figsize = (7, int(n_parameters / n_outputs * 7)/2)
64+
figsize = (10, 15)
5365

5466
colorbar_range = 2.0
5567

@@ -66,7 +78,7 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
6678
cbar_pos=(0.0, 0.4, 0.03, 0.2), # (left, bottom, width, height),
6779
cbar_kws={
6880
"orientation": "vertical",
69-
"label": "sensitivity"
81+
# "label": "sensitivity"
7082
},
7183
annot=annotate_values,
7284
fmt="1.2f",
@@ -83,8 +95,10 @@ def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
8395
horizontalalignment="right",
8496
size=20,
8597
)
86-
plt.setp(ax.ax_heatmap.get_yticklabels(), size=20)
87-
ax.ax_cbar.tick_params(labelsize=20)
98+
label_fontsize=10
99+
plt.setp(ax.ax_heatmap.get_yticklabels(), size=label_fontsize)
100+
plt.setp(ax.ax_heatmap.get_xticklabels(), size=label_fontsize)
101+
ax.ax_cbar.tick_params(labelsize=label_fontsize)
88102
ax.ax_row_dendrogram.set_visible(False)
89103
ax.ax_col_dendrogram.set_visible(False)
90104

0 commit comments

Comments
 (0)