Skip to content

Commit 64e16ea

Browse files
working on sensitivity
1 parent 6640e3b commit 64e16ea

File tree

1 file changed

+124
-3
lines changed

1 file changed

+124
-3
lines changed

src/sbmlsim/sensitivity/global_sensitivity.py

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
- [ ] parallelization ? (benchmark)
2020
2121
"""
22+
from typing import Optional
23+
2224
import SALib
2325
from SALib import ProblemSpec
2426
from SALib.sample import saltelli
@@ -35,7 +37,6 @@
3537
from roadrunner._roadrunner import NamedArray
3638

3739

38-
3940
@dataclass
4041
class SensitivitySimulation:
4142
"""Base class for sensitivity calculation.
@@ -49,8 +50,9 @@ class SensitivitySimulation:
4950
selections: list[str]
5051
rr: roadrunner.RoadRunner = None
5152
outputs: list[str] = None
53+
changes_simulation: dict[str, float] = None
5254

53-
def __init__(self, model_path: Path, selections: list[str]):
55+
def __init__(self, model_path: Path, selections: list[str], changes_simulation: dict[str, float]):
5456
self.model_path = model_path
5557
self.selections = selections
5658
self.rr: roadrunner.RoadRunner = roadrunner.RoadRunner(str(model_path))
@@ -59,6 +61,9 @@ def __init__(self, model_path: Path, selections: list[str]):
5961
integrator.setSetting("variable_step_size", True)
6062
# state = rr.saveStateS()
6163

64+
# store the simulation changes
65+
self.changes_simulation = changes_simulation
66+
6267
# get the outputs from the simulation
6368
y = self.simulate(changes={})
6469
self.outputs = list(y.keys())
@@ -71,6 +76,15 @@ def simulate(self, changes: dict[str, float]) -> dict[str, float]:
7176
"""
7277
raise NotImplemented
7378

79+
def parameter_values(self, parameters: list[str], changes: dict[str, float]) -> dict[str, float]:
80+
"""Get the parameter values for a given set of changes."""
81+
self.apply_changes(changes, reset_all=True)
82+
values: dict[str, float] = {}
83+
for pid in parameters:
84+
values[pid] = self.rr.getValue(pid)
85+
return values
86+
87+
7488
def plot(self) -> None:
7589
"""Plots the model simulation for debugging."""
7690
raise NotImplemented
@@ -84,9 +98,116 @@ def apply_changes(self, changes: dict[str, float], reset_all: bool=True) -> None
8498
self.rr.setValue(key, value)
8599

86100

101+
@dataclass
102+
class SensitivityAnalysis:
103+
"""Parent class for all sensitivity analysis.
104+
105+
TODO: additional metadata for the outputs and the parameters; i.e. name, units, bounds, ....
106+
"""
107+
108+
sensitivity_simulation: SensitivitySimulation
109+
110+
def __init__(self, sensitivity_simulation: SensitivitySimulation,
111+
parameters: list[str]) -> None:
112+
"""Create a sensitivity analysis for given parameter ids.
113+
114+
Based on the results matrix the sensitivity is calculated.
115+
"""
116+
self.sensitivity_simulation = sensitivity_simulation
117+
118+
# parameters to vary; shape: (num_parameters,)
119+
self.parameters: list[str] = parameters
120+
# outputs to calculate sensitivity on; shape: (num_outputs,)
121+
self.outputs: list[str] = sensitivity_simulation.outputs
122+
# parameter samples for sensitivity; shape: (num_samples x num_parameters)
123+
self.samples: Optional[np.ndarray] = None
124+
# outputs for given samples; shape: (num_samples x num_outputs)
125+
self.results: Optional[np.ndarray] = None
126+
# sensitivity matrix; shape: (num_parameters x num_outputs)
127+
self.sensitivity_results: Optional[np.ndarray] = None
128+
129+
@property
130+
def num_parameters(self) -> int:
131+
return len(self.parameters)
132+
133+
@property
134+
def num_outputs(self) -> int:
135+
return len(self.outputs)
136+
137+
def create_samples(self) -> None:
138+
"""Create and set parameter samples."""
139+
140+
raise NotImplemented
141+
142+
def num_samples(self) -> int:
143+
"""Number of samples.
144+
145+
Requires that samples have been created.
146+
"""
147+
return self.samples.shape[0]
148+
149+
def simulate_samples(self) -> None:
150+
"""Simulate all samples."""
151+
self.samples = np.zeros(shape=(self.num_samples, self.num_parameters))
152+
self.outputs = np.zeros(shape=(self.num_samples, self.num_outputs))
153+
154+
for k in range(self.num_samples()):
155+
changes = dict(zip(self.parameters, self.samples[k, :]))
156+
outputs = self.sensitivity_simulation.simulate(changes=changes)
157+
self.outputs[k, :] = outputs
158+
159+
def calculate_sensitivity(self):
160+
"""Calculate the sensitivity matrix."""
161+
162+
raise NotImplemented
163+
164+
@dataclass
165+
class LocalSensitivityAnalysis(SensitivityAnalysis):
166+
"""Local sensitivity analysis based on local differences."""
167+
168+
difference: float
169+
sensitivity: np.ndarray = None
170+
171+
def __init__(self, sensitivity_simulation: SensitivitySimulation,
172+
parameters: list[str], difference: float = 0.1):
173+
174+
self.sensitivity = np.zeros(shape=(self.num_parameters, self.num_outputs))
175+
self.difference = difference
176+
self.samples = self.create_samples()
177+
178+
@property
179+
def num_samples(self) -> int:
180+
"""Number of parameter samples to simulate."""
181+
return 2 * self.num_parameters
182+
183+
def create_samples(self) -> np.ndarray:
184+
185+
for key, value in p_ref.items():
186+
values = np.ones(shape=(2 * num_pars,)) * value.magnitude
187+
# change parameters in correct position
188+
values[index] = value.magnitude * (1.0 + difference)
189+
values[index + num_pars] = value.magnitude * (1.0 - difference)
190+
changes[key] = Q_(values, value.units)
191+
index += 1
192+
193+
def calculate_sensitivity(self):
194+
195+
pass
196+
197+
def plot_sensitivity(self):
198+
199+
pass
200+
201+
202+
@dataclass
203+
class SamplingSensitivityAnalysis(SensitivityAnalysis):
204+
"""Sample from provided parameter distributions."""
205+
206+
# TODO: implement
207+
pass
87208

88209
@dataclass
89-
class SBMLSensitivityAnalysis:
210+
class GlobalSobolSensitivityAnalysis:
90211
"""Parent class for sensitivity analysis."""
91212

92213
sensitivity_simulation: SensitivitySimulation

0 commit comments

Comments
 (0)