1616import logging
1717
1818import matplotlib .pyplot as plt
19+ import networkx as nx
20+ import numpy as np
1921from shapiq import ExactComputer , InteractionValues
2022
2123from hypershap .games import (
@@ -144,7 +146,7 @@ def ablation(
144146
145147 def tunability (
146148 self ,
147- baseline_config : Configuration = None ,
149+ baseline_config : Configuration | None = None ,
148150 index : str = "FSII" ,
149151 order : int = 2 ,
150152 n_samples : int = 10_000 ,
@@ -161,6 +163,9 @@ def tunability(
161163 InteractionValues: The computed interaction values.
162164
163165 """
166+ if baseline_config is None :
167+ baseline_config = self .explanation_task .config_space .get_default_configuration ()
168+
164169 # setup explanation task
165170 tunability_task : TunabilityExplanationTask = TunabilityExplanationTask (
166171 config_space = self .explanation_task .config_space ,
@@ -183,7 +188,7 @@ def tunability(
183188
184189 def sensitivity (
185190 self ,
186- baseline_config : Configuration = None ,
191+ baseline_config : Configuration | None = None ,
187192 index : str = "FSII" ,
188193 order : int = 2 ,
189194 n_samples : int = 10_000 ,
@@ -200,18 +205,21 @@ def sensitivity(
200205 InteractionValues: The computed interaction values.
201206
202207 """
208+ if baseline_config is None :
209+ baseline_config = self .explanation_task .config_space .get_default_configuration ()
210+
203211 # setup explanation task
204- tunability_task : SensitivityExplanationTask = SensitivityExplanationTask (
212+ sensitivity_task : SensitivityExplanationTask = SensitivityExplanationTask (
205213 config_space = self .explanation_task .config_space ,
206214 surrogate_model = self .explanation_task .surrogate_model ,
207215 baseline_config = baseline_config ,
208216 )
209217
210218 # setup tunability game and get interaction values
211219 tg = SensitivityGame (
212- explanation_task = tunability_task ,
220+ explanation_task = sensitivity_task ,
213221 cs_searcher = RandomConfigSpaceSearcher (
214- explanation_task = tunability_task ,
222+ explanation_task = sensitivity_task ,
215223 n_samples = n_samples ,
216224 mode = "var" ,
217225 ),
@@ -222,7 +230,7 @@ def sensitivity(
222230
223231 def mistunability (
224232 self ,
225- baseline_config : Configuration = None ,
233+ baseline_config : Configuration | None = None ,
226234 index : str = "FSII" ,
227235 order : int = 2 ,
228236 n_samples : int = 10_000 ,
@@ -239,18 +247,21 @@ def mistunability(
239247 InteractionValues: The computed interaction values.
240248
241249 """
250+ if baseline_config is None :
251+ baseline_config = self .explanation_task .config_space .get_default_configuration ()
252+
242253 # setup explanation task
243- tunability_task : MistunabilityExplanationTask = MistunabilityExplanationTask (
254+ mistunability_task : MistunabilityExplanationTask = MistunabilityExplanationTask (
244255 config_space = self .explanation_task .config_space ,
245256 surrogate_model = self .explanation_task .surrogate_model ,
246257 baseline_config = baseline_config ,
247258 )
248259
249260 # setup tunability game and get interaction values
250261 tg = MistunabilityGame (
251- explanation_task = tunability_task ,
262+ explanation_task = mistunability_task ,
252263 cs_searcher = RandomConfigSpaceSearcher (
253- explanation_task = tunability_task ,
264+ explanation_task = mistunability_task ,
254265 n_samples = n_samples ,
255266 mode = "min" ,
256267 ),
@@ -284,11 +295,10 @@ def optimizer_bias(
284295 surrogate_model = self .explanation_task .surrogate_model ,
285296 optimizer_of_interest = optimizer_of_interest ,
286297 optimizer_ensemble = optimizer_ensemble ,
287- n_workers = self .n_workers ,
288298 )
289299
290300 # setup optimizer bias game and get interaction values
291- og = OptimizerBiasGame (explanation_task = optimizer_bias_task )
301+ og = OptimizerBiasGame (explanation_task = optimizer_bias_task , n_workers = self . n_workers , verbose = self . verbose )
292302 return self .__get_interaction_values (game = og , index = index , order = order )
293303
294304 def plot_si_graph (self , interaction_values : InteractionValues | None = None , save_path : str | None = None ) -> None :
@@ -304,9 +314,11 @@ def plot_si_graph(self, interaction_values: InteractionValues | None = None, sav
304314
305315 # if given interaction values use those, else use cached interaction values
306316 iv = interaction_values if interaction_values is not None else self .last_interaction_values
307- hyperparameter_names = self .explanation_task .get_hyperparameter_names ()
308317
309- import networkx as nx
318+ if not isinstance (iv , InteractionValues ):
319+ raise TypeError
320+
321+ hyperparameter_names = self .explanation_task .get_hyperparameter_names ()
310322
311323 def get_circular_layout (n_players : int ) -> dict :
312324 original_graph , graph_nodes = nx .Graph (), []
@@ -345,9 +357,17 @@ def plot_upset(self, interaction_values: InteractionValues | None = None, save_p
345357
346358 # if given interaction values use those, else use cached interaction values
347359 iv = interaction_values if interaction_values is not None else self .last_interaction_values
360+
361+ if not isinstance (iv , InteractionValues ):
362+ raise TypeError
363+
348364 hyperparameter_names = self .explanation_task .get_hyperparameter_names ()
349365
350366 fig = iv .plot_upset (feature_names = hyperparameter_names , show = False )
367+
368+ if fig is None :
369+ raise TypeError
370+
351371 ax = fig .get_axes ()[0 ]
352372 ax .set_ylabel ("Performance Gain" )
353373 # also add "parameter" to the y-axis label
@@ -374,9 +394,13 @@ def plot_force(self, interaction_values: InteractionValues | None = None, save_p
374394
375395 # if given interaction values use those, else use cached interaction values
376396 iv = interaction_values if interaction_values is not None else self .last_interaction_values
397+
398+ if not isinstance (iv , InteractionValues ):
399+ raise TypeError
400+
377401 hyperparameter_names = self .explanation_task .get_hyperparameter_names ()
378402
379- iv .plot_force (feature_names = hyperparameter_names , show = False )
403+ iv .plot_force (feature_names = np . array ( hyperparameter_names ) , show = False )
380404 plt .tight_layout ()
381405
382406 if save_path is not None :
0 commit comments