Skip to content

Commit 2e2f913

Browse files
committed
Fixed a couple of pyright issues.
1 parent 4032a6c commit 2e2f913

9 files changed

Lines changed: 179 additions & 53 deletions

File tree

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
run: uv run python -m pytest tests --cov --cov-config=pyproject.toml --cov-report=xml
4848

4949
- name: Check typing
50-
run: uv run mypy
50+
run: uv run pyright
5151

5252

5353
- name: Upload coverage reports to Codecov with GitHub Action on Python 3.11

pyproject.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,17 @@ ignore = [
8585
"E501",
8686
# DoNotAssignLambda
8787
"E731",
88+
# Argument name should be lowercase (conflicting with sklearn)
89+
"N803",
90+
# Dynamically typed expressions (setup protocol type)
91+
"ANN401",
92+
# Boolean-typed positional argument in function definition
93+
"FBT001",
8894
]
8995

96+
[tool.ruff.lint.pep8-naming]
97+
ignore-names = ["X"]
98+
9099
[tool.ruff.lint.isort]
91100
known-first-party = ["hypershap"]
92101
extra-standard-library = ["typing_extensions"]

src/hypershap/games/ablation.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,10 @@
2424

2525
from __future__ import annotations
2626

27-
from typing import TYPE_CHECKING
28-
2927
import numpy as np
3028

3129
from hypershap.games.abstract import AbstractHPIGame
32-
33-
if TYPE_CHECKING:
34-
from hypershap.task import AblationExplanationTask
30+
from hypershap.task import AblationExplanationTask
3531

3632

3733
class AblationGame(AbstractHPIGame):
@@ -82,7 +78,13 @@ def evaluate_single_coalition(self, coalition: np.ndarray) -> float:
8278
baseline_cfg = self._get_explanation_task().baseline_config.get_array()
8379
cfg_of_interest = self._get_explanation_task().config_of_interest.get_array()
8480
blend = np.where(coalition == 0, baseline_cfg, cfg_of_interest)
85-
return self._get_explanation_task().surrogate_model.evaluate(blend)
81+
res = self._get_explanation_task().surrogate_model.evaluate(blend)
82+
83+
# validate that we do not get a list of floats by accident
84+
if isinstance(res, list):
85+
raise TypeError
86+
87+
return res
8688

8789
def _get_explanation_task(self) -> AblationExplanationTask:
8890
"""Retrieve the explanation task associated with this ablation game.
@@ -94,4 +96,6 @@ def _get_explanation_task(self) -> AblationExplanationTask:
9496
AblationExplanationTask: The explanation task associated with this ablation game.
9597
9698
"""
97-
return self.explanation_task
99+
if isinstance(self.explanation_task, AblationExplanationTask):
100+
return self.explanation_task
101+
raise ValueError

src/hypershap/games/optimizerbias.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
if TYPE_CHECKING:
2525
import numpy as np
2626

27-
from hypershap.task import OptimizerBiasExplanationTask
28-
2927
from hypershap.games import AbstractHPIGame
28+
from hypershap.task import OptimizerBiasExplanationTask
3029

3130

