1+ from __future__ import annotations
2+
13import math
2- from typing import Dict
3- from typing import List
4- from typing import Tuple
54from unittest .mock import Mock
65
76import numpy as np
@@ -28,7 +27,7 @@ def tree() -> _FanovaTree:
2827
2928
3029@pytest .fixture
31- def expected_tree_statistics () -> List [ Dict [str , List ]]:
30+ def expected_tree_statistics () -> list [ dict [str , list ]]:
3231 # Statistics the each node in the tree.
3332 return [
3433 {"values" : [0.1 , 0.2 , 0.5 ], "weights" : [0.75 , 0.25 , 1.0 ]},
@@ -39,7 +38,7 @@ def expected_tree_statistics() -> List[Dict[str, List]]:
3938 ]
4039
4140
42- def test_tree_variance (tree : _FanovaTree , expected_tree_statistics : List [ Dict [str , List ]]) -> None :
41+ def test_tree_variance (tree : _FanovaTree , expected_tree_statistics : list [ dict [str , list ]]) -> None :
4342 # The root node at node index `0` holds the values and weights for all nodes in the tree.
4443 expected_statistics = expected_tree_statistics [0 ]
4544 expected_values = expected_statistics ["values" ]
@@ -87,9 +86,9 @@ def test_tree_variance(tree: _FanovaTree, expected_tree_statistics: List[Dict[st
8786)
8887def test_tree_get_marginal_variance (
8988 tree : _FanovaTree ,
90- features : List [int ],
91- expected : List [ Tuple [ List [Size ], List [ Tuple [NodeIndex , Cardinality ]]]],
92- expected_tree_statistics : List [ Dict [str , List ]],
89+ features : list [int ],
90+ expected : list [ tuple [ list [Size ], list [ tuple [NodeIndex , Cardinality ]]]],
91+ expected_tree_statistics : list [ dict [str , list ]],
9392) -> None :
9493 variance = tree .get_marginal_variance (np .array (features ))
9594
@@ -145,9 +144,9 @@ def test_tree_get_marginal_variance(
145144)
146145def test_tree_get_marginalized_statistics (
147146 tree : _FanovaTree ,
148- feature_vector : List [float ],
149- expected : List [ Tuple [NodeIndex , Cardinality ]],
150- expected_tree_statistics : List [ Dict [str , List ]],
147+ feature_vector : list [float ],
148+ expected : list [ tuple [NodeIndex , Cardinality ]],
149+ expected_tree_statistics : list [ dict [str , list ]],
151150) -> None :
152151 value , weight = tree ._get_marginalized_statistics (np .array (feature_vector ))
153152
@@ -167,7 +166,7 @@ def test_tree_get_marginalized_statistics(
167166
168167
169168def test_tree_statistics (
170- tree : _FanovaTree , expected_tree_statistics : List [ Dict [str , List ]]
169+ tree : _FanovaTree , expected_tree_statistics : list [ dict [str , list ]]
171170) -> None :
172171 statistics = tree ._statistics
173172
@@ -184,13 +183,13 @@ def test_tree_statistics(
184183
185184@pytest .mark .parametrize ("node_index,expected" , [(0 , [0.5 ]), (1 , [0.25 , 0.75 ]), (2 , [0.75 , 1.75 ])])
186185def test_tree_split_midpoints (
187- tree : _FanovaTree , node_index : NodeIndex , expected : List [float ]
186+ tree : _FanovaTree , node_index : NodeIndex , expected : list [float ]
188187) -> None :
189188 np .testing .assert_equal (tree ._split_midpoints [node_index ], expected )
190189
191190
192191@pytest .mark .parametrize ("node_index,expected" , [(0 , [1.0 ]), (1 , [0.5 , 0.5 ]), (2 , [1.5 , 0.5 ])])
193- def test_tree_split_sizes (tree : _FanovaTree , node_index : NodeIndex , expected : List [float ]) -> None :
192+ def test_tree_split_sizes (tree : _FanovaTree , node_index : NodeIndex , expected : list [float ]) -> None :
194193 np .testing .assert_equal (tree ._split_sizes [node_index ], expected )
195194
196195
@@ -205,7 +204,7 @@ def test_tree_split_sizes(tree: _FanovaTree, node_index: NodeIndex, expected: Li
205204 ],
206205)
207206def test_tree_subtree_active_features (
208- tree : _FanovaTree , node_index : NodeIndex , expected : List [bool ]
207+ tree : _FanovaTree , node_index : NodeIndex , expected : list [bool ]
209208) -> None :
210209 active_features : np .ndarray = tree ._subtree_active_features [node_index ] == expected
211210 assert active_features .all ()
0 commit comments