Skip to content

Commit 4db7613

Browse files
cleanup sensitivity
1 parent 91243cd commit 4db7613

File tree

13 files changed

+715
-473
lines changed

13 files changed

+715
-473
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ dependencies = [
5050
# dependencies
5151
"dill>=0.4.0",
5252
"SAlib>=1.5.2",
53-
"petab>=0.7.0",
53+
"petab>=0.8.1",
5454
# "tables>=3.10.2",
5555
"statsmodels>=0.14.6",
5656
"typst>=0.14.5",
Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,45 @@
1-
"""Module for sensitivity analysis.
1+
"""
2+
Sensitivity analysis framework for computational models.
3+
4+
This package provides a unified framework for analyzing how uncertainty and
5+
variability in model parameters affect model outputs. It supports multiple
6+
complementary sensitivity analysis strategies, including local, sampling-based,
7+
and global methods, enabling both qualitative and quantitative assessment of
8+
parameter influence.
9+
10+
Sensitivity analyses are performed by systematically perturbing or sampling
11+
model parameters, executing simulations, and evaluating changes in selected
12+
model outputs. The framework is designed for deterministic simulation models
13+
and integrates sampling, caching, statistical evaluation, and visualization
14+
within a consistent workflow.
215
316
TODO implementation of alternative methods:
417
- [ ] FAST
518
- [ ] Morris
619
7-
FIXME: generate simple example => create tests for the sensitivity
8-
FIXME: general documentation
20+
FIXME: generate simple example
21+
FIXME: create unittests for the sensitivity
922
FIXME: add a flag to control resources for parallelization (ncores)
1023
1124
"""
12-
13-
from .sensitivity_sobol import SobolSensitivityAnalysis
14-
from .sensitivity_sampling import SamplingSensitivityAnalysis
25+
from .analysis import (
26+
SensitivitySimulation,
27+
SensitivityOutput,
28+
AnalysisGroup,
29+
)
30+
from .parameters import (
31+
SensitivityParameter,
32+
)
1533
from .sensitivity_local import LocalSensitivityAnalysis
34+
from .sensitivity_sampling import SamplingSensitivityAnalysis
35+
from .sensitivity_sobol import SobolSensitivityAnalysis
1636

17-
37+
__all__ = [
38+
"SensitivityParameter",
39+
"SensitivitySimulation",
40+
"SensitivityOutput",
41+
"AnalysisGroup",
42+
"SobolSensitivityAnalysis",
43+
"SamplingSensitivityAnalysis",
44+
"LocalSensitivityAnalysis",
45+
]

src/sbmlsim/sensitivity/analysis.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,33 @@
22
33
44
"""
5+
import multiprocessing
56
import os
67
import time
7-
import multiprocessing
88
from dataclasses import dataclass
9-
from typing import Optional, Any
109
from pathlib import Path
11-
from rich.progress import track
12-
from pymetadata.console import console
10+
from typing import Optional, Any
1311

12+
import dill
1413
import numpy as np
1514
import pandas as pd
16-
import xarray as xr
17-
18-
1915
import roadrunner
20-
import dill
16+
import xarray as xr
17+
from pymetadata.console import console
18+
from rich.progress import track
2119

2220
from sbmlsim.sensitivity.parameters import SensitivityParameter
23-
from sbmlsim.sensitivity.outputs import SensitivityOutput
2421
from sbmlsim.sensitivity.plots import heatmap
2522

2623

24+
@dataclass
25+
class SensitivityOutput:
26+
"""Output measurement for SensitivityAnalysis."""
27+
uid: str
28+
name: str
29+
unit: Optional[str]
30+
31+
2732
@dataclass
2833
class AnalysisGroup:
2934
"""Subgroup for analysis."""
@@ -59,8 +64,8 @@ def __init__(self, model_path: Path, selections: list[str],
5964
outputs_dict = {q.uid for q in self.outputs}
6065
for key in y:
6166
if key not in outputs_dict:
62-
raise ValueError(f"Key '{key}' missing in outputs dictionary: '{outputs_dict}")
63-
67+
raise ValueError(
68+
f"Key '{key}' missing in outputs dictionary: '{outputs_dict}")
6469

6570
@staticmethod
6671
def load_model(model_path: Path, selections: list[str]) -> roadrunner.RoadRunner:
@@ -72,7 +77,8 @@ def load_model(model_path: Path, selections: list[str]) -> roadrunner.RoadRunner
7277
return rr
7378

7479
@staticmethod
75-
def apply_changes(r: roadrunner.RoadRunner, changes: dict[str, float], reset_all: bool=True) -> None:
80+
def apply_changes(r: roadrunner.RoadRunner, changes: dict[str, float],
81+
reset_all: bool = True) -> None:
7682
"""Apply changes after possible reset of the model."""
7783
if reset_all:
7884
r.resetAll()
@@ -81,7 +87,8 @@ def apply_changes(r: roadrunner.RoadRunner, changes: dict[str, float], reset_all
8187
# print(f"{key=} {value=}")
8288
r.setValue(key, value)
8389

84-
def simulate(self, r: roadrunner.RoadRunner, changes: dict[str, float]) -> dict[str, float]:
90+
def simulate(self, r: roadrunner.RoadRunner, changes: dict[str, float]) -> dict[
91+
str, float]:
8592
"""Run a model simulation and return scalar results dictionary."""
8693

8794
raise NotImplemented
@@ -115,7 +122,7 @@ def __init__(self,
115122
parameters: list[SensitivityParameter],
116123
groups: list[AnalysisGroup],
117124
results_path: Path,
118-
seed: Optional[int]=None,
125+
seed: Optional[int] = None,
119126
) -> None:
120127
"""Create a sensitivity analysis for given parameter ids.
121128
@@ -161,20 +168,15 @@ def __init__(self,
161168

162169
# multiple sensitivities are stored
163170
# sensitivity matrix; shape: (num_parameters x num_outputs); could be multiple
164-
self.sensitivity: dict[str, dict[str, xr.DataArray]] = {g.uid: {} for g in self.groups}
171+
self.sensitivity: dict[str, dict[str, xr.DataArray]] = {g.uid: {} for g in
172+
self.groups}
165173

166174
def samples_table(self) -> pd.DataFrame:
167175
return self._data_table(d=self.samples)
168176

169177
def results_table(self) -> pd.DataFrame:
170178
return self._data_table(d=self.results)
171179

172-
# def sensitivity_tables(self) -> dict[str, pd.DataFrame]:
173-
#
174-
# for group in self.groups:
175-
# for key in group.changes.keys():
176-
# return {k: self._data_table(d=d) for k, d in self.sensitivity.items()}
177-
178180
def _data_table(self, d: dict[str, xr.DataArray]) -> pd.DataFrame:
179181
items = []
180182
for group in self.groups:
@@ -188,7 +190,8 @@ def _data_table(self, d: dict[str, xr.DataArray]) -> pd.DataFrame:
188190
return pd.DataFrame(items)
189191

190192
def read_cache(self, cache_filename: str, cache: bool) -> Optional[Any]:
191-
cache_path: Optional[Path] = self.results_path / cache_filename if cache_filename else None
193+
cache_path: Optional[
194+
Path] = self.results_path / cache_filename if cache_filename else None
192195
if cache and not cache_path:
193196
raise ValueError("Cache path is required for caching.")
194197

@@ -202,7 +205,8 @@ def read_cache(self, cache_filename: str, cache: bool) -> Optional[Any]:
202205
return None
203206

204207
def write_cache(self, data: Any, cache_filename: str, cache: bool) -> Optional[Any]:
205-
cache_path: Optional[Path] = self.results_path / cache_filename if cache_filename else None
208+
cache_path: Optional[
209+
Path] = self.results_path / cache_filename if cache_filename else None
206210
if cache_path:
207211
with open(cache_path, 'wb') as f:
208212
console.print(f"Simulated samples written to cache: '{cache_path}'")
@@ -247,7 +251,8 @@ def num_samples(self) -> int:
247251
samples = self.samples[self.group_ids[0]]
248252
return samples.shape[0]
249253

250-
def simulate_samples(self, cache_filename: Optional[str] = None, cache: bool = False) -> None:
254+
def simulate_samples(self, cache_filename: Optional[str] = None,
255+
cache: bool = False) -> None:
251256
"""Simulate all samples in parallel.
252257
253258
:param cache_filename: Path to the cache path.
@@ -320,7 +325,8 @@ def split_into_chunks(items, n):
320325
# write to cache
321326
self.write_cache(data=self.results, cache_filename=cache_filename, cache=cache)
322327

323-
def calculate_sensitivity(self, cache_filename: Optional[str] = None, cache: bool = False):
328+
def calculate_sensitivity(self, cache_filename: Optional[str] = None,
329+
cache: bool = False):
324330
"""Calculate the sensitivity matrices."""
325331

326332
raise NotImplemented
@@ -361,14 +367,14 @@ def plot_sensitivity(
361367
)
362368

363369

364-
365370
def run_simulation(
366371
params_tuple
367372
):
368373
"""Pass all required arguments as parameter tuple."""
369374
sensitivity_simulation, r, chunked_changes = params_tuple
370375
outputs = []
371-
for kc in track(range(len(chunked_changes)), description=f"Simulate samples PID={os.getpid()}"):
376+
for kc in track(range(len(chunked_changes)),
377+
description=f"Simulate samples PID={os.getpid()}"):
372378
changes = chunked_changes[kc]
373379
# console.print(f"PID={os.getpid()} | k={kc}")
374380
Y = sensitivity_simulation.simulate(
@@ -378,6 +384,3 @@ def run_simulation(
378384
outputs.append(Y)
379385

380386
return outputs
381-
382-
383-

src/sbmlsim/sensitivity/example/sensitivity_example.py

Lines changed: 36 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import roadrunner
6-
from roadrunner._roadrunner import NamedArray
76
from pymetadata.console import console
87

98
from sbmlsim.sensitivity.analysis import (
@@ -13,13 +12,11 @@
1312
)
1413
from sbmlsim.sensitivity.parameters import (
1514
SensitivityParameter,
16-
ParameterType,
1715
parameters_for_sensitivity_analysis,
1816
)
1917

2018
model_path: Path = Path(__file__).parent / "simple_chain.xml"
2119

22-
2320
# Subgroups to perform sensitivity analysis on
2421
sensitivity_groups: list[AnalysisGroup] = [
2522
AnalysisGroup(
@@ -48,22 +45,24 @@ class ExampleSensitivitySimulation(SensitivitySimulation):
4845
tend = 1000 #
4946
steps = 1000
5047

51-
def simulate(self, r: roadrunner.RoadRunner, changes: dict[str, float]) -> dict[str, float]:
48+
def simulate(self, r: roadrunner.RoadRunner, changes: dict[str, float]) -> dict[
49+
str, float]:
5250

5351
# apply changes and simulate
5452
all_changes = {
5553
**self.changes_simulation, # model
5654
**changes # sensitivity
5755
}
5856
self.apply_changes(r, all_changes, reset_all=True)
59-
# ensure tolerances
57+
58+
# ensure identical tolerances on all simulations
6059
r.integrator.setValue("absolute_tolerance", self.init_tolerances)
61-
s: NamedArray = r.simulate(start=0, end=self.tend, steps=self.steps)
60+
s = r.simulate(start=0, end=self.tend, steps=self.steps)
6261

63-
# pharmacokinetic parameters
62+
# calculate outputs y (custom functions)
63+
# this can be registered functions calculating scalars based on subsets of the
64+
# timecourse vectors
6465
y: dict[str, float] = {}
65-
66-
# calculate outputs (custom functions)
6766
t = s["time"]
6867
for key in "S1", "S2", "S3":
6968
rr_key = f"[{key}]"
@@ -79,13 +78,8 @@ def simulate(self, r: roadrunner.RoadRunner, changes: dict[str, float]) -> dict[
7978

8079
sensitivity_simulation = ExampleSensitivitySimulation(
8180
model_path=model_path,
82-
selections=[
83-
"time",
84-
"[S1]",
85-
"[S2]",
86-
"[S3]",
87-
],
88-
changes_simulation = {},
81+
selections=["time", "[S1]", "[S2]", "[S3]"],
82+
changes_simulation={},
8983
outputs=[
9084
SensitivityOutput(uid='[S1]_auc', name='[S1] AUC', unit=None),
9185
SensitivityOutput(uid='[S2]_tmax', name='[S2] time maximum', unit=None),
@@ -120,49 +114,46 @@ def _sensitivity_parameters() -> list[SensitivityParameter]:
120114

121115
sensitivity_parameters = _sensitivity_parameters()
122116

123-
124117
if __name__ == "__main__":
125-
126118
from sbmlsim.sensitivity import (
127119
LocalSensitivityAnalysis,
128120
SobolSensitivityAnalysis,
129121
SamplingSensitivityAnalysis,
130122
)
123+
131124
sensitivity_path = Path(__file__).parent / "results"
132125
console.print(SensitivityParameter.parameters_to_df(sensitivity_parameters))
133126

134-
# SamplingSensitivityAnalysis.run_sensitivity_analysis(
135-
# results_path=sensitivity_path / "sampling",
136-
# sensitivity_simulation=sensitivity_simulation,
137-
# parameters=sensitivity_parameters,
138-
# groups=sensitivity_groups,
139-
# # cache_results=False,
140-
# # cache_sensitivity=False,
141-
# N=200,
142-
# seed=1234,
143-
# )
144-
#
145-
# LocalSensitivityAnalysis.run_sensitivity_analysis(
146-
# results_path=sensitivity_path / "local",
147-
# sensitivity_simulation=sensitivity_simulation,
148-
# parameters=sensitivity_parameters,
149-
# groups=[sensitivity_groups[1]],
150-
# # cache_results=False,
151-
# # cache_sensitivity=False,
152-
# difference=0.01,
153-
# seed=1234,
154-
# )
127+
SamplingSensitivityAnalysis.run_sensitivity_analysis(
128+
results_path=sensitivity_path / "sampling",
129+
sensitivity_simulation=sensitivity_simulation,
130+
parameters=sensitivity_parameters,
131+
groups=sensitivity_groups,
132+
cache_results=True,
133+
cache_sensitivity=True,
134+
N=1000,
135+
seed=1234,
136+
)
137+
138+
LocalSensitivityAnalysis.run_sensitivity_analysis(
139+
results_path=sensitivity_path / "local",
140+
sensitivity_simulation=sensitivity_simulation,
141+
parameters=sensitivity_parameters,
142+
groups=[sensitivity_groups[1]],
143+
cache_results=True,
144+
cache_sensitivity=True,
145+
difference=0.01,
146+
seed=1234,
147+
)
155148

156149
SobolSensitivityAnalysis.run_sensitivity_analysis(
157150
results_path=sensitivity_path / "sobol",
158151
sensitivity_simulation=sensitivity_simulation,
159152
parameters=sensitivity_parameters,
160153
groups=[sensitivity_groups[1]],
161-
# cache_results=False,
162-
# cache_sensitivity=False,
163-
# N=2048,
164-
N=8,
154+
cache_results=True,
155+
cache_sensitivity=True,
156+
N=2048,
157+
# N=8,
165158
seed=1234,
166159
)
167-
168-

0 commit comments

Comments
 (0)