Skip to content

Commit 0915f09

Browse files
run sensitivity analysis
1 parent 197185e commit 0915f09

File tree

8 files changed

+210
-387
lines changed

8 files changed

+210
-387
lines changed

src/sbmlsim/sensitivity/__init__.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,6 @@
1212
model outputs. The framework is designed for deterministic simulation models
1313
and integrates sampling, caching, statistical evaluation, and visualization
1414
within a consistent workflow.
15-
16-
TODO implementation of alternative methods:
17-
- [ ] Morris
18-
19-
FIXME: generate simple example
20-
FIXME: create unittests for the sensitivity
21-
FIXME: add a flag to control resources for parallelization (ncores)
22-
2315
"""
2416
from .analysis import (
2517
SensitivityAnalysis,
@@ -30,10 +22,10 @@
3022
from .parameters import (
3123
SensitivityParameter,
3224
)
25+
from .sensitivity_fast import FASTSensitivityAnalysis
3326
from .sensitivity_local import LocalSensitivityAnalysis
3427
from .sensitivity_sampling import SamplingSensitivityAnalysis
3528
from .sensitivity_sobol import SobolSensitivityAnalysis
36-
from .sensitivity_fast import FASTSensitivityAnalysis
3729

3830
__all__ = [
3931
"SensitivityParameter",

src/sbmlsim/sensitivity/analysis.py

Lines changed: 83 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def __init__(self,
123123
groups: list[AnalysisGroup],
124124
results_path: Path,
125125
seed: Optional[int] = None,
126+
n_cores: Optional[int] = None,
127+
cache_results: bool = False,
126128
) -> None:
127129
"""Create a sensitivity analysis for given parameter ids.
128130
@@ -160,6 +162,15 @@ def __init__(self,
160162
if seed is not None:
161163
np.random.seed(seed)
162164

165+
# caching
166+
self.cache_results: bool = cache_results
167+
self.prefix: str = self.__class__.__name__
168+
169+
# handle compute resources
170+
if not n_cores:
171+
n_cores = int(round(0.9 * multiprocessing.cpu_count()))
172+
self.n_cores = n_cores
173+
163174
# parameter samples for sensitivity; shape: (num_samples x num_parameters)
164175
self.samples: dict[str, Optional[xr.DataArray]] = {}
165176

@@ -171,47 +182,6 @@ def __init__(self,
171182
self.sensitivity: dict[str, dict[str, xr.DataArray]] = {g.uid: {} for g in
172183
self.groups}
173184

174-
def samples_table(self) -> pd.DataFrame:
175-
return self._data_table(d=self.samples)
176-
177-
def results_table(self) -> pd.DataFrame:
178-
return self._data_table(d=self.results)
179-
180-
def _data_table(self, d: dict[str, xr.DataArray]) -> pd.DataFrame:
181-
items = []
182-
for group in self.groups:
183-
da: xr.DataArray = d[group.uid]
184-
item = {
185-
'group': group.uid,
186-
# 'group_name': group.name,
187-
**da.sizes,
188-
}
189-
items.append(item)
190-
return pd.DataFrame(items)
191-
192-
def read_cache(self, cache_filename: str, cache: bool) -> Optional[Any]:
193-
cache_path: Optional[
194-
Path] = self.results_path / cache_filename if cache_filename else None
195-
if cache and not cache_path:
196-
raise ValueError("Cache path is required for caching.")
197-
198-
# retrieve from cache
199-
if cache and cache_path.exists():
200-
with open(cache_path, 'rb') as f:
201-
data = dill.load(f)
202-
console.print(f"Simulated samples loaded from cache: '{cache_path}'")
203-
return data
204-
205-
return None
206-
207-
def write_cache(self, data: Any, cache_filename: str, cache: bool) -> Optional[Any]:
208-
cache_path: Optional[
209-
Path] = self.results_path / cache_filename if cache_filename else None
210-
if cache_path:
211-
with open(cache_path, 'wb') as f:
212-
console.print(f"Simulated samples written to cache: '{cache_path}'")
213-
dill.dump(data, f)
214-
215185
@property
216186
def output_ids(self) -> list[str]:
217187
return [o.uid for o in self.outputs]
@@ -236,6 +206,30 @@ def num_outputs(self) -> int:
236206
def num_groups(self) -> int:
237207
return len(self.groups)
238208

209+
def execute(self):
210+
"""Execute the sensitivity analysis."""
211+
console.rule(
212+
f"{self.__class__.__name__}",
213+
style="blue bold",
214+
align="center",
215+
)
216+
console.rule("Samples", style="white")
217+
self.create_samples()
218+
console.print(self.samples_table())
219+
220+
console.rule("Results", style="white")
221+
self.simulate_samples(
222+
cache_filename=f"{self.prefix}_results.pkl",
223+
cache=self.cache_results,
224+
)
225+
console.print(self.results_table())
226+
227+
console.rule("Sensitivity", style="white")
228+
self.calculate_sensitivity(
229+
cache_filename=f"{self.prefix}_sensitivity.pkl",
230+
cache=self.cache_results,
231+
)
232+
239233
def create_samples(self) -> None:
240234
"""Create and set parameter samples."""
241235

@@ -283,8 +277,6 @@ def simulate_samples(self, cache_filename: Optional[str] = None,
283277
)
284278

285279
# number of cores
286-
n_cores = multiprocessing.cpu_count()
287-
288280
samples = self.samples[group.uid]
289281

290282
# create chunk of samples for core
@@ -305,13 +297,13 @@ def split_into_chunks(items, n):
305297
return chunks, chunked_samples
306298

307299
items = list(range(self.num_samples))
308-
chunks, chunked_samples = split_into_chunks(items, n_cores)
300+
chunks, chunked_samples = split_into_chunks(items, self.n_cores)
309301

310302
# parameters for multiprocessing
311303
sa_sim = self.sensitivity_simulation
312-
rrs = [(sa_sim, r, chunked_samples[i]) for i in range(n_cores)]
304+
rrs = [(sa_sim, r, chunked_samples[i]) for i in range(self.n_cores)]
313305

314-
with multiprocessing.Pool(processes=n_cores) as pool:
306+
with multiprocessing.Pool(processes=self.n_cores) as pool:
315307
outputs_list: list = pool.map(run_simulation, rrs)
316308

317309
for kc, chunk in enumerate(chunks):
@@ -331,6 +323,47 @@ def calculate_sensitivity(self, cache_filename: Optional[str] = None,
331323

332324
raise NotImplemented
333325

326+
def samples_table(self) -> pd.DataFrame:
327+
return self._data_table(d=self.samples)
328+
329+
def results_table(self) -> pd.DataFrame:
330+
return self._data_table(d=self.results)
331+
332+
def _data_table(self, d: dict[str, xr.DataArray]) -> pd.DataFrame:
333+
items = []
334+
for group in self.groups:
335+
da: xr.DataArray = d[group.uid]
336+
item = {
337+
'group': group.uid,
338+
# 'group_name': group.name,
339+
**da.sizes,
340+
}
341+
items.append(item)
342+
return pd.DataFrame(items)
343+
344+
def read_cache(self, cache_filename: str, cache: bool) -> Optional[Any]:
345+
cache_path: Optional[
346+
Path] = self.results_path / cache_filename if cache_filename else None
347+
if cache and not cache_path:
348+
raise ValueError("Cache path is required for caching.")
349+
350+
# retrieve from cache
351+
if cache and cache_path.exists():
352+
with open(cache_path, 'rb') as f:
353+
data = dill.load(f)
354+
console.print(f"Simulated samples loaded from cache: '{cache_path}'")
355+
return data
356+
357+
return None
358+
359+
def write_cache(self, data: Any, cache_filename: str, cache: bool) -> Optional[Any]:
360+
cache_path: Optional[
361+
Path] = self.results_path / cache_filename if cache_filename else None
362+
if cache_path:
363+
with open(cache_path, 'wb') as f:
364+
console.print(f"Simulated samples written to cache: '{cache_path}'")
365+
dill.dump(data, f)
366+
334367
def sensitivity_df(self, group_id: str, key: str) -> pd.DataFrame:
335368
"""Convert sensitivity information to dataframes."""
336369

@@ -341,6 +374,10 @@ def sensitivity_df(self, group_id: str, key: str) -> pd.DataFrame:
341374
index=sensitivity.coords["parameter"]
342375
)
343376

377+
def plot(self):
378+
"""Should be implemented by subclass."""
379+
console.rule("Plotting", style="white")
380+
344381
def plot_sensitivity(
345382
self,
346383
group_id: str,

src/sbmlsim/sensitivity/example/sensitivity_example.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -112,58 +112,63 @@ def _sensitivity_parameters() -> list[SensitivityParameter]:
112112
sensitivity_parameters = _sensitivity_parameters()
113113

114114
if __name__ == "__main__":
115+
import multiprocessing
115116
from sbmlsim.sensitivity import (
116-
LocalSensitivityAnalysis,
117117
SobolSensitivityAnalysis,
118+
LocalSensitivityAnalysis,
118119
SamplingSensitivityAnalysis,
119120
FASTSensitivityAnalysis,
120121
)
121122

122123
sensitivity_path = Path(__file__).parent / "results"
123124
console.print(SensitivityParameter.parameters_to_df(sensitivity_parameters))
124-
cache = False
125125

126-
SamplingSensitivityAnalysis.run_sensitivity_analysis(
127-
results_path=sensitivity_path / "sampling",
126+
settings = {
127+
"cache_results": False,
128+
"n_cores": int(round(0.9 * multiprocessing.cpu_count())),
129+
"seed": 1234
130+
}
131+
132+
sa_sampling = SamplingSensitivityAnalysis(
128133
sensitivity_simulation=sensitivity_simulation,
129134
parameters=sensitivity_parameters,
130135
groups=sensitivity_groups,
131-
cache_results=cache,
132-
cache_sensitivity=cache,
136+
results_path=sensitivity_path / "sampling",
133137
N=1000,
134-
seed=1234,
138+
**settings,
135139
)
140+
sa_sampling.execute()
141+
sa_sampling.plot()
136142

137-
LocalSensitivityAnalysis.run_sensitivity_analysis(
138-
results_path=sensitivity_path / "local",
143+
sa_local = LocalSensitivityAnalysis(
139144
sensitivity_simulation=sensitivity_simulation,
140145
parameters=sensitivity_parameters,
141-
groups=[sensitivity_groups[1]],
142-
cache_results=cache,
143-
cache_sensitivity=cache,
146+
groups=sensitivity_groups,
147+
results_path=sensitivity_path / "local",
144148
difference=0.01,
145-
seed=1234,
149+
**settings,
146150
)
151+
sa_local.execute()
152+
sa_local.plot()
147153

148-
SobolSensitivityAnalysis.run_sensitivity_analysis(
149-
results_path=sensitivity_path / "sobol",
154+
sa_sobol = SobolSensitivityAnalysis(
150155
sensitivity_simulation=sensitivity_simulation,
151156
parameters=sensitivity_parameters,
152157
groups=[sensitivity_groups[1]],
153-
cache_results=cache,
154-
cache_sensitivity=cache,
158+
results_path=sensitivity_path / "sobol",
155159
N=4096,
156-
# N=8,
157-
seed=1234,
160+
**settings,
158161
)
162+
sa_sobol.execute()
163+
sa_sobol.plot()
159164

160-
FASTSensitivityAnalysis.run_sensitivity_analysis(
161-
results_path=sensitivity_path / "fast",
165+
sa_fast = FASTSensitivityAnalysis(
162166
sensitivity_simulation=sensitivity_simulation,
163167
parameters=sensitivity_parameters,
164-
groups=[sensitivity_groups[1]],
165-
cache_results=cache,
166-
cache_sensitivity=cache,
168+
groups=sensitivity_groups,
169+
results_path=sensitivity_path / "fast",
167170
N=1000,
168-
seed=1234,
171+
**settings,
169172
)
173+
sa_fast.execute()
174+
sa_fast.plot()

src/sbmlsim/sensitivity/plots.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def plot_S1_ST_indices(
142142
ymin=np.min([-0.05, ymin]),
143143
)
144144

145+
145146
def S1_ST_barplot(
146147
S1, ST, S1_conf, ST_conf,
147148
parameter_labels: dict[str, str],

0 commit comments

Comments
 (0)