Skip to content

Commit 0ce7e54

Browse files
first prototype parallelization
1 parent 93fa4ae commit 0ce7e54

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,45 @@ def simulate_samples(self) -> None:
176176
)
177177
self.results[k, :] = list(outputs.values())
178178

179+
def simulate_samples_parallel(self) -> None:
180+
"""Simulate all samples in parallel."""
181+
import multiprocessing
182+
183+
# num_samples x num_outputs
184+
self.results = xr.DataArray(
185+
np.full((self.num_samples, self.num_outputs), np.nan),
186+
dims=["sample", "output"],
187+
coords={"sample": range(self.num_samples), "output": self.outputs},
188+
name="results"
189+
)
190+
191+
n_cores = multiprocessing.cpu_count()
192+
193+
# load model
194+
r: roadrunner.RoadRunner = self.sensitivity_simulation.load_model(
195+
model_path=self.sensitivity_simulation.model_path,
196+
selections=self.sensitivity_simulation.selections,
197+
)
198+
# this must be handled via batches
199+
sa_sim = self.sensitivity_simulation
200+
changes_batch = []
201+
rrs = [(sa_sim, r, changes_batch) for i in range(n_cores)]
202+
203+
with multiprocessing.Pool(processes=n_cores) as pool:
204+
results = pool.map(run_simulation, rrs)
205+
206+
# TODO: collect results
207+
# # FIXME: here the parallelization must take place
208+
# for k in track(range(self.num_samples), description="Simulating samples"):
209+
# changes = dict(zip(self.parameter_ids, self.samples[k, :].values))
210+
#
211+
# outputs = self.sensitivity_simulation.simulate(
212+
# r=r,
213+
# changes=changes
214+
# )
215+
# self.results[k, :] = list(outputs.values())
216+
217+
179218
def calculate_sensitivity(self):
180219
"""Calculate the sensitivity matrices."""
181220

@@ -191,6 +230,20 @@ def sensitivity_df(self, key="normalized") -> pd.DataFrame:
191230
)
192231

193232

233+
def run_simulation(
234+
params_tuple
235+
):
236+
"""Pass all required arguments as parameter tuple."""
237+
# FIXME: this must run a batch of simulations
238+
sensitivity_simulation, r, changes_batch = params_tuple
239+
240+
241+
console.print("Simulate parallel")
242+
return sensitivity_simulation.simulate(
243+
r=r,
244+
changes={}
245+
)
246+
194247
class LocalSensitivityAnalysis(SensitivityAnalysis):
195248
"""Local sensitivity analysis based on local differences.
196249

0 commit comments

Comments
 (0)