3231
class OptimizerBiasGame(AbstractHPIGame):
@@ -82,7 +81,9 @@ def __init__(
8281
super().__init__(explanation_task, n_workers=n_workers, verbose=verbose)
8382

8483
def _get_explanation_task(self) -> OptimizerBiasExplanationTask:
85-
return self.explanation_task
84+
if isinstance(self.explanation_task, OptimizerBiasExplanationTask):
85+
return self.explanation_task
86+
raise ValueError
8687

8788
def evaluate_single_coalition(self, coalition: np.ndarray) -> float:
8889
"""Evaluate a single coalition by comparing against an optimizer ensemble.

src/hypershap/games/tunability.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
if TYPE_CHECKING:
2525
import numpy as np
2626

27-
from hypershap.task import TunabilityExplanationTask
27+
from hypershap.task import (
28+
BaselineExplanationTask,
29+
MistunabilityExplanationTask,
30+
SensitivityExplanationTask,
31+
TunabilityExplanationTask,
32+
)
2833

2934
from hypershap.games.abstract import AbstractHPIGame
3035
from hypershap.utils import ConfigSpaceSearcher, RandomConfigSpaceSearcher
@@ -37,8 +42,8 @@ class SearchBasedGame(AbstractHPIGame):
3742

3843
def __init__(
3944
self,
40-
explanation_task: TunabilityExplanationTask,
41-
cs_searcher: ConfigSpaceSearcher = None,
45+
explanation_task: BaselineExplanationTask,
46+
cs_searcher: ConfigSpaceSearcher,
4247
n_workers: int | None = None,
4348
verbose: bool | None = None,
4449
) -> None:
@@ -80,7 +85,7 @@ class TunabilityGame(SearchBasedGame):
8085
def __init__(
8186
self,
8287
explanation_task: TunabilityExplanationTask,
83-
cs_searcher: ConfigSpaceSearcher = None,
88+
cs_searcher: ConfigSpaceSearcher | None = None,
8489
n_workers: int | None = None,
8590
verbose: bool | None = None,
8691
) -> None:
@@ -113,8 +118,8 @@ class SensitivityGame(SearchBasedGame):
113118

114119
def __init__(
115120
self,
116-
explanation_task: TunabilityExplanationTask,
117-
cs_searcher: ConfigSpaceSearcher = None,
121+
explanation_task: SensitivityExplanationTask,
122+
cs_searcher: ConfigSpaceSearcher | None = None,
118123
n_workers: int | None = None,
119124
verbose: bool | None = None,
120125
) -> None:
@@ -143,13 +148,13 @@ def __init__(
143148
super().__init__(explanation_task, cs_searcher, n_workers=n_workers, verbose=verbose)
144149

145150

146-
class MistunabilityGame(TunabilityGame):
151+
class MistunabilityGame(SearchBasedGame):
147152
"""Game representing the mistunability of hyperparameters."""
148153

149154
def __init__(
150155
self,
151-
explanation_task: TunabilityExplanationTask,
152-
cs_searcher: ConfigSpaceSearcher = None,
156+
explanation_task: MistunabilityExplanationTask,
157+
cs_searcher: ConfigSpaceSearcher | None = None,
153158
n_workers: int | None = None,
154159
verbose: bool | None = None,
155160
) -> None:

src/hypershap/hypershap.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import logging
1717

1818
import matplotlib.pyplot as plt
19+
import networkx as nx
20+
import numpy as np
1921
from shapiq import ExactComputer, InteractionValues
2022

2123
from hypershap.games import (
@@ -144,7 +146,7 @@ def ablation(
144146

145147
def tunability(
146148
self,
147-
baseline_config: Configuration = None,
149+
baseline_config: Configuration | None = None,
148150
index: str = "FSII",
149151
order: int = 2,
150152
n_samples: int = 10_000,
@@ -161,6 +163,9 @@ def tunability(
161163
InteractionValues: The computed interaction values.
162164
163165
"""
166+
if baseline_config is None:
167+
baseline_config = self.explanation_task.config_space.get_default_configuration()
168+
164169
# setup explanation task
165170
tunability_task: TunabilityExplanationTask = TunabilityExplanationTask(
166171
config_space=self.explanation_task.config_space,
@@ -183,7 +188,7 @@ def tunability(
183188

184189
def sensitivity(
185190
self,
186-
baseline_config: Configuration = None,
191+
baseline_config: Configuration | None = None,
187192
index: str = "FSII",
188193
order: int = 2,
189194
n_samples: int = 10_000,
@@ -200,18 +205,21 @@ def sensitivity(
200205
InteractionValues: The computed interaction values.
201206
202207
"""
208+
if baseline_config is None:
209+
baseline_config = self.explanation_task.config_space.get_default_configuration()
210+
203211
# setup explanation task
204-
tunability_task: SensitivityExplanationTask = SensitivityExplanationTask(
212+
sensitivity_task: SensitivityExplanationTask = SensitivityExplanationTask(
205213
config_space=self.explanation_task.config_space,
206214
surrogate_model=self.explanation_task.surrogate_model,
207215
baseline_config=baseline_config,
208216
)
209217

210218
# setup tunability game and get interaction values
211219
tg = SensitivityGame(
212-
explanation_task=tunability_task,
220+
explanation_task=sensitivity_task,
213221
cs_searcher=RandomConfigSpaceSearcher(
214-
explanation_task=tunability_task,
222+
explanation_task=sensitivity_task,
215223
n_samples=n_samples,
216224
mode="var",
217225
),
@@ -222,7 +230,7 @@ def sensitivity(
222230

223231
def mistunability(
224232
self,
225-
baseline_config: Configuration = None,
233+
baseline_config: Configuration | None = None,
226234
index: str = "FSII",
227235
order: int = 2,
228236
n_samples: int = 10_000,
@@ -239,18 +247,21 @@ def mistunability(
239247
InteractionValues: The computed interaction values.
240248
241249
"""
250+
if baseline_config is None:
251+
baseline_config = self.explanation_task.config_space.get_default_configuration()
252+
242253
# setup explanation task
243-
tunability_task: MistunabilityExplanationTask = MistunabilityExplanationTask(
254+
mistunability_task: MistunabilityExplanationTask = MistunabilityExplanationTask(
244255
config_space=self.explanation_task.config_space,
245256
surrogate_model=self.explanation_task.surrogate_model,
246257
baseline_config=baseline_config,
247258
)
248259

249260
# setup tunability game and get interaction values
250261
tg = MistunabilityGame(
251-
explanation_task=tunability_task,
262+
explanation_task=mistunability_task,
252263
cs_searcher=RandomConfigSpaceSearcher(
253-
explanation_task=tunability_task,
264+
explanation_task=mistunability_task,
254265
n_samples=n_samples,
255266
mode="min",
256267
),
@@ -284,11 +295,10 @@ def optimizer_bias(
284295
surrogate_model=self.explanation_task.surrogate_model,
285296
optimizer_of_interest=optimizer_of_interest,
286297
optimizer_ensemble=optimizer_ensemble,
287-
n_workers=self.n_workers,
288298
)
289299

290300
# setup optimizer bias game and get interaction values
291-
og = OptimizerBiasGame(explanation_task=optimizer_bias_task)
301+
og = OptimizerBiasGame(explanation_task=optimizer_bias_task, n_workers=self.n_workers, verbose=self.verbose)
292302
return self.__get_interaction_values(game=og, index=index, order=order)
293303

294304
def plot_si_graph(self, interaction_values: InteractionValues | None = None, save_path: str | None = None) -> None:
@@ -304,9 +314,11 @@ def plot_si_graph(self, interaction_values: InteractionValues | None = None, sav
304314

305315
# if given interaction values use those, else use cached interaction values
306316
iv = interaction_values if interaction_values is not None else self.last_interaction_values
307-
hyperparameter_names = self.explanation_task.get_hyperparameter_names()
308317

309-
import networkx as nx
318+
if not isinstance(iv, InteractionValues):
319+
raise TypeError
320+
321+
hyperparameter_names = self.explanation_task.get_hyperparameter_names()
310322

311323
def get_circular_layout(n_players: int) -> dict:
312324
original_graph, graph_nodes = nx.Graph(), []
@@ -345,9 +357,17 @@ def plot_upset(self, interaction_values: InteractionValues | None = None, save_p
345357

346358
# if given interaction values use those, else use cached interaction values
347359
iv = interaction_values if interaction_values is not None else self.last_interaction_values
360+
361+
if not isinstance(iv, InteractionValues):
362+
raise TypeError
363+
348364
hyperparameter_names = self.explanation_task.get_hyperparameter_names()
349365

350366
fig = iv.plot_upset(feature_names=hyperparameter_names, show=False)
367+
368+
if fig is None:
369+
raise TypeError
370+
351371
ax = fig.get_axes()[0]
352372
ax.set_ylabel("Performance Gain")
353373
# also add "parameter" to the y-axis label
@@ -374,9 +394,13 @@ def plot_force(self, interaction_values: InteractionValues | None = None, save_p
374394

375395
# if given interaction values use those, else use cached interaction values
376396
iv = interaction_values if interaction_values is not None else self.last_interaction_values
397+
398+
if not isinstance(iv, InteractionValues):
399+
raise TypeError
400+
377401
hyperparameter_names = self.explanation_task.get_hyperparameter_names()
378402

379-
iv.plot_force(feature_names=hyperparameter_names, show=False)
403+
iv.plot_force(feature_names=np.array(hyperparameter_names), show=False)
380404
plt.tight_layout()
381405

382406
if save_path is not None:

0 commit comments

Comments
 (0)