4848from sbmlsim .sensitivity .outputs import SensitivityOutput
4949import pandas as pd
5050
51+ @dataclass
52+ class SensitivityOutput :
53+ """Output for SensitivityAnalysis"""
54+ uid : str
55+ name : str
56+ unit : Optional [str ] = None
57+
58+ def __hash__ (self ):
59+ return hash (self .uid )
60+
5161@dataclass
5262class SensitivitySimulation :
5363 """Base class for sensitivity calculation.
@@ -60,10 +70,10 @@ class SensitivitySimulation:
6070 model_path : Path
6171 selections : list [str ]
6272 rr : roadrunner .RoadRunner = None
63- outputs : list [str ] = None
73+ outputs : list [SensitivityOutput ] = None
6474 changes_simulation : dict [str , float ] = None
6575
66- def __init__ (self , model_path : Path , selections : list [str ], changes_simulation : dict [str , float ]):
76+ def __init__ (self , model_path : Path , selections : list [str ], changes_simulation : dict [str , float ], outputs : list [ SensitivityOutput ] ):
6777 self .model_path = model_path
6878 self .selections = selections
6979 self .rr : roadrunner .RoadRunner = roadrunner .RoadRunner (str (model_path ))
@@ -74,10 +84,15 @@ def __init__(self, model_path: Path, selections: list[str], changes_simulation:
7484
7585 # store the simulation changes
7686 self .changes_simulation = changes_simulation
87+ self .outputs : list [SensitivityOutput ] = outputs
7788
78- # get the outputs from the simulation
89+ # validate the outputs from the simulation
7990 y = self .simulate (changes = {})
80- self .outputs = list (y .keys ())
91+ outputs_dict = {q .uid for q in self .outputs }
92+ for key in y :
93+ if key not in outputs_dict :
94+ raise ValueError (f"Key '{ key } ' missing in outputs dictionary: '{ outputs_dict } " )
95+
8196
8297
8398 # def output_definitions(self) -> list[SensitivityOutput]:
@@ -118,17 +133,13 @@ def apply_changes(self, changes: dict[str, float], reset_all: bool=True) -> None
118133
119134
120135
121- @ dataclass
136+
122137class SensitivityAnalysis :
123138 """Parent class for all sensitivity analysis.
124139
125140 TODO: additional metadata for the outputs and the parameters; i.e. name, units, bounds, ....
126141 """
127142
128- sensitivity_simulation : SensitivitySimulation
129- parameters : list [SensitivityParameter ]
130- outputs : list [str ]
131-
132143 def __init__ (self , sensitivity_simulation : SensitivitySimulation ,
133144 parameters : SensitivityParameter ) -> None :
134145 """Create a sensitivity analysis for given parameter ids.
@@ -139,8 +150,11 @@ def __init__(self, sensitivity_simulation: SensitivitySimulation,
139150
140151 # parameters to vary; shape: (num_parameters,)
141152 self .parameters : list [SensitivityParameter ] = parameters
153+ self .parameter_ids : list [str ] = [p .uid for p in self .parameters ]
154+
142155 # outputs to calculate sensitivity on; shape: (num_outputs,)
143- self .outputs : list [output ] = sensitivity_simulation .outputs
156+ self .outputs : list [SensitivityOutput ] = sensitivity_simulation .outputs
157+ self .output_ids : list [str ] = [q .uid for q in self .outputs ]
144158
145159 # parameter samples for sensitivity; shape: (num_samples x num_parameters)
146160 self .samples : Optional [xr .DataArray ] = None
@@ -235,7 +249,7 @@ def create_samples(self) -> None:
235249 samples = xr .DataArray (
236250 np .full ((num_samples , self .num_parameters ), np .nan ),
237251 dims = ["sample" , "parameter" ],
238- coords = {"sample" : range (num_samples ), "parameter" : self .parameters },
252+ coords = {"sample" : range (num_samples ), "parameter" : [ p . uid for p in self .parameters ] },
239253 name = "samples"
240254 )
241255
@@ -262,15 +276,15 @@ def calculate_sensitivity(self):
262276 self .sensitivity = xr .DataArray (
263277 np .full ((self .num_parameters , self .num_outputs ), np .nan ),
264278 dims = ["parameter" , "output" ],
265- coords = {"parameter" : [ p . uid for p in self .parameters ] ,
266- "output" : self .outputs },
279+ coords = {"parameter" : self .parameter_ids ,
280+ "output" : self .output_ids },
267281 name = "sensitivity"
268282 )
269283 self .sensitivity_normalized = xr .DataArray (
270284 np .full ((self .num_parameters , self .num_outputs ), np .nan ),
271285 dims = ["parameter" , "output" ],
272- coords = {"parameter" : [ p . uid for p in self .parameters ] ,
273- "output" : self .outputs },
286+ coords = {"parameter" : self .parameter_ids ,
287+ "output" : self .output_ids },
274288 name = "sensitivity"
275289 )
276290
@@ -302,21 +316,33 @@ def sensitivity_df(self) -> pd.DataFrame:
302316
303317 def plot_sensitivity (self ):
304318 df = self .sensitivity_df
305- self .plot_sensitivity_df (df )
319+ self .plot_sensitivity_df (
320+ df = df ,
321+ parameter_labels = {p .uid : p .name for p in self .parameters },
322+ output_labels = {q .uid : q .name for q in self .outputs },
323+ )
306324
307325 @staticmethod
308- def plot_sensitivity_df (df : pd .DataFrame , cutoff = 0.1 , cluster_rows : bool = True ):
326+ def plot_sensitivity_df (
327+ df : pd .DataFrame ,
328+ parameter_labels : dict [str , str ],
329+ output_labels : dict [str , str ],
330+ cutoff = 0.1 , cluster_rows : bool = True
331+ ):
309332 from sbmlsim .sensitivity .plots import heatmap
310333 console .print (df )
311334
312335 # TODO: labels of parameters
313336 # TODO: labels of outputs
314337 # TODO: better position of colorbar
315338
316- heatmap (df , cutoff = cutoff , cluster_rows = False )
317-
318-
319-
339+ heatmap (
340+ df ,
341+ parameter_labels = parameter_labels ,
342+ output_labels = output_labels ,
343+ cutoff = cutoff ,
344+ cluster_rows = False
345+ )
320346
321347@dataclass
322348class SamplingSensitivityAnalysis (SensitivityAnalysis ):
0 commit comments