@@ -249,16 +249,34 @@ def sensitivity_df(self, key="normalized") -> pd.DataFrame:
249249 index = self .sensitivity [key ].coords ["parameter" ]
250250 )
251251
252+ def plot_sensitivity (
253+ self ,
254+ key : str , cutoff = 0.1 ,
255+ cluster_rows : bool = True ,
256+ title : Optional [str ] = None ,
257+ cmap : str = "seismic" ,
258+ ** kwargs
259+ ) -> None :
260+ df = self .sensitivity_df (key = key )
261+ heatmap (
262+ df = df ,
263+ parameter_labels = {p .uid : p .name for p in self .parameters },
264+ output_labels = {q .uid : q .name for q in self .outputs },
265+ cutoff = cutoff ,
266+ cluster_rows = cluster_rows ,
267+ title = title ,
268+ cmap = cmap ,
269+ ** kwargs
270+ )
271+
252272import os
253273
254274def run_simulation (
255275 params_tuple
256276):
257277 """Pass all required arguments as parameter tuple."""
258278 sensitivity_simulation , r , chunked_changes = params_tuple
259-
260279 outputs = []
261-
262280 for kc in track (range (len (chunked_changes )), description = f"Simulate samples PID={ os .getpid ()} " ):
263281 changes = chunked_changes [kc ]
264282 # console.print(f"PID={os.getpid()} | k={kc}")
@@ -271,16 +289,6 @@ def run_simulation(
271289 return outputs
272290
273291
274-
275-
276-
277-
278-
279-
280-
281-
282-
283-
284292class LocalSensitivityAnalysis (SensitivityAnalysis ):
285293 """Local sensitivity analysis based on local differences.
286294
@@ -371,35 +379,7 @@ def calculate_sensitivity(self):
371379 sensitivity_normalized [kp , ko ] = sensitivity_raw [kp , ko ] * p_ref / q_ref
372380
373381
374- def plot_sensitivity (self , cutoff = 0.1 , cluster_rows : bool = True , title : Optional [str ] = None ):
375- df = self .sensitivity_df (key = "normalized" )
376- self .plot_sensitivity_df (
377- df = df ,
378- parameter_labels = {p .uid : p .name for p in self .parameters },
379- output_labels = {q .uid : q .name for q in self .outputs },
380- cutoff = cutoff ,
381- cluster_rows = cluster_rows ,
382- title = title
383- )
384-
385- @staticmethod
386- def plot_sensitivity_df (
387- df : pd .DataFrame ,
388- parameter_labels : dict [str , str ],
389- output_labels : dict [str , str ],
390- cutoff = 0.1 , cluster_rows : bool = True ,
391- title : Optional [str ] = None ,
392- ):
393- console .print (df )
394382
395- heatmap (
396- df ,
397- parameter_labels = parameter_labels ,
398- output_labels = output_labels ,
399- cutoff = cutoff ,
400- cluster_rows = False ,
401- title = title ,
402- )
403383
404384
405385@dataclass
@@ -457,50 +437,41 @@ def create_samples(self, N: int=1024):
457437
458438
459439 def calculate_sensitivity (self ):
460- # transfer results in libsa results format
440+ """Calculate the sensitivity matrices."""
461441
462442 Y = self .results .values
463443 self .ssa_problem .set_results (Y )
464444
445+ # num_parameters x num_outputs
446+ sensitivity_keys = ["S1" , "ST" , "S1_conf" , "ST_conf" ]
447+ for key in sensitivity_keys :
448+ self .sensitivity [key ] = xr .DataArray (
449+ np .full ((self .num_parameters , self .num_outputs ), np .nan ),
450+ dims = ["parameter" , "output" ],
451+ coords = {"parameter" : self .parameter_ids ,
452+ "output" : self .output_ids },
453+ name = key
454+ )
455+
465456 # Perform Analysis
466457 # Si is a Python dict-like with the keys "S1", "S2", "ST",
467458 # "S1_conf", "S2_conf", and "ST_conf".
468459 # The _conf keys store the corresponding confidence intervals,
469460 # typically with a confidence level of 95%.
470461
471462 # Calculate Sobol indices for every output
472- Si_all = []
473463 for ko in range (self .num_outputs ):
474464 Yo = Y [:, ko ]
475465 Si = SALib .analyze .sobol .analyze (
476466 self .ssa_problem , Yo ,
477467 calc_second_order = True ,
478468 print_to_console = True ,
479469 )
480- Si_all .append (Si )
481-
482- Si .plot ()
483- from matplotlib import pyplot as plt
484- plt .show ()
485-
486- # Si = SALib.analyze.sobol.analyze(
487- # self.ssa_problem, Y,
488- # calc_second_order=True,
489- # print_to_console=True,
490- # )
491-
492- # Store the sensitivity matrices
470+ console .print ("S1" )
471+ console .print (Si ["S1" ])
472+ for key in sensitivity_keys :
473+ self .sensitivity [key ][:, ko ] = Si [key ]
493474
494- sensitivity_total = Si ['ST' ]
495- sensitivity_first = Si ['S1' ]
496- print (Si ['S1' ])
497- print (Si ['ST' ])
498-
499-
500-
501- Si .plot ()
502- from matplotlib import pyplot as plt
503- plt .show ()
504475
505476 def plot (self ):
506477 Si .plot ()
0 commit comments