2020
2121"""
2222from typing import Optional
23+ import xarray as xr
2324
2425import SALib
2526from SALib import ProblemSpec
@@ -79,9 +80,11 @@ def simulate(self, changes: dict[str, float]) -> dict[str, float]:
7980 def parameter_values (self , parameters : list [str ], changes : dict [str , float ]) -> dict [str , float ]:
8081 """Get the parameter values for a given set of changes."""
8182 self .apply_changes (changes , reset_all = True )
83+
8284 values : dict [str , float ] = {}
8385 for pid in parameters :
8486 values [pid ] = self .rr .getValue (pid )
87+
8588 return values
8689
8790
@@ -121,11 +124,11 @@ def __init__(self, sensitivity_simulation: SensitivitySimulation,
121124 # outputs to calculate sensitivity on; shape: (num_outputs,)
122125 self .outputs : list [str ] = sensitivity_simulation .outputs
123126 # parameter samples for sensitivity; shape: (num_samples x num_parameters)
124- self .samples : Optional [np . ndarray ] = None
127+ self .samples : Optional [xr . DataArray ] = None
125128 # outputs for given samples; shape: (num_samples x num_outputs)
126- self .results : Optional [np . ndarray ] = None
129+ self .results : Optional [xr . DataArray ] = None
127130 # sensitivity matrix; shape: (num_parameters x num_outputs); could be multiple
128- self .sensitivity_results : Optional [np . ndarray ] = None
131+ self .sensitivity_results : Optional [xr . DataArray ] = None
129132
130133 @property
131134 def num_parameters (self ) -> int :
@@ -152,7 +155,7 @@ def simulate_samples(self) -> None:
152155 self .samples = np .zeros (shape = (self .num_samples , self .num_parameters ))
153156 self .outputs = np .zeros (shape = (self .num_samples , self .num_outputs ))
154157
155- for k in range (self .num_samples () ):
158+ for k in range (self .num_samples ):
156159 changes = dict (zip (self .parameters , self .samples [k , :]))
157160 outputs = self .sensitivity_simulation .simulate (changes = changes )
158161 self .outputs [k , :] = outputs
@@ -185,26 +188,35 @@ def num_samples(self) -> int:
185188 """Number of parameter samples to simulate."""
186189 return 2 * self .num_parameters
187190
188- def create_samples (self ) -> np . ndarray :
191+ def create_samples (self ) -> None :
189192
190193 # Calculate the parameter values in the reference state
191- parameter_values = self .sensitivity_simulation .parameter_values (
194+ parameter_values : dict [str , float ] = self .sensitivity_simulation .parameter_values (
195+ parameters = self .parameters ,
192196 changes = self .sensitivity_simulation .changes_simulation
193197 )
194198
195199 # (num_samples x num_outputs)
196- samples = np .empty (shape = (self .num_samples , self .num_parameters ))
197-
198- for key , value in :
199- values = np .ones (shape = (2 * num_pars ,)) * value .magnitude
200+ num_samples = 2 * self .num_parameters
201+ samples = np .empty (shape = (num_samples , self .num_parameters ))
202+ samples = xr .DataArray (
203+ np .full ((num_samples , self .num_parameters ), np .nan ),
204+ dims = ["sample" , "parameter" ],
205+ coords = {"sample" : range (num_samples ), "parameter" : self .parameters },
206+ name = "samples"
207+ )
200208
209+ reference_values = np .array (list (parameter_values .values ()))
210+ for kp , pid in enumerate (parameter_values ):
211+ value = parameter_values [pid ]
201212
213+ # right sided changes
214+ samples [2 * kp , :] = reference_values
215+ samples [2 * kp , kp ] = value * (1.0 + self .difference )
216+ samples [2 * kp + 1 , :] = reference_values
217+ samples [2 * kp + 1 , :] = value * (1.0 - self .difference )
202218
203- # change parameters in correct position
204- values [index ] = value .magnitude * (1.0 + difference )
205- values [index + num_pars ] = value .magnitude * (1.0 - difference )
206- changes [key ] = Q_ (values , value .units )
207- index += 1
219+ self .samples = samples
208220
209221 def calculate_sensitivity (self ):
210222
@@ -214,6 +226,117 @@ def plot_sensitivity(self):
214226
215227 pass
216228
229+ from matplotlib import pyplot as plt
230+ import seaborn as sns
231+ import numpy as np
232+
233+ def heatmap (da : xr .DataArray , cutoff : float = 0.01 , annotate_values = True , transpose : bool = False ):
234+ """Creates heatmap of model sensitivity"""
235+
236+ def calculate_mask (df , cutoff = 0.01 ):
237+ """Calculates a boolean mask DataFrame for the heatmap based on cutoff."""
238+ mask = np .empty (shape = df .shape , dtype = "bool" )
239+ for index , value in np .ndenumerate (df ):
240+ if np .abs (value ) < cutoff :
241+ mask [index ] = True
242+ else :
243+ mask [index ] = False
244+ return pd .DataFrame (data = mask , columns = df .COLUMNS , index = df .index )
245+
246+ def calculate_subset (df , cutoff = 0.01 ):
247+ """Calculates subset of data frame consisting of rows where at least
248+ one value is above cutoff."""
249+ return df [(df .abs () >= cutoff ).any (axis = 1 )]
250+
251+
252+
253+ # filter rows
254+ # X.drop(pk_exclude, axis=1, inplace=True)
255+
256+ # if cutoff > 0:
257+ # X_subset = calculate_subset(X, cutoff=cutoff)
258+ # X_subset_mask = calculate_mask(X_subset, cutoff)
259+ da_subset = da
260+
261+ # yticklabels = ["{}".format(pid) for pid in X_subset.index]
262+ # xticklabels = ["{}".format(pnames[pid]["label"]) for pid in X_subset.COLUMNS]
263+
264+ xticklabels = da .coords [da .dims [1 ]]
265+ yticklabels = da .coords [da .dims [0 ]]
266+
267+ # plot heatmap
268+ ax = sns .clustermap (
269+ da_subset ,
270+ center = 0 ,
271+ # vmin=-0.2,
272+ # vmax=0.2,
273+ xticklabels = xticklabels ,
274+ yticklabels = yticklabels ,
275+ cmap = "seismic" ,
276+ cbar_pos = (0.05 , 0.25 , 0.03 , 0.4 ),
277+ annot = annotate_values ,
278+ fmt = "1.2f" ,
279+ annot_kws = {"size" : 13 },
280+ # mask=X_subset_mask,
281+ col_cluster = False ,
282+ method = "single" ,
283+ figsize = (20 , 20 ),
284+ )
285+ plt .setp (
286+ ax .ax_heatmap .get_xticklabels (),
287+ rotation = 45 ,
288+ horizontalalignment = "right" ,
289+ size = 20 ,
290+ )
291+ plt .setp (ax .ax_heatmap .get_yticklabels (), size = 20 )
292+ ax .ax_cbar .tick_params (labelsize = 20 )
293+ ax .ax_row_dendrogram .set_visible (False )
294+ ax .ax_col_dendrogram .set_visible (False )
295+
296+ # create custom legend containing yticklabels and their description
297+ # handles = [t.get_text() for t in ax.ax_heatmap.get_yticklabels()]
298+ # labels = [pnames[pid]["label"] for pid in handles]
299+ #
300+ # # FIXME: update after defining labels
301+ # idx = [pnames[pid]["idx"] for pid in handles]
302+ # # idx = [k for k, pid in enumerate(handles)]
303+ #
304+ # labels = [label for _, label in sorted(zip(idx, labels))]
305+ # handles = [f"{handle}:" for _, handle in sorted(zip(idx, handles))]
306+ # handles = [handle.replace("_", "\_") for handle in handles]
307+
308+ # mid = int(np.ceil(len(handles) / 2))
309+ # legend1 = plt.legend(
310+ # handles[:mid],
311+ # labels[:mid],
312+ # handler_map={str: LegendTitle({"fontsize": 16})},
313+ # fontsize=16,
314+ # frameon=False,
315+ # bbox_to_anchor=(1.2, -0.6),
316+ # loc="upper left",
317+ # handlelength=14,
318+ # )
319+ # legend2 = plt.legend(
320+ # handles[mid:],
321+ # labels[mid:],
322+ # handler_map={str: LegendTitle({"fontsize": 16})},
323+ # fontsize=16,
324+ # frameon=False,
325+ # bbox_to_anchor=(13, -0.6),
326+ # loc="upper left",
327+ # handlelength=19,
328+ # )
329+ # plt.gca().add_artist(legend1)
330+
331+ # plt.savefig(
332+ # results_dir / "parameter.sensitivity_cluster.png", dpi=300, bbox_inches="tight"
333+ # )
334+ # plt.savefig(results_dir / "parameter.sensitivity_cluster.svg", bbox_inches="tight")
335+
336+ plt .show ()
337+
338+
339+
217340
218341@dataclass
219342class SamplingSensitivityAnalysis (SensitivityAnalysis ):
0 commit comments