33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6- # pyre-unsafe
7-
8- from __future__ import absolute_import , division , print_function , unicode_literals
6+ # pyre-strict
7+
8+ from __future__ import (
9+ absolute_import ,
10+ annotations ,
11+ division ,
12+ print_function ,
13+ unicode_literals ,
14+ )
915
1016import balance .testutil
1117
@@ -36,13 +42,13 @@ class Testrake(
3642
3743 def _assert_rake_raises_with_message (
3844 self ,
39- expected_message ,
40- sample_df ,
41- sample_weights ,
42- target_df ,
43- target_weights ,
44- ** kwargs ,
45- ):
45+ expected_message : str ,
46+ sample_df : pd . DataFrame ,
47+ sample_weights : pd . Series | None ,
48+ target_df : pd . DataFrame ,
49+ target_weights : pd . Series | None ,
50+ ** kwargs : object ,
51+ ) -> None :
4652 """
4753 Helper method to assert that rake raises an error with a specific message.
4854
@@ -65,7 +71,7 @@ def _assert_rake_raises_with_message(
6571 ** kwargs ,
6672 )
6773
68- def test_rake_input_assertions (self ):
74+ def test_rake_input_assertions (self ) -> None :
6975 """
7076 Test that rake() properly validates input parameters.
7177
@@ -156,7 +162,7 @@ def test_rake_input_assertions(self):
156162 pd .Series ((1 ,) * (n_rows - 1 )),
157163 )
158164
159- def test_rake_fails_when_all_na (self ):
165+ def test_rake_fails_when_all_na (self ) -> None :
160166 """
161167 Test that rake() properly handles cases where all values are NaN.
162168
@@ -221,7 +227,7 @@ def test_rake_fails_when_all_na(self):
221227 transformations = None ,
222228 )
223229
224- def test_rake_weights (self ):
230+ def test_rake_weights (self ) -> None :
225231 """
226232 Test basic rake weighting functionality with categorical data.
227233
@@ -261,7 +267,7 @@ def test_rake_weights(self):
261267 pd .Series ([1.67 , 0.33 ] * 6 , name = "rake_weight" ).rename_axis ("index" ),
262268 )
263269
264- def test_rake_weight_trimming_applied (self ):
270+ def test_rake_weight_trimming_applied (self ) -> None :
265271 """Verify that rake forwards trimming arguments to the adjustment helper."""
266272
267273 df_sample = pd .DataFrame (
@@ -306,7 +312,7 @@ def test_rake_weight_trimming_applied(self):
306312
307313 pd .testing .assert_series_equal (trimmed ["weight" ], expected )
308314
309- def test_rake_percentile_trimming_applied (self ):
315+ def test_rake_percentile_trimming_applied (self ) -> None :
310316 """Percentile trimming parameters should be honoured by rake."""
311317
312318 df_sample = pd .DataFrame (
@@ -351,7 +357,7 @@ def test_rake_percentile_trimming_applied(self):
351357
352358 pd .testing .assert_series_equal (trimmed ["weight" ], expected )
353359
354- def test_rake_weights_with_weighted_input (self ):
360+ def test_rake_weights_with_weighted_input (self ) -> None :
355361 """
356362 Test rake weighting with pre-weighted target data.
357363
@@ -391,7 +397,7 @@ def test_rake_weights_with_weighted_input(self):
391397 pd .Series ([1.25 , 0.25 ] * 6 , name = "rake_weight" ).rename_axis ("index" ),
392398 )
393399
394- def test_rake_weights_scale_to_pop (self ):
400+ def test_rake_weights_scale_to_pop (self ) -> None :
395401 """
396402 Test that rake weights properly scale to match target population size.
397403
@@ -427,7 +433,7 @@ def test_rake_weights_scale_to_pop(self):
427433
428434 self .assertEqual (round (sum (adjusted ["weight" ]), 2 ), 15.0 )
429435
430- def test_rake_expected_weights_with_na (self ):
436+ def test_rake_expected_weights_with_na (self ) -> None :
431437 """
432438 Test rake weighting behavior with NaN values using different na_action strategies.
433439
@@ -486,7 +492,7 @@ def test_rake_expected_weights_with_na(self):
486492 pd .Series ([1.67 , 1.0 , 0.33 ] * 6 , name = "weight" ),
487493 )
488494
489- def test_rake_consistency_with_default_arguments (self ):
495+ def test_rake_consistency_with_default_arguments (self ) -> None :
490496 """
491497 Test consistency of rake function results with default parameters.
492498
@@ -565,7 +571,7 @@ def test_rake_consistency_with_default_arguments(self):
565571 ),
566572 )
567573
568- def test_variable_order_alphabetized (self ):
574+ def test_variable_order_alphabetized (self ) -> None :
569575 """
570576 Test that variable ordering is consistent and alphabetized.
571577
@@ -617,7 +623,7 @@ def test_variable_order_alphabetized(self):
617623 adjusted_two ["weight" ],
618624 )
619625
620- def test_rake_levels_warnings (self ):
626+ def test_rake_levels_warnings (self ) -> None :
621627 """
622628 Test warning and error handling for mismatched categorical levels.
623629
@@ -679,7 +685,7 @@ def test_rake_levels_warnings(self):
679685 target_excess_levels .weight_column ,
680686 )
681687
682- def test__proportional_array_from_dict (self ):
688+ def test__proportional_array_from_dict (self ) -> None :
683689 """
684690 Test the _proportional_array_from_dict utility function.
685691
@@ -716,7 +722,7 @@ def test__proportional_array_from_dict(self):
716722 ["a" , "a" , "a" , "b" , "b" , "b" , "b" , "b" ],
717723 )
718724
719- def test__realize_dicts_of_proportions (self ):
725+ def test__realize_dicts_of_proportions (self ) -> None :
720726 """
721727 Test the _realize_dicts_of_proportions utility function.
722728
@@ -768,7 +774,7 @@ def test__realize_dicts_of_proportions(self):
768774 },
769775 )
770776
771- def test_prepare_marginal_dist_for_raking (self ):
777+ def test_prepare_marginal_dist_for_raking (self ) -> None :
772778 """
773779 Test the prepare_marginal_dist_for_raking utility function.
774780
@@ -810,7 +816,7 @@ def test_prepare_marginal_dist_for_raking(self):
810816 },
811817 )
812818
813- def test_run_ipf_numpy_matches_expected_margins (self ):
819+ def test_run_ipf_numpy_matches_expected_margins (self ) -> None :
814820 """Validate that the NumPy IPF solver hits the requested marginals."""
815821
816822 original = np .array ([[5.0 , 3.0 ], [2.0 , 4.0 ]])
@@ -831,7 +837,7 @@ def test_run_ipf_numpy_matches_expected_margins(self):
831837 np .testing .assert_allclose (fitted .sum (axis = 0 ), target_cols , rtol = 0 , atol = 1e-6 )
832838 self .assertGreater (len (iterations ), 0 )
833839
834- def test_run_ipf_numpy_handles_zero_targets (self ):
840+ def test_run_ipf_numpy_handles_zero_targets (self ) -> None :
835841 """Ensure zero-valued margins do not introduce NaNs or divergence."""
836842
837843 original = np .array ([[4.0 , 1.0 , 0.0 ], [0.0 , 3.0 , 2.0 ]])
@@ -852,7 +858,7 @@ def test_run_ipf_numpy_handles_zero_targets(self):
852858 np .testing .assert_allclose (fitted .sum (axis = 1 ), target_rows , atol = 1e-9 )
853859 np .testing .assert_allclose (fitted .sum (axis = 0 ), target_cols , atol = 1e-9 )
854860
855- def test_run_ipf_numpy_flags_non_convergence (self ):
861+ def test_run_ipf_numpy_flags_non_convergence (self ) -> None :
856862 """The solver should report non-convergence when the iteration budget is exhausted."""
857863
858864 original = np .array ([[1.0 , 0.0 ], [0.0 , 1.0 ]])
@@ -869,7 +875,7 @@ def test_run_ipf_numpy_flags_non_convergence(self):
869875
870876 self .assertEqual (converged , 0 )
871877
872- def test_rake_zero_weight_levels_respected (self ):
878+ def test_rake_zero_weight_levels_respected (self ) -> None :
873879 """Variable levels with zero target weight should collapse to zero mass."""
874880
875881 sample_df = pd .DataFrame (
0 commit comments