Skip to content

Commit deaf3b3

Browse files
parallelization
1 parent 0ce7e54 commit deaf3b3

File tree

1 file changed

+61
-24
lines changed

1 file changed

+61
-24
lines changed

src/sbmlsim/sensitivity/analysis.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
- [ ] Morris
66
- [ ] Sampling based methods (distribution)
77
"""
8+
import time
9+
import multiprocessing
810
from typing import Optional
911
from pathlib import Path
1012
from dataclasses import dataclass
@@ -152,6 +154,7 @@ def num_samples(self) -> int:
152154

153155
def simulate_samples(self) -> None:
154156
"""Simulate all samples."""
157+
start = time.perf_counter()
155158

156159
# num_samples x num_outputs
157160
self.results = xr.DataArray(
@@ -176,9 +179,12 @@ def simulate_samples(self) -> None:
176179
)
177180
self.results[k, :] = list(outputs.values())
178181

182+
elapsed = time.perf_counter() - start
183+
console.print(f"Serial: {elapsed:.3f} s")
184+
179185
def simulate_samples_parallel(self) -> None:
180186
"""Simulate all samples in parallel."""
181-
import multiprocessing
187+
start = time.perf_counter()
182188

183189
# num_samples x num_outputs
184190
self.results = xr.DataArray(
@@ -188,31 +194,45 @@ def simulate_samples_parallel(self) -> None:
188194
name="results"
189195
)
190196

191-
n_cores = multiprocessing.cpu_count()
192-
193197
# load model
194198
r: roadrunner.RoadRunner = self.sensitivity_simulation.load_model(
195199
model_path=self.sensitivity_simulation.model_path,
196200
selections=self.sensitivity_simulation.selections,
197201
)
198-
# this must be handled via batches
202+
203+
# number of cores
204+
n_cores = multiprocessing.cpu_count()
205+
206+
# create chunk of samples for core
207+
def split_into_chunks(items, n):
208+
m = len(items)
209+
k, r = divmod(m, n)
210+
chunks = [
211+
items[i * k + min(i, r):(i + 1) * k + min(i + 1, r)]
212+
for i in range(n)
213+
]
214+
chunked_samples = [
215+
[dict(zip(self.parameter_ids, self.samples[k, :].values)) for k in chunk]
216+
for chunk in chunks
217+
]
218+
return chunks, chunked_samples
219+
220+
items = list(range(self.num_samples))
221+
chunks, chunked_samples = split_into_chunks(items, n_cores)
222+
223+
# parameters for multiprocessing
199224
sa_sim = self.sensitivity_simulation
200-
changes_batch = []
201-
rrs = [(sa_sim, r, changes_batch) for i in range(n_cores)]
225+
rrs = [(sa_sim, r, chunked_samples[i]) for i in range(n_cores)]
202226

203227
with multiprocessing.Pool(processes=n_cores) as pool:
204-
results = pool.map(run_simulation, rrs)
228+
outputs_list: list = pool.map(run_simulation, rrs)
205229

206-
# TODO: collect results
207-
# # FIXME: here the parallelization must take place
208-
# for k in track(range(self.num_samples), description="Simulating samples"):
209-
# changes = dict(zip(self.parameter_ids, self.samples[k, :].values))
210-
#
211-
# outputs = self.sensitivity_simulation.simulate(
212-
# r=r,
213-
# changes=changes
214-
# )
215-
# self.results[k, :] = list(outputs.values())
230+
for kc, chunk in enumerate(chunks):
231+
for kp, idx in enumerate(chunk):
232+
self.results[idx, :] = list(outputs_list[kc][kp].values())
233+
234+
elapsed = time.perf_counter() - start
235+
console.print(f"Parallel simulation: {elapsed:.3f} s")
216236

217237

218238
def calculate_sensitivity(self):
@@ -229,20 +249,37 @@ def sensitivity_df(self, key="normalized") -> pd.DataFrame:
229249
index=self.sensitivity[key].coords["parameter"]
230250
)
231251

252+
import os
232253

233254
def run_simulation(
234255
params_tuple
235256
):
236257
"""Pass all required arguments as parameter tuple."""
237-
# FIXME: this must run a batch of simulations
238-
sensitivity_simulation, r, changes_batch = params_tuple
258+
sensitivity_simulation, r, chunked_changes = params_tuple
259+
260+
outputs = []
261+
262+
for kc in track(range(len(chunked_changes)), description=f"Simulate samples PID={os.getpid()}"):
263+
changes = chunked_changes[kc]
264+
# console.print(f"PID={os.getpid()} | k={kc}")
265+
Y = sensitivity_simulation.simulate(
266+
r=r,
267+
changes=changes
268+
)
269+
outputs.append(Y)
270+
271+
return outputs
272+
273+
274+
275+
276+
277+
278+
279+
280+
239281

240282

241-
console.print("Simulate parallel")
242-
return sensitivity_simulation.simulate(
243-
r=r,
244-
changes={}
245-
)
246283

247284
class LocalSensitivityAnalysis(SensitivityAnalysis):
248285
"""Local sensitivity analysis based on local differences.

0 commit comments

Comments
 (0)