Skip to content

Commit fdf955d

Browse files
talgalilifacebook-github-bot
authored andcommitted
Migrate test_rake.py to pyre-strict mode (facebookresearch#161)
Summary: Migrated test_rake.py from # pyre-unsafe to # pyre-strict mode as part of the comprehensive type safety improvements in the balance test suite. Changes: - Converted test_rake.py from # pyre-unsafe to # pyre-strict - Updated file header to use from __future__ import annotations for modern type hint support - Added comprehensive type annotations for all test functions and helper variables - Fixed type handling throughout test cases to satisfy strict mode requirements All test cases pass without modifications to test logic - the changes are purely type annotation improvements that enhance code maintainability and prevent future type safety regressions. Differential Revision: D87727292
1 parent d3b1d96 commit fdf955d

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

tests/test_rake.py

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,15 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
# pyre-unsafe
7-
8-
from __future__ import absolute_import, division, print_function, unicode_literals
6+
# pyre-strict
7+
8+
from __future__ import (
9+
absolute_import,
10+
annotations,
11+
division,
12+
print_function,
13+
unicode_literals,
14+
)
915

1016
import balance.testutil
1117

@@ -36,13 +42,13 @@ class Testrake(
3642

3743
def _assert_rake_raises_with_message(
3844
self,
39-
expected_message,
40-
sample_df,
41-
sample_weights,
42-
target_df,
43-
target_weights,
44-
**kwargs,
45-
):
45+
expected_message: str,
46+
sample_df: pd.DataFrame,
47+
sample_weights: pd.Series | None,
48+
target_df: pd.DataFrame,
49+
target_weights: pd.Series | None,
50+
**kwargs: object,
51+
) -> None:
4652
"""
4753
Helper method to assert that rake raises an error with a specific message.
4854
@@ -65,7 +71,7 @@ def _assert_rake_raises_with_message(
6571
**kwargs,
6672
)
6773

68-
def test_rake_input_assertions(self):
74+
def test_rake_input_assertions(self) -> None:
6975
"""
7076
Test that rake() properly validates input parameters.
7177
@@ -156,7 +162,7 @@ def test_rake_input_assertions(self):
156162
pd.Series((1,) * (n_rows - 1)),
157163
)
158164

159-
def test_rake_fails_when_all_na(self):
165+
def test_rake_fails_when_all_na(self) -> None:
160166
"""
161167
Test that rake() properly handles cases where all values are NaN.
162168
@@ -221,7 +227,7 @@ def test_rake_fails_when_all_na(self):
221227
transformations=None,
222228
)
223229

224-
def test_rake_weights(self):
230+
def test_rake_weights(self) -> None:
225231
"""
226232
Test basic rake weighting functionality with categorical data.
227233
@@ -261,7 +267,7 @@ def test_rake_weights(self):
261267
pd.Series([1.67, 0.33] * 6, name="rake_weight").rename_axis("index"),
262268
)
263269

264-
def test_rake_weight_trimming_applied(self):
270+
def test_rake_weight_trimming_applied(self) -> None:
265271
"""Verify that rake forwards trimming arguments to the adjustment helper."""
266272

267273
df_sample = pd.DataFrame(
@@ -306,7 +312,7 @@ def test_rake_weight_trimming_applied(self):
306312

307313
pd.testing.assert_series_equal(trimmed["weight"], expected)
308314

309-
def test_rake_percentile_trimming_applied(self):
315+
def test_rake_percentile_trimming_applied(self) -> None:
310316
"""Percentile trimming parameters should be honoured by rake."""
311317

312318
df_sample = pd.DataFrame(
@@ -351,7 +357,7 @@ def test_rake_percentile_trimming_applied(self):
351357

352358
pd.testing.assert_series_equal(trimmed["weight"], expected)
353359

354-
def test_rake_weights_with_weighted_input(self):
360+
def test_rake_weights_with_weighted_input(self) -> None:
355361
"""
356362
Test rake weighting with pre-weighted target data.
357363
@@ -391,7 +397,7 @@ def test_rake_weights_with_weighted_input(self):
391397
pd.Series([1.25, 0.25] * 6, name="rake_weight").rename_axis("index"),
392398
)
393399

394-
def test_rake_weights_scale_to_pop(self):
400+
def test_rake_weights_scale_to_pop(self) -> None:
395401
"""
396402
Test that rake weights properly scale to match target population size.
397403
@@ -427,7 +433,7 @@ def test_rake_weights_scale_to_pop(self):
427433

428434
self.assertEqual(round(sum(adjusted["weight"]), 2), 15.0)
429435

430-
def test_rake_expected_weights_with_na(self):
436+
def test_rake_expected_weights_with_na(self) -> None:
431437
"""
432438
Test rake weighting behavior with NaN values using different na_action strategies.
433439
@@ -486,7 +492,7 @@ def test_rake_expected_weights_with_na(self):
486492
pd.Series([1.67, 1.0, 0.33] * 6, name="weight"),
487493
)
488494

489-
def test_rake_consistency_with_default_arguments(self):
495+
def test_rake_consistency_with_default_arguments(self) -> None:
490496
"""
491497
Test consistency of rake function results with default parameters.
492498
@@ -565,7 +571,7 @@ def test_rake_consistency_with_default_arguments(self):
565571
),
566572
)
567573

