Skip to content

Commit 93fa4ae

Browse files
starting parallelization
1 parent ade56ab commit 93fa4ae

File tree

1 file changed

+39
-20
lines changed

1 file changed

+39
-20
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,50 +35,60 @@ class SensitivitySimulation:
3535
This function is called repeatedly during the sensitivity calculation.
3636
"""
3737

38-
def __init__(self, model_path: Path, selections: list[str], changes_simulation: dict[str, float], outputs: list[SensitivityOutput]):
38+
def __init__(self, model_path: Path, selections: list[str],
39+
changes_simulation: dict[str, float],
40+
outputs: list[SensitivityOutput]):
3941
self.model_path = model_path
4042
self.selections = selections
41-
42-
self.rr: roadrunner.RoadRunner = roadrunner.RoadRunner(str(model_path))
43-
self.rr.selections = self.selections
44-
# integrator: roadrunner.Integrator = self.rr.integrator
45-
# integrator.setSetting("variable_step_size", True)
43+
self.changes_simulation = changes_simulation
4644

4745
# store the simulation changes
48-
self.changes_simulation: dict[str, float] = changes_simulation
4946
self.outputs: list[SensitivityOutput] = outputs
5047

5148
# validate the outputs from the simulation
52-
y = self.simulate(changes={})
49+
rr = self.load_model(model_path=model_path, selections=self.selections)
50+
y = self.simulate(r=rr, changes={})
5351
outputs_dict = {q.uid for q in self.outputs}
5452
for key in y:
5553
if key not in outputs_dict:
5654
raise ValueError(f"Key '{key}' missing in outputs dictionary: '{outputs_dict}")
5755

58-
def apply_changes(self, changes: dict[str, float], reset_all: bool=True) -> None:
56+
57+
@staticmethod
58+
def load_model(model_path: Path, selections: list[str]) -> roadrunner.RoadRunner:
59+
"""Load roadrunner model."""
60+
rr: roadrunner.RoadRunner = roadrunner.RoadRunner(str(model_path))
61+
rr.selections = selections
62+
# integrator: roadrunner.Integrator = self.rr.integrator
63+
# integrator.setSetting("variable_step_size", True)
64+
return rr
65+
66+
@staticmethod
67+
def apply_changes(r: roadrunner.RoadRunner, changes: dict[str, float], reset_all: bool=True) -> None:
5968
"""Apply changes after possible reset of the model."""
6069
if reset_all:
61-
self.rr.resetAll()
70+
r.resetAll()
6271
for key, value in changes.items():
6372
# print(f"{key=} {value=}")
64-
self.rr.setValue(key, value)
73+
r.setValue(key, value)
6574

66-
def simulate(self, changes: dict[str, float]) -> dict[str, float]:
75+
def simulate(self, r: roadrunner.RoadRunner, changes: dict[str, float]) -> dict[str, float]:
6776
"""Run a model simulation and return scalar results dictionary."""
6877

6978
raise NotImplemented
7079

71-
def parameter_values(self,
80+
@classmethod
81+
def parameter_values(cls, r: roadrunner.RoadRunner,
7282
parameters: list[SensitivityParameter],
7383
changes: dict[str, float]
7484
) -> dict[str, float]:
7585
"""Get the parameter values for a given set of changes."""
76-
self.apply_changes(changes, reset_all=True)
86+
cls.apply_changes(r, changes, reset_all=True)
7787

7888
values: dict[str, float] = {}
7989
p: SensitivityParameter
8090
for p in parameters:
81-
values[p.uid] = self.rr.getValue(p.uid)
91+
values[p.uid] = r.getValue(p.uid)
8292

8393
return values
8494

@@ -151,12 +161,19 @@ def simulate_samples(self) -> None:
151161
name="results"
152162
)
153163

154-
pids = [p.uid for p in self.parameters]
164+
# load the integrators
165+
r: roadrunner.RoadRunner = self.sensitivity_simulation.load_model(
166+
model_path=self.sensitivity_simulation.model_path,
167+
selections=self.sensitivity_simulation.selections,
168+
)
169+
170+
# FIXME: here the parallelization must take place
155171
for k in track(range(self.num_samples), description="Simulating samples"):
156-
# console.print(f"{k}/{self.num_samples}")
157-
changes = dict(zip(pids, self.samples[k, :].values))
158-
# console.print(changes)
159-
outputs = self.sensitivity_simulation.simulate(changes=changes)
172+
changes = dict(zip(self.parameter_ids, self.samples[k, :].values))
173+
outputs = self.sensitivity_simulation.simulate(
174+
r=r,
175+
changes=changes
176+
)
160177
self.results[k, :] = list(outputs.values())
161178

162179
def calculate_sensitivity(self):
@@ -200,7 +217,9 @@ def create_samples(self) -> None:
200217
with increase and decrease of the respective parameter.
201218
"""
202219
# Calculate the parameter values in the reference state
220+
r = self.sensitivity_simulation.load_model(self.sensitivity_simulation.model_path, selections=self.sensitivity_simulation.selections)
203221
parameter_values: dict[str, float] = self.sensitivity_simulation.parameter_values(
222+
r=r,
204223
parameters=self.parameters,
205224
changes=self.sensitivity_simulation.changes_simulation
206225
)

0 commit comments

Comments
 (0)