22
33
44"""
5+ import multiprocessing
56import os
67import time
7- import multiprocessing
88from dataclasses import dataclass
9- from typing import Optional , Any
109from pathlib import Path
11- from rich .progress import track
12- from pymetadata .console import console
10+ from typing import Optional , Any
1311
12+ import dill
1413import numpy as np
1514import pandas as pd
16- import xarray as xr
17-
18-
1915import roadrunner
20- import dill
16+ import xarray as xr
17+ from pymetadata .console import console
18+ from rich .progress import track
2119
2220from sbmlsim .sensitivity .parameters import SensitivityParameter
23- from sbmlsim .sensitivity .outputs import SensitivityOutput
2421from sbmlsim .sensitivity .plots import heatmap
2522
2623
24+ @dataclass
25+ class SensitivityOutput :
26+ """Output measurement for SensitivityAnalysis."""
27+ uid : str
28+ name : str
29+ unit : Optional [str ]
30+
31+
2732@dataclass
2833class AnalysisGroup :
2934 """Subgroup for analysis."""
@@ -59,8 +64,8 @@ def __init__(self, model_path: Path, selections: list[str],
5964 outputs_dict = {q .uid for q in self .outputs }
6065 for key in y :
6166 if key not in outputs_dict :
62- raise ValueError (f"Key ' { key } ' missing in outputs dictionary: ' { outputs_dict } " )
63-
67+ raise ValueError (
68+ f"Key ' { key } ' missing in outputs dictionary: ' { outputs_dict } " )
6469
6570 @staticmethod
6671 def load_model (model_path : Path , selections : list [str ]) -> roadrunner .RoadRunner :
@@ -72,7 +77,8 @@ def load_model(model_path: Path, selections: list[str]) -> roadrunner.RoadRunner
7277 return rr
7378
7479 @staticmethod
75- def apply_changes (r : roadrunner .RoadRunner , changes : dict [str , float ], reset_all : bool = True ) -> None :
80+ def apply_changes (r : roadrunner .RoadRunner , changes : dict [str , float ],
81+ reset_all : bool = True ) -> None :
7682 """Apply changes after possible reset of the model."""
7783 if reset_all :
7884 r .resetAll ()
@@ -81,7 +87,8 @@ def apply_changes(r: roadrunner.RoadRunner, changes: dict[str, float], reset_all
8187 # print(f"{key=} {value=}")
8288 r .setValue (key , value )
8389
84- def simulate (self , r : roadrunner .RoadRunner , changes : dict [str , float ]) -> dict [str , float ]:
90+ def simulate (self , r : roadrunner .RoadRunner , changes : dict [str , float ]) -> dict [
91+ str , float ]:
8592 """Run a model simulation and return scalar results dictionary."""
8693
8794 raise NotImplemented
@@ -115,7 +122,7 @@ def __init__(self,
115122 parameters : list [SensitivityParameter ],
116123 groups : list [AnalysisGroup ],
117124 results_path : Path ,
118- seed : Optional [int ]= None ,
125+ seed : Optional [int ] = None ,
119126 ) -> None :
120127 """Create a sensitivity analysis for given parameter ids.
121128
@@ -161,20 +168,15 @@ def __init__(self,
161168
162169 # multiple sensitivities are stored
163170 # sensitivity matrix; shape: (num_parameters x num_outputs); could be multiple
164- self .sensitivity : dict [str , dict [str , xr .DataArray ]] = {g .uid : {} for g in self .groups }
171+ self .sensitivity : dict [str , dict [str , xr .DataArray ]] = {g .uid : {} for g in
172+ self .groups }
165173
166174 def samples_table (self ) -> pd .DataFrame :
167175 return self ._data_table (d = self .samples )
168176
169177 def results_table (self ) -> pd .DataFrame :
170178 return self ._data_table (d = self .results )
171179
172- # def sensitivity_tables(self) -> dict[str, pd.DataFrame]:
173- #
174- # for group in self.groups:
175- # for key in group.changes.keys():
176- # return {k: self._data_table(d=d) for k, d in self.sensitivity.items()}
177-
178180 def _data_table (self , d : dict [str , xr .DataArray ]) -> pd .DataFrame :
179181 items = []
180182 for group in self .groups :
@@ -188,7 +190,8 @@ def _data_table(self, d: dict[str, xr.DataArray]) -> pd.DataFrame:
188190 return pd .DataFrame (items )
189191
190192 def read_cache (self , cache_filename : str , cache : bool ) -> Optional [Any ]:
191- cache_path : Optional [Path ] = self .results_path / cache_filename if cache_filename else None
193+ cache_path : Optional [
194+ Path ] = self .results_path / cache_filename if cache_filename else None
192195 if cache and not cache_path :
193196 raise ValueError ("Cache path is required for caching." )
194197
@@ -202,7 +205,8 @@ def read_cache(self, cache_filename: str, cache: bool) -> Optional[Any]:
202205 return None
203206
204207 def write_cache (self , data : Any , cache_filename : str , cache : bool ) -> Optional [Any ]:
205- cache_path : Optional [Path ] = self .results_path / cache_filename if cache_filename else None
208+ cache_path : Optional [
209+ Path ] = self .results_path / cache_filename if cache_filename else None
206210 if cache_path :
207211 with open (cache_path , 'wb' ) as f :
208212 console .print (f"Simulated samples written to cache: '{ cache_path } '" )
@@ -247,7 +251,8 @@ def num_samples(self) -> int:
247251 samples = self .samples [self .group_ids [0 ]]
248252 return samples .shape [0 ]
249253
250- def simulate_samples (self , cache_filename : Optional [str ] = None , cache : bool = False ) -> None :
254+ def simulate_samples (self , cache_filename : Optional [str ] = None ,
255+ cache : bool = False ) -> None :
251256 """Simulate all samples in parallel.
252257
253258 :param cache_filename: Path to the cache path.
@@ -320,7 +325,8 @@ def split_into_chunks(items, n):
320325 # write to cache
321326 self .write_cache (data = self .results , cache_filename = cache_filename , cache = cache )
322327
323- def calculate_sensitivity (self , cache_filename : Optional [str ] = None , cache : bool = False ):
328+ def calculate_sensitivity (self , cache_filename : Optional [str ] = None ,
329+ cache : bool = False ):
324330 """Calculate the sensitivity matrices."""
325331
326332 raise NotImplemented
@@ -361,14 +367,14 @@ def plot_sensitivity(
361367 )
362368
363369
364-
365370def run_simulation (
366371 params_tuple
367372):
368373 """Pass all required arguments as parameter tuple."""
369374 sensitivity_simulation , r , chunked_changes = params_tuple
370375 outputs = []
371- for kc in track (range (len (chunked_changes )), description = f"Simulate samples PID={ os .getpid ()} " ):
376+ for kc in track (range (len (chunked_changes )),
377+ description = f"Simulate samples PID={ os .getpid ()} " ):
372378 changes = chunked_changes [kc ]
373379 # console.print(f"PID={os.getpid()} | k={kc}")
374380 Y = sensitivity_simulation .simulate (
@@ -378,6 +384,3 @@ def run_simulation(
378384 outputs .append (Y )
379385
380386 return outputs
381-
382-
383-
0 commit comments