568-
def test_variable_order_alphabetized(self):
574+
def test_variable_order_alphabetized(self) -> None:
569575
"""
570576
Test that variable ordering is consistent and alphabetized.
571577
@@ -617,7 +623,7 @@ def test_variable_order_alphabetized(self):
617623
adjusted_two["weight"],
618624
)
619625

620-
def test_rake_levels_warnings(self):
626+
def test_rake_levels_warnings(self) -> None:
621627
"""
622628
Test warning and error handling for mismatched categorical levels.
623629
@@ -679,7 +685,7 @@ def test_rake_levels_warnings(self):
679685
target_excess_levels.weight_column,
680686
)
681687

682-
def test__proportional_array_from_dict(self):
688+
def test__proportional_array_from_dict(self) -> None:
683689
"""
684690
Test the _proportional_array_from_dict utility function.
685691
@@ -716,7 +722,7 @@ def test__proportional_array_from_dict(self):
716722
["a", "a", "a", "b", "b", "b", "b", "b"],
717723
)
718724

719-
def test__realize_dicts_of_proportions(self):
725+
def test__realize_dicts_of_proportions(self) -> None:
720726
"""
721727
Test the _realize_dicts_of_proportions utility function.
722728
@@ -768,7 +774,7 @@ def test__realize_dicts_of_proportions(self):
768774
},
769775
)
770776

771-
def test_prepare_marginal_dist_for_raking(self):
777+
def test_prepare_marginal_dist_for_raking(self) -> None:
772778
"""
773779
Test the prepare_marginal_dist_for_raking utility function.
774780
@@ -810,7 +816,7 @@ def test_prepare_marginal_dist_for_raking(self):
810816
},
811817
)
812818

813-
def test_run_ipf_numpy_matches_expected_margins(self):
819+
def test_run_ipf_numpy_matches_expected_margins(self) -> None:
814820
"""Validate that the NumPy IPF solver hits the requested marginals."""
815821

816822
original = np.array([[5.0, 3.0], [2.0, 4.0]])
@@ -831,7 +837,7 @@ def test_run_ipf_numpy_matches_expected_margins(self):
831837
np.testing.assert_allclose(fitted.sum(axis=0), target_cols, rtol=0, atol=1e-6)
832838
self.assertGreater(len(iterations), 0)
833839

834-
def test_run_ipf_numpy_handles_zero_targets(self):
840+
def test_run_ipf_numpy_handles_zero_targets(self) -> None:
835841
"""Ensure zero-valued margins do not introduce NaNs or divergence."""
836842

837843
original = np.array([[4.0, 1.0, 0.0], [0.0, 3.0, 2.0]])
@@ -852,7 +858,7 @@ def test_run_ipf_numpy_handles_zero_targets(self):
852858
np.testing.assert_allclose(fitted.sum(axis=1), target_rows, atol=1e-9)
853859
np.testing.assert_allclose(fitted.sum(axis=0), target_cols, atol=1e-9)
854860

855-
def test_run_ipf_numpy_flags_non_convergence(self):
861+
def test_run_ipf_numpy_flags_non_convergence(self) -> None:
856862
"""The solver should report non-convergence when the iteration budget is exhausted."""
857863

858864
original = np.array([[1.0, 0.0], [0.0, 1.0]])
@@ -869,7 +875,7 @@ def test_run_ipf_numpy_flags_non_convergence(self):
869875

870876
self.assertEqual(converged, 0)
871877

872-
def test_rake_zero_weight_levels_respected(self):
878+
def test_rake_zero_weight_levels_respected(self) -> None:
873879
"""Variable levels with zero target weight should collapse to zero mass."""
874880

875881
sample_df = pd.DataFrame(

0 commit comments

Comments
 (0)