Skip to content

Commit ef79c31

Browse files
authored
Merge pull request optuna#5731 from guisp03/fix/future-annotations-test_tree.py
Use `__future__.annotations` in `tests/importance_tests/fanova_tests/test_tree.py`
2 parents c014b9d + e07a7ce commit ef79c31

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

tests/importance_tests/fanova_tests/test_tree.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
from __future__ import annotations
2+
13
import math
2-
from typing import Dict
3-
from typing import List
4-
from typing import Tuple
54
from unittest.mock import Mock
65

76
import numpy as np
@@ -28,7 +27,7 @@ def tree() -> _FanovaTree:
2827

2928

3029
@pytest.fixture
31-
def expected_tree_statistics() -> List[Dict[str, List]]:
30+
def expected_tree_statistics() -> list[dict[str, list]]:
3231
# Statistics the each node in the tree.
3332
return [
3433
{"values": [0.1, 0.2, 0.5], "weights": [0.75, 0.25, 1.0]},
@@ -39,7 +38,7 @@ def expected_tree_statistics() -> List[Dict[str, List]]:
3938
]
4039

4140

42-
def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[str, List]]) -> None:
41+
def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: list[dict[str, list]]) -> None:
4342
# The root node at node index `0` holds the values and weights for all nodes in the tree.
4443
expected_statistics = expected_tree_statistics[0]
4544
expected_values = expected_statistics["values"]
@@ -87,9 +86,9 @@ def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[st
8786
)
8887
def test_tree_get_marginal_variance(
8988
tree: _FanovaTree,
90-
features: List[int],
91-
expected: List[Tuple[List[Size], List[Tuple[NodeIndex, Cardinality]]]],
92-
expected_tree_statistics: List[Dict[str, List]],
89+
features: list[int],
90+
expected: list[tuple[list[Size], list[tuple[NodeIndex, Cardinality]]]],
91+
expected_tree_statistics: list[dict[str, list]],
9392
) -> None:
9493
variance = tree.get_marginal_variance(np.array(features))
9594

@@ -145,9 +144,9 @@ def test_tree_get_marginal_variance(
145144
)
146145
def test_tree_get_marginalized_statistics(
147146
tree: _FanovaTree,
148-
feature_vector: List[float],
149-
expected: List[Tuple[NodeIndex, Cardinality]],
150-
expected_tree_statistics: List[Dict[str, List]],
147+
feature_vector: list[float],
148+
expected: list[tuple[NodeIndex, Cardinality]],
149+
expected_tree_statistics: list[dict[str, list]],
151150
) -> None:
152151
value, weight = tree._get_marginalized_statistics(np.array(feature_vector))
153152

@@ -167,7 +166,7 @@ def test_tree_get_marginalized_statistics(
167166

168167

169168
def test_tree_statistics(
170-
tree: _FanovaTree, expected_tree_statistics: List[Dict[str, List]]
169+
tree: _FanovaTree, expected_tree_statistics: list[dict[str, list]]
171170
) -> None:
172171
statistics = tree._statistics
173172

@@ -184,13 +183,13 @@ def test_tree_statistics(
184183

185184
@pytest.mark.parametrize("node_index,expected", [(0, [0.5]), (1, [0.25, 0.75]), (2, [0.75, 1.75])])
186185
def test_tree_split_midpoints(
187-
tree: _FanovaTree, node_index: NodeIndex, expected: List[float]
186+
tree: _FanovaTree, node_index: NodeIndex, expected: list[float]
188187
) -> None:
189188
np.testing.assert_equal(tree._split_midpoints[node_index], expected)
190189

191190

192191
@pytest.mark.parametrize("node_index,expected", [(0, [1.0]), (1, [0.5, 0.5]), (2, [1.5, 0.5])])
193-
def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: List[float]) -> None:
192+
def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: list[float]) -> None:
194193
np.testing.assert_equal(tree._split_sizes[node_index], expected)
195194

196195

@@ -205,7 +204,7 @@ def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: Li
205204
],
206205
)
207206
def test_tree_subtree_active_features(
208-
tree: _FanovaTree, node_index: NodeIndex, expected: List[bool]
207+
tree: _FanovaTree, node_index: NodeIndex, expected: list[bool]
209208
) -> None:
210209
active_features: np.ndarray = tree._subtree_active_features[node_index] == expected
211210
assert active_features.all()

0 commit comments

Comments
 (0)