33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6- # pyre-unsafe
6+ # pyre-strict
7+
8+ from __future__ import annotations
9+
10+ from typing import Any , cast
711
812import balance .testutil
913
1418class TestBalance_weights_stats (
1519 balance .testutil .BalanceTestCase ,
1620):
17- def test__check_weights_are_valid (self ):
21+ def test__check_weights_are_valid (self ) -> None :
1822 """Test validation of weight arrays for statistical calculations.
1923
2024 Verifies that _check_weights_are_valid correctly validates different
@@ -52,7 +56,7 @@ def test__check_weights_are_valid(self):
5256 negative_w = [- 1 , 0 , 1 ]
5357 _check_weights_are_valid (negative_w )
5458
55- def test_design_effect (self ):
59+ def test_design_effect (self ) -> None :
5660 """Test calculation of design effect for weighted samples.
5761
5862 Design effect measures the loss of precision due to weighting.
@@ -68,7 +72,7 @@ def test_design_effect(self):
6872 )
6973 self .assertEqual (type (design_effect (pd .Series ((0 , 1 , 2 , 3 )))), np .float64 )
7074
71- def test_nonparametric_skew (self ):
75+ def test_nonparametric_skew (self ) -> None :
7276 """Test calculation of nonparametric skewness measure.
7377
7478 Tests skewness calculation for various distributions including
@@ -82,7 +86,7 @@ def test_nonparametric_skew(self):
8286 self .assertEqual (nonparametric_skew (pd .Series ((1 , 2 , 3 , 4 ))), 0 )
8387 self .assertEqual (nonparametric_skew (pd .Series ((1 , 1 , 1 , 2 ))), 0.5 )
8488
85- def test_prop_above_and_below (self ):
89+ def test_prop_above_and_below (self ) -> None :
8690 """Test calculation of proportions above and below thresholds.
8791
8892 Tests the prop_above_and_below function with default thresholds,
@@ -92,21 +96,32 @@ def test_prop_above_and_below(self):
9296 from balance .stats_and_plots .weights_stats import prop_above_and_below
9397
9498 # Test with identical values
99+ result1 = prop_above_and_below (pd .Series ((1 , 1 , 1 , 1 )))
100+ self .assertIsNotNone (result1 )
101+ assert result1 is not None # Type narrowing for pyre
102+ assert isinstance (result1 , pd .Series ) # Type narrowing for pyre
95103 self .assertEqual (
96- prop_above_and_below ( pd . Series (( 1 , 1 , 1 , 1 ))) .astype (int ).to_list (),
104+ result1 .astype (int ).to_list (),
97105 [0 , 0 , 0 , 0 , 0 , 1 , 0 , 0 , 0 , 0 ],
98106 )
99107
100108 # Test with varying values
109+ result2 = prop_above_and_below (pd .Series ((1 , 2 , 3 , 4 )))
110+ self .assertIsNotNone (result2 )
111+ assert result2 is not None # Type narrowing for pyre
112+ assert isinstance (result2 , pd .Series ) # Type narrowing for pyre
101113 self .assertEqual (
102- prop_above_and_below ( pd . Series (( 1 , 2 , 3 , 4 ))) .to_list (),
114+ result2 .to_list (),
103115 [0.0 , 0.0 , 0.0 , 0.25 , 0.5 , 0.5 , 0.0 , 0.0 , 0.0 , 0.0 ],
104116 )
105117
106118 # Test custom thresholds
107119 result = prop_above_and_below (
108120 pd .Series ((1 , 2 , 3 , 4 )), below = (0.1 , 0.5 ), above = (2 , 3 )
109121 )
122+ self .assertIsNotNone (result )
123+ assert result is not None # Type narrowing for pyre
124+ assert isinstance (result , pd .Series ) # Type narrowing for pyre
110125 self .assertEqual (result .to_list (), [0.0 , 0.25 , 0.0 , 0.0 ])
111126 self .assertEqual (
112127 result .index .to_list (),
@@ -119,14 +134,18 @@ def test_prop_above_and_below(self):
119134 )
120135
121136 # Test return_as_series = False
122- result = prop_above_and_below (pd .Series ((1 , 2 , 3 , 4 )), return_as_series = False )
137+ result_dict = prop_above_and_below (
138+ pd .Series ((1 , 2 , 3 , 4 )), return_as_series = False
139+ )
140+ self .assertIsNotNone (result_dict )
141+ assert result_dict is not None # Type narrowing for pyre
123142 expected = {
124143 "below" : [0.0 , 0.0 , 0.0 , 0.25 , 0.5 ],
125144 "above" : [0.5 , 0.0 , 0.0 , 0.0 , 0.0 ],
126145 }
127- self .assertEqual ({k : v .to_list () for k , v in result .items ()}, expected )
146+ self .assertEqual ({k : v .to_list () for k , v in result_dict .items ()}, expected )
128147
129- def test_weighted_median_breakdown_point (self ):
148+ def test_weighted_median_breakdown_point (self ) -> None :
130149 """Test calculation of weighted median breakdown point.
131150
132151 Tests the weighted median breakdown point calculation which measures
@@ -151,7 +170,7 @@ def test_weighted_median_breakdown_point(self):
151170class TestBalance_weighted_stats (
152171 balance .testutil .BalanceTestCase ,
153172):
154- def test__prepare_weighted_stat_args (self ):
173+ def test__prepare_weighted_stat_args (self ) -> None :
155174 """Test preparation of arguments for weighted statistical functions.
156175
157176 Tests the _prepare_weighted_stat_args function which standardizes
@@ -196,11 +215,11 @@ def test__prepare_weighted_stat_args(self):
196215
197216 # Check that it catches wrong input types
198217 with self .assertRaises (TypeError ):
199- v , w = pd .Series ([1 , 2 ]), "wrong_type"
200- v2 , w2 = _prepare_weighted_stat_args (v , w )
218+ v , w = pd .Series ([1 , 2 ]), "wrong_type" # type: ignore[assignment]
219+ v2 , w2 = _prepare_weighted_stat_args (v , w ) # type: ignore[arg-type]
201220 with self .assertRaises (TypeError ):
202- v , w = pd .Series ([1 , 2 ]), (1 , 2 )
203- v2 , w2 = _prepare_weighted_stat_args (v , w )
221+ v , w = pd .Series ([1 , 2 ]), (1 , 2 ) # type: ignore[assignment]
222+ v2 , w2 = _prepare_weighted_stat_args (v , w ) # type: ignore[arg-type]
204223 with self .assertRaises (TypeError ):
205224 v , w = (1 , 2 ), pd .Series ([1 , 2 ])
206225 v2 , w2 = _prepare_weighted_stat_args (v , w )
@@ -251,7 +270,7 @@ def test__prepare_weighted_stat_args(self):
251270 v , w = pd .Series ([- 1 , 0 , 1 , np .inf ]), pd .Series ([np .inf , 1 , - 2 , 1.0 ])
252271 v2 , w2 = _prepare_weighted_stat_args (v , w )
253272
254- def test_weighted_mean (self ):
273+ def test_weighted_mean (self ) -> None :
255274 """Test calculation of weighted mean for various input types.
256275
257276 Tests weighted mean calculation with different input formats,
@@ -368,7 +387,7 @@ def test_weighted_mean(self):
368387 pd .Series ((- 1 * 1 + 2 * 2 + 1 * 3 + 2 * 4 ) / (1 + 2 + 3 + 4 )),
369388 )
370389
371- def test_var_of_weighted_mean (self ):
390+ def test_var_of_weighted_mean (self ) -> None :
372391 """Test calculation of variance of weighted mean.
373392
374393 Tests the variance calculation of weighted means with and without
@@ -389,7 +408,7 @@ def test_var_of_weighted_mean(self):
389408 pd .Series (0.24 ),
390409 )
391410
392- def test_ci_of_weighted_mean (self ):
411+ def test_ci_of_weighted_mean (self ) -> None :
393412 """Test calculation of confidence intervals for weighted means.
394413
395414 Tests confidence interval calculations for weighted means with
@@ -417,7 +436,7 @@ def test_ci_of_weighted_mean(self):
417436 {"a" : (1.738 , 4.262 ), "b" : (1.0 , 1.0 )},
418437 )
419438
420- def test_weighted_var (self ):
439+ def test_weighted_var (self ) -> None :
421440 """Test calculation of weighted variance.
422441
423442 Tests weighted variance calculations with and without weights
@@ -435,7 +454,7 @@ def test_weighted_var(self):
435454 weighted_var (pd .Series ((1 , 2 )), pd .Series ((1 , 2 ))), pd .Series (0.5 )
436455 )
437456
438- def test_weighted_sd (self ):
457+ def test_weighted_sd (self ) -> None :
439458 """Test calculation of weighted standard deviation.
440459
441460 Tests weighted standard deviation calculations with various inputs
@@ -459,7 +478,7 @@ def test_weighted_sd(self):
459478 manual_std = np .sqrt (np .sum ((x2 - x2 .mean ()) ** 2 ) / (len (x ) - 1 ))
460479 self .assertEqual (round (weighted_sd (x )[0 ], 5 ), round (manual_std , 5 ))
461480
462- def test_weighted_quantile (self ):
481+ def test_weighted_quantile (self ) -> None :
463482 """Test calculation of weighted quantiles.
464483
465484 Tests weighted quantile calculations with various input formats
@@ -469,18 +488,18 @@ def test_weighted_quantile(self):
469488 from balance .stats_and_plots .weighted_stats import weighted_quantile
470489
471490 self .assertEqual (
472- weighted_quantile (np .arange (1 , 100 , 1 ), 0.5 ).values ,
491+ weighted_quantile (np .arange (1 , 100 , 1 ), [ 0.5 ] ).values ,
473492 np .array (((50 ,),)),
474493 )
475494
476495 # In R: reldist::wtd.quantile(c(1, 2, 3), q=c(0.5, 0.75), weight=c(1, 1, 2))
477496 self .assertEqual (
478- weighted_quantile (np .array ([1 , 2 , 3 ]), ( 0.5 , 0.75 ) ).values ,
497+ weighted_quantile (np .array ([1 , 2 , 3 ]), [ 0.5 , 0.75 ] ).values ,
479498 np .array (((2 ,), (3 ,))),
480499 )
481500
482501 self .assertEqual (
483- weighted_quantile (np .array ([1 , 2 , 3 ]), 0.5 , np .array ([1 , 1 , 2 ])).values ,
502+ weighted_quantile (np .array ([1 , 2 , 3 ]), [ 0.5 ] , np .array ([1 , 1 , 2 ])).values ,
484503 np .percentile ([1 , 2 , 3 , 3 ], 50 ),
485504 )
486505
@@ -540,7 +559,7 @@ def test_weighted_quantile(self):
540559 np .array ([[2.0 , 1.0 ], [2.0 , 1.0 ]]),
541560 )
542561
543- def test_descriptive_stats (self ):
562+ def test_descriptive_stats (self ) -> None :
544563 """Test calculation of descriptive statistics with weights.
545564
546565 Tests the descriptive_stats function with various statistics
@@ -629,7 +648,7 @@ def test_descriptive_stats(self):
629648class TestBalance_weighted_comparisons_stats (
630649 balance .testutil .BalanceTestCase ,
631650):
632- def test_outcome_variance_ratio (self ):
651+ def test_outcome_variance_ratio (self ) -> None :
633652 """Test calculation of outcome variance ratios between datasets.
634653
635654 Tests the outcome_variance_ratio function which compares variance
@@ -653,7 +672,7 @@ def test_outcome_variance_ratio(self):
653672 pd .Series ([1.0 , 1.0 ], index = ["j" , "k" ]),
654673 )
655674
656- def test__weights_per_covars_names (self ):
675+ def test__weights_per_covars_names (self ) -> None :
657676 """Test calculation of weights per covariate names.
658677
659678 Tests the _weights_per_covars_names function which assigns weights
@@ -696,7 +715,7 @@ def test__weights_per_covars_names(self):
696715
697716 self .assertEqual (outcome , expected )
698717
699- def test_asmd (self ):
718+ def test_asmd (self ) -> None :
700719 """Test calculation of Absolute Standardized Mean Differences (ASMD).
701720
702721 Tests the asmd function which calculates standardized mean differences
@@ -709,16 +728,16 @@ def test_asmd(self):
709728 # Using wild card since it will return:
710729 # "sample_df must be pd.DataFrame, is* <class 'pandas.core.series.Series'>"
711730 asmd (
712- pd .Series ((0 , 1 , 2 , 3 )),
713- pd .Series ((0 , 1 , 2 , 3 )),
731+ pd .Series ((0 , 1 , 2 , 3 )), # type: ignore[arg-type]
732+ pd .Series ((0 , 1 , 2 , 3 )), # type: ignore[arg-type]
714733 pd .Series ((0 , 1 , 2 , 3 )),
715734 pd .Series ((0 , 1 , 2 , 3 )),
716735 )
717736
718737 with self .assertRaisesRegex (ValueError , "target_df must be pd.DataFrame, is*" ):
719738 asmd (
720739 pd .DataFrame ({"a" : (0 , 1 , 2 , 3 )}),
721- pd .Series ((0 , 1 , 2 , 3 )),
740+ pd .Series ((0 , 1 , 2 , 3 )), # type: ignore[arg-type]
722741 pd .Series ((0 , 1 , 2 , 3 )),
723742 pd .Series ((0 , 1 , 2 , 3 )),
724743 )
@@ -740,7 +759,7 @@ def test_asmd(self):
740759 asmd (
741760 pd .DataFrame ({"a" : (1 , 2 ), "b" : (- 1 , 12 )}),
742761 pd .DataFrame ({"a" : (3 , 4 ), "c" : (5 , 6 )}),
743- std_type = "magic variance type that doesn't exist" ,
762+ std_type = cast ( Any , "magic variance type that doesn't exist" ) ,
744763 )
745764
746765 # TODO: (p2) add comparison to the following numbers
@@ -915,7 +934,7 @@ def test_asmd(self):
915934 )
916935 self .assertTrue (all ((np .round (r2 , 5 )) == np .array ([2.82843 , 0.70711 , 1.76777 ])))
917936
918- def test__aggregate_asmd_by_main_covar (self ):
937+ def test__aggregate_asmd_by_main_covar (self ) -> None :
919938 """Test aggregation of ASMD values by main covariate.
920939
921940 Tests the _aggregate_asmd_by_main_covar function which groups
@@ -942,7 +961,7 @@ def test__aggregate_asmd_by_main_covar(self):
942961
943962 self .assertEqual (outcome , expected )
944963
945- def test_asmd_improvement (self ):
964+ def test_asmd_improvement (self ) -> None :
946965 """Test calculation of ASMD improvement ratios.
947966
948967 Tests the asmd_improvement function which measures the improvement
@@ -988,7 +1007,7 @@ def test_asmd_improvement(self):
9881007class TestBalance_general_stats (
9891008 balance .testutil .BalanceTestCase ,
9901009):
991- def test_relative_response_rates (self ):
1010+ def test_relative_response_rates (self ) -> None :
9921011 """Test calculation of relative response rates across columns.
9931012
9941013 Tests the relative_response_rates function which calculates response
0 commit comments