Skip to content

Commit 6de5dc2

Browse files
work on sensitivity
1 parent 60d320a commit 6de5dc2

File tree

3 files changed

+90
-34
lines changed

3 files changed

+90
-34
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
from sbmlsim.sensitivity.parameters import SensitivityParameter
4848
from sbmlsim.sensitivity.outputs import SensitivityOutput
49-
49+
import pandas as pd
5050

5151
@dataclass
5252
class SensitivitySimulation:
@@ -69,7 +69,7 @@ def __init__(self, model_path: Path, selections: list[str], changes_simulation:
6969
self.rr: roadrunner.RoadRunner = roadrunner.RoadRunner(str(model_path))
7070
self.rr.selections = self.selections
7171
integrator: roadrunner.Integrator = self.rr.integrator
72-
integrator.setSetting("variable_step_size", True)
72+
# integrator.setSetting("variable_step_size", True)
7373
# state = rr.saveStateS()
7474

7575
# store the simulation changes
@@ -148,6 +148,7 @@ def __init__(self, sensitivity_simulation: SensitivitySimulation,
148148
self.results: Optional[xr.DataArray] = None
149149
# sensitivity matrix; shape: (num_parameters x num_outputs); could be multiple
150150
self.sensitivity: Optional[xr.DataArray] = None
151+
self.sensitivity_normalized: Optional[xr.DataArray] = None
151152

152153

153154
@property
@@ -192,25 +193,22 @@ def simulate_samples(self) -> None:
192193

193194
def calculate_sensitivity(self):
194195
"""Calculate the sensitivity matrix."""
196+
pass
195197

196-
self.sensitivity = xr.DataArray(
197-
np.full((self.num_parameters, self.num_outputs), np.nan),
198-
dims=["parameter", "output"],
199-
coords={"parameter": [p.uid for p in self.parameters],
200-
"output": self.outputs},
201-
name="sensitivity"
202-
)
203198

204199

205200
@dataclass
206201
class LocalSensitivityAnalysis(SensitivityAnalysis):
207-
"""Local sensitivity analysis based on local differences."""
202+
"""Local sensitivity analysis based on local differences.
203+
204+
param difference: change for calculation of local sensitivity (0.01 = 1% change)
205+
"""
208206

209207
difference: float
210208
sensitivity: np.ndarray = None
211209

212210
def __init__(self, sensitivity_simulation: SensitivitySimulation,
213-
parameters: list[SensitivityParameter], difference: float = 0.1):
211+
parameters: list[SensitivityParameter], difference: float = 0.01):
214212

215213
super().__init__(sensitivity_simulation, parameters)
216214
self.sensitivity = np.zeros(shape=(self.num_parameters, self.num_outputs))
@@ -260,34 +258,62 @@ def calculate_sensitivity(self):
260258
"""Calculate the two-sided local sensitivity matrix."""
261259

262260
# num_parameters x num_outputs
263-
super().calculate_sensitivity()
261+
# empty sensitivity
262+
self.sensitivity = xr.DataArray(
263+
np.full((self.num_parameters, self.num_outputs), np.nan),
264+
dims=["parameter", "output"],
265+
coords={"parameter": [p.uid for p in self.parameters],
266+
"output": self.outputs},
267+
name="sensitivity"
268+
)
269+
self.sensitivity_normalized = xr.DataArray(
270+
np.full((self.num_parameters, self.num_outputs), np.nan),
271+
dims=["parameter", "output"],
272+
coords={"parameter": [p.uid for p in self.parameters],
273+
"output": self.outputs},
274+
name="sensitivity"
275+
)
264276

265277
for kp, p in enumerate(self.parameters):
266278
pid = self.parameters[kp].uid
279+
p_ref = self.samples[-1, kp]
280+
p_up = self.samples[2*kp, kp]
281+
p_down = self.samples[2 * kp + 1, kp]
282+
267283
for ko, oid in enumerate(self.outputs):
268284
# num_samples x num_outputs
269-
value_ref = self.results[-1, ko]
270-
value_up = self.results[2*kp, ko]
271-
value_down = self.results[2 * kp + 1, ko]
285+
q_ref = self.results[-1, ko]
286+
q_up = self.results[2*kp, ko]
287+
q_down = self.results[2 * kp + 1, ko]
272288

273-
# midpoint method, two-sided sensitivity
274-
self.sensitivity[kp, ko] = (value_up - value_down) / (2.0 * value_ref)
289+
# two-sided sensitivity
290+
self.sensitivity[kp, ko] = (q_up - q_down) / (p_up - p_down)
291+
# normalized: relative change in output per relative change in parameter
292+
self.sensitivity_normalized[kp, ko] = self.sensitivity[kp, ko] * p_ref/q_ref
275293

276-
277-
def plot_sensitivity(self):
278-
from sbmlsim.sensitivity.plots import heatmap
279-
import pandas as pd
280-
281-
df = pd.DataFrame(
282-
self.sensitivity.values,
294+
@property
295+
def sensitivity_df(self) -> pd.DataFrame:
296+
"""Convert sensitivity information to dataframe."""
297+
return pd.DataFrame(
298+
self.sensitivity_normalized.values,
283299
columns=self.sensitivity.coords["output"],
284300
index=self.sensitivity.coords["parameter"]
285301
)
286-
console.print(df)
287-
heatmap(df, cutoff=0.01)
288302

303+
def plot_sensitivity(self):
304+
df = self.sensitivity_df
305+
self.plot_sensitivity_df(df)
306+
307+
@staticmethod
308+
def plot_sensitivity_df(df: pd.DataFrame, cutoff=0.1, cluster_rows: bool = True):
309+
from sbmlsim.sensitivity.plots import heatmap
310+
console.print(df)
289311

312+
# TODO: labels of parameters
313+
# TODO: labels of outputs
314+
# TODO: better position of colorbar
290315

316+
heatmap(df, cutoff=cutoff, cluster_rows=False)
291317

292318

293319

src/sbmlsim/sensitivity/parameters.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sbmlutils.console import console
99
from sbmlutils.factory import ValueWithUnit
1010

11+
import roadrunner
1112

1213

1314
@dataclass
@@ -27,6 +28,7 @@ def parameters_for_sensitivity_analysis(
2728
sbml_path: Path,
2829
exclude_ids: Optional[set[str]] = None,
2930
exclude_na: bool = True,
31+
exclude_zero: bool = True,
3032
) -> list[SensitivityParameter]:
3133
"""Retrieve parameters from model for the sensitivity analysis.
3234
@@ -50,7 +52,7 @@ def parameter_from_sbase(sbase: libsbml.SBase) -> SensitivityParameter:
5052

5153
return parameter
5254

53-
55+
r: roadrunner.RoadRunner = roadrunner.RoadRunner(str(sbml_path))
5456
doc: libsbml.SBMLDocument = libsbml.readSBMLFromFile(str(sbml_path))
5557
sbml_model: libsbml.Model = doc.getModel()
5658

@@ -64,6 +66,8 @@ def parameter_from_sbase(sbase: libsbml.SBase) -> SensitivityParameter:
6466
if p.getConstant() is True:
6567
if exclude_na and np.isnan(p.getValue()):
6668
exclude_ids.add(sid)
69+
if exclude_zero and np.isclose(r.getValue(sid), 0.0):
70+
exclude_ids.add(sid)
6771
parameters.append(parameter_from_sbase(p))
6872

6973
# constant compartments
@@ -73,6 +77,8 @@ def parameter_from_sbase(sbase: libsbml.SBase) -> SensitivityParameter:
7377
if c.getConstant() is True:
7478
if exclude_na and np.isnan(c.getSize()):
7579
exclude_ids.add(sid)
80+
if exclude_zero and np.isclose(r.getValue(sid), 0.0):
81+
exclude_ids.add(sid)
7682
parameters.append(parameter_from_sbase(c))
7783

7884
# constant species or boundaryCondition == True
@@ -87,6 +93,11 @@ def parameter_from_sbase(sbase: libsbml.SBase) -> SensitivityParameter:
8793
exclude_ids.add(sid)
8894
elif s.isSetInitialConcentration() and np.isnan(s.getInitialConcentration()):
8995
exclude_ids.add(sid)
96+
if exclude_zero:
97+
if s.isSetInitialAmount() and np.isclose(s.getInitialAmount(), 0.0):
98+
exclude_ids.add(sid)
99+
elif s.isSetInitialConcentration() and np.isclose(s.getInitialConcentration(), 0.0):
100+
exclude_ids.add(sid)
90101

91102
if s.getConstant() is True or s.getBoundaryCondition() is True:
92103
parameters.append(parameter_from_sbase(s))

src/sbmlsim/sensitivity/plots.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111
import numpy as np
1212
import pandas as pd
1313

14-
def heatmap(df: pd.DataFrame, cutoff: float=0.01, annotate_values=True, transpose: bool=False):
14+
def heatmap(
15+
df: pd.DataFrame,
16+
cutoff: float=0.01,
17+
annotate_values=True,
18+
cluster_rows: bool = True, # cluster parameters
19+
cluster_cols: bool = False, # cluster outputs
20+
transpose: bool=False
21+
):
1522
"""Creates heatmap of model sensitivity"""
1623

1724
def calculate_mask(df, cutoff=0.01):
@@ -24,7 +31,7 @@ def calculate_mask(df, cutoff=0.01):
2431
mask[index] = False
2532
return pd.DataFrame(data=mask, columns=df.columns, index=df.index)
2633

27-
def calculate_subset(df, cutoff=0.01):
34+
def calculate_subset(df, cutoff=0.01) -> pd.DataFrame:
2835
"""Calculates subset of data frame consisting of rows where at least
2936
one value is above cutoff."""
3037
return df[(df.abs() >= cutoff).any(axis=1)]
@@ -40,23 +47,35 @@ def calculate_subset(df, cutoff=0.01):
4047
yticklabels = [pid for pid in df_subset.index]
4148
xticklabels = [pid for pid in df_subset.columns]
4249

50+
n_outputs = df_subset.shape[1]
51+
n_parameters = df_subset.shape[0]
52+
figsize = (7, int(n_parameters / n_outputs * 7)/2)
53+
54+
colorbar_range = 2.0
55+
4356
# plot heatmap
4457
ax = sns.clustermap(
4558
df_subset,
4659
center=0,
47-
vmin=-0.2,
48-
vmax=0.2,
60+
vmin=-colorbar_range,
61+
vmax=colorbar_range,
4962
xticklabels=xticklabels,
5063
yticklabels=yticklabels,
5164
cmap="seismic",
52-
cbar_pos=(0.05, 0.25, 0.03, 0.4),
65+
# cbar_pos=(0.0, 0.0, 0.6, 0.05), # (left, bottom, width, height),
66+
cbar_pos=(0.0, 0.4, 0.03, 0.2), # (left, bottom, width, height),
67+
cbar_kws={
68+
"orientation": "vertical",
69+
"label": "sensitivity"
70+
},
5371
annot=annotate_values,
5472
fmt="1.2f",
55-
annot_kws={"size": 13},
73+
annot_kws={"size": 11},
5674
mask=df_subset_mask,
5775
col_cluster=False,
76+
row_cluster=True,
5877
method="single",
59-
figsize=(20, 20),
78+
figsize=figsize,
6079
)
6180
plt.setp(
6281
ax.ax_heatmap.get_xticklabels(),

0 commit comments

Comments
 (0)