@@ -123,6 +123,8 @@ def __init__(self,
123123 groups : list [AnalysisGroup ],
124124 results_path : Path ,
125125 seed : Optional [int ] = None ,
126+ n_cores : Optional [int ] = None ,
127+ cache_results : bool = False ,
126128 ) -> None :
127129 """Create a sensitivity analysis for given parameter ids.
128130
@@ -160,6 +162,15 @@ def __init__(self,
160162 if seed is not None :
161163 np .random .seed (seed )
162164
165+ # caching
166+ self .cache_results : bool = cache_results
167+ self .prefix : str = self .__class__ .__name__
168+
169+ # handle compute resources
170+ if not n_cores :
171+ n_cores = int (round (0.9 * multiprocessing .cpu_count ()))
172+ self .n_cores = n_cores
173+
163174 # parameter samples for sensitivity; shape: (num_samples x num_parameters)
164175 self .samples : dict [str , Optional [xr .DataArray ]] = {}
165176
@@ -171,47 +182,6 @@ def __init__(self,
171182 self .sensitivity : dict [str , dict [str , xr .DataArray ]] = {g .uid : {} for g in
172183 self .groups }
173184
174- def samples_table (self ) -> pd .DataFrame :
175- return self ._data_table (d = self .samples )
176-
177- def results_table (self ) -> pd .DataFrame :
178- return self ._data_table (d = self .results )
179-
180- def _data_table (self , d : dict [str , xr .DataArray ]) -> pd .DataFrame :
181- items = []
182- for group in self .groups :
183- da : xr .DataArray = d [group .uid ]
184- item = {
185- 'group' : group .uid ,
186- # 'group_name': group.name,
187- ** da .sizes ,
188- }
189- items .append (item )
190- return pd .DataFrame (items )
191-
192- def read_cache (self , cache_filename : str , cache : bool ) -> Optional [Any ]:
193- cache_path : Optional [
194- Path ] = self .results_path / cache_filename if cache_filename else None
195- if cache and not cache_path :
196- raise ValueError ("Cache path is required for caching." )
197-
198- # retrieve from cache
199- if cache and cache_path .exists ():
200- with open (cache_path , 'rb' ) as f :
201- data = dill .load (f )
202- console .print (f"Simulated samples loaded from cache: '{ cache_path } '" )
203- return data
204-
205- return None
206-
207- def write_cache (self , data : Any , cache_filename : str , cache : bool ) -> Optional [Any ]:
208- cache_path : Optional [
209- Path ] = self .results_path / cache_filename if cache_filename else None
210- if cache_path :
211- with open (cache_path , 'wb' ) as f :
212- console .print (f"Simulated samples written to cache: '{ cache_path } '" )
213- dill .dump (data , f )
214-
215185 @property
216186 def output_ids (self ) -> list [str ]:
217187 return [o .uid for o in self .outputs ]
@@ -236,6 +206,30 @@ def num_outputs(self) -> int:
236206 def num_groups (self ) -> int :
237207 return len (self .groups )
238208
209+ def execute (self ):
210+ """Execute the sensitivity analysis."""
211+ console .rule (
212+ f"{ self .__class__ .__name__ } " ,
213+ style = "blue bold" ,
214+ align = "center" ,
215+ )
216+ console .rule ("Samples" , style = "white" )
217+ self .create_samples ()
218+ console .print (self .samples_table ())
219+
220+ console .rule ("Results" , style = "white" )
221+ self .simulate_samples (
222+ cache_filename = f"{ self .prefix } _results.pkl" ,
223+ cache = self .cache_results ,
224+ )
225+ console .print (self .results_table ())
226+
227+ console .rule ("Sensitivity" , style = "white" )
228+ self .calculate_sensitivity (
229+ cache_filename = f"{ self .prefix } _sensitivity.pkl" ,
230+ cache = self .cache_results ,
231+ )
232+
239233 def create_samples (self ) -> None :
240234 """Create and set parameter samples."""
241235
@@ -283,8 +277,6 @@ def simulate_samples(self, cache_filename: Optional[str] = None,
283277 )
284278
285279 # number of cores
286- n_cores = multiprocessing .cpu_count ()
287-
288280 samples = self .samples [group .uid ]
289281
290282 # create chunk of samples for core
@@ -305,13 +297,13 @@ def split_into_chunks(items, n):
305297 return chunks , chunked_samples
306298
307299 items = list (range (self .num_samples ))
308- chunks , chunked_samples = split_into_chunks (items , n_cores )
300+ chunks , chunked_samples = split_into_chunks (items , self . n_cores )
309301
310302 # parameters for multiprocessing
311303 sa_sim = self .sensitivity_simulation
312- rrs = [(sa_sim , r , chunked_samples [i ]) for i in range (n_cores )]
304+ rrs = [(sa_sim , r , chunked_samples [i ]) for i in range (self . n_cores )]
313305
314- with multiprocessing .Pool (processes = n_cores ) as pool :
306+ with multiprocessing .Pool (processes = self . n_cores ) as pool :
315307 outputs_list : list = pool .map (run_simulation , rrs )
316308
317309 for kc , chunk in enumerate (chunks ):
@@ -331,6 +323,47 @@ def calculate_sensitivity(self, cache_filename: Optional[str] = None,
331323
332324 raise NotImplemented
333325
326+ def samples_table (self ) -> pd .DataFrame :
327+ return self ._data_table (d = self .samples )
328+
329+ def results_table (self ) -> pd .DataFrame :
330+ return self ._data_table (d = self .results )
331+
332+ def _data_table (self , d : dict [str , xr .DataArray ]) -> pd .DataFrame :
333+ items = []
334+ for group in self .groups :
335+ da : xr .DataArray = d [group .uid ]
336+ item = {
337+ 'group' : group .uid ,
338+ # 'group_name': group.name,
339+ ** da .sizes ,
340+ }
341+ items .append (item )
342+ return pd .DataFrame (items )
343+
344+ def read_cache (self , cache_filename : str , cache : bool ) -> Optional [Any ]:
345+ cache_path : Optional [
346+ Path ] = self .results_path / cache_filename if cache_filename else None
347+ if cache and not cache_path :
348+ raise ValueError ("Cache path is required for caching." )
349+
350+ # retrieve from cache
351+ if cache and cache_path .exists ():
352+ with open (cache_path , 'rb' ) as f :
353+ data = dill .load (f )
354+ console .print (f"Simulated samples loaded from cache: '{ cache_path } '" )
355+ return data
356+
357+ return None
358+
359+ def write_cache (self , data : Any , cache_filename : str , cache : bool ) -> Optional [Any ]:
360+ cache_path : Optional [
361+ Path ] = self .results_path / cache_filename if cache_filename else None
362+ if cache_path :
363+ with open (cache_path , 'wb' ) as f :
364+ console .print (f"Simulated samples written to cache: '{ cache_path } '" )
365+ dill .dump (data , f )
366+
334367 def sensitivity_df (self , group_id : str , key : str ) -> pd .DataFrame :
335368 """Convert sensitivity information to dataframes."""
336369
@@ -341,6 +374,10 @@ def sensitivity_df(self, group_id: str, key: str) -> pd.DataFrame:
341374 index = sensitivity .coords ["parameter" ]
342375 )
343376
377+ def plot (self ):
378+ """Should be implemented by subclass."""
379+ console .rule ("Plotting" , style = "white" )
380+
344381 def plot_sensitivity (
345382 self ,
346383 group_id : str ,
0 commit comments