Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions tests/test_poststratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)

import balance.testutil

Expand All @@ -19,7 +25,7 @@
class Testpoststratify(
balance.testutil.BalanceTestCase,
):
def test_poststratify(self):
def test_poststratify(self) -> None:
s = pd.DataFrame(
{
"a": (0, 1, 0, 1),
Expand Down Expand Up @@ -101,7 +107,7 @@ def test_poststratify(self):
)
self.assertEqual(expected, result.weights().df.iloc[:, 0].values)

def test_poststratify_weight_trimming_applied(self):
def test_poststratify_weight_trimming_applied(self) -> None:
s = pd.DataFrame(
{
"a": (0, 1, 0, 1),
Expand All @@ -112,20 +118,24 @@ def test_poststratify_weight_trimming_applied(self):
t = s
t_weights = pd.Series([4, 2, 2, 8])

baseline = poststratify(
baseline_result = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
)["weight"]
assert isinstance(baseline_result, pd.Series)
baseline = baseline_result

trimmed = poststratify(
trimmed_result = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
weight_trimming_mean_ratio=1.0,
)["weight"]
assert isinstance(trimmed_result, pd.Series)
trimmed = trimmed_result

expected = balance_adjustment.trim_weights(
baseline,
Expand All @@ -135,7 +145,7 @@ def test_poststratify_weight_trimming_applied(self):

pd.testing.assert_series_equal(trimmed, expected)

def test_poststratify_percentile_trimming_applied(self):
def test_poststratify_percentile_trimming_applied(self) -> None:
s = pd.DataFrame(
{
"a": (0, 1, 0, 1),
Expand All @@ -146,20 +156,24 @@ def test_poststratify_percentile_trimming_applied(self):
t = s
t_weights = pd.Series([4, 2, 2, 8])

baseline = poststratify(
baseline_result = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
)["weight"]
assert isinstance(baseline_result, pd.Series)
baseline = baseline_result

trimmed = poststratify(
trimmed_result = poststratify(
sample_df=s,
sample_weights=s_weights,
target_df=t,
target_weights=t_weights,
weight_trimming_percentile=0.25,
)["weight"]
assert isinstance(trimmed_result, pd.Series)
trimmed = trimmed_result

expected = balance_adjustment.trim_weights(
baseline,
Expand All @@ -169,7 +183,7 @@ def test_poststratify_percentile_trimming_applied(self):

pd.testing.assert_series_equal(trimmed, expected)

def test_poststratify_variables_arg(self):
def test_poststratify_variables_arg(self) -> None:
s = pd.DataFrame(
{
"a": (0, 1, 0, 1),
Expand All @@ -188,7 +202,7 @@ def test_poststratify_variables_arg(self):
)["weight"]
self.assertEqual(result, pd.Series([4.0, 4.0, 2.0, 6.0]))

def test_poststratify_transformations(self):
def test_poststratify_transformations(self) -> None:
# for numeric
size = 10000
s = pd.DataFrame({"age": np.random.uniform(0, 1, size)})
Expand Down Expand Up @@ -239,7 +253,7 @@ def test_poststratify_transformations(self):
self.assertTrue(abs(result[s.x == "b"].sum() / size - 0.035) < eps)
self.assertTrue(abs(result[s.x == "c"].sum() / size - 0.015) < eps)

def test_poststratify_exceptions(self):
def test_poststratify_exceptions(self) -> None:
# column with name weight
s = pd.DataFrame(
{
Expand Down
62 changes: 34 additions & 28 deletions tests/test_rake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,15 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

from __future__ import absolute_import, division, print_function, unicode_literals
# pyre-strict

from __future__ import (
absolute_import,
annotations,
division,
print_function,
unicode_literals,
)

import balance.testutil

Expand Down Expand Up @@ -36,13 +42,13 @@ class Testrake(

def _assert_rake_raises_with_message(
self,
expected_message,
sample_df,
sample_weights,
target_df,
target_weights,
**kwargs,
):
expected_message: str,
sample_df: pd.DataFrame,
sample_weights: pd.Series | None,
target_df: pd.DataFrame,
target_weights: pd.Series | None,
**kwargs: object,
) -> None:
"""
Helper method to assert that rake raises an error with a specific message.
Expand All @@ -65,7 +71,7 @@ def _assert_rake_raises_with_message(
**kwargs,
)

def test_rake_input_assertions(self):
def test_rake_input_assertions(self) -> None:
"""
Test that rake() properly validates input parameters.
Expand Down Expand Up @@ -156,7 +162,7 @@ def test_rake_input_assertions(self):
pd.Series((1,) * (n_rows - 1)),
)

def test_rake_fails_when_all_na(self):
def test_rake_fails_when_all_na(self) -> None:
"""
Test that rake() properly handles cases where all values are NaN.
Expand Down Expand Up @@ -221,7 +227,7 @@ def test_rake_fails_when_all_na(self):
transformations=None,
)

def test_rake_weights(self):
def test_rake_weights(self) -> None:
"""
Test basic rake weighting functionality with categorical data.
Expand Down Expand Up @@ -261,7 +267,7 @@ def test_rake_weights(self):
pd.Series([1.67, 0.33] * 6, name="rake_weight").rename_axis("index"),
)

def test_rake_weight_trimming_applied(self):
def test_rake_weight_trimming_applied(self) -> None:
"""Verify that rake forwards trimming arguments to the adjustment helper."""

df_sample = pd.DataFrame(
Expand Down Expand Up @@ -306,7 +312,7 @@ def test_rake_weight_trimming_applied(self):

pd.testing.assert_series_equal(trimmed["weight"], expected)

def test_rake_percentile_trimming_applied(self):
def test_rake_percentile_trimming_applied(self) -> None:
"""Percentile trimming parameters should be honoured by rake."""

df_sample = pd.DataFrame(
Expand Down Expand Up @@ -351,7 +357,7 @@ def test_rake_percentile_trimming_applied(self):

pd.testing.assert_series_equal(trimmed["weight"], expected)

def test_rake_weights_with_weighted_input(self):
def test_rake_weights_with_weighted_input(self) -> None:
"""
Test rake weighting with pre-weighted target data.
Expand Down Expand Up @@ -391,7 +397,7 @@ def test_rake_weights_with_weighted_input(self):
pd.Series([1.25, 0.25] * 6, name="rake_weight").rename_axis("index"),
)

def test_rake_weights_scale_to_pop(self):
def test_rake_weights_scale_to_pop(self) -> None:
"""
Test that rake weights properly scale to match target population size.
Expand Down Expand Up @@ -427,7 +433,7 @@ def test_rake_weights_scale_to_pop(self):

self.assertEqual(round(sum(adjusted["weight"]), 2), 15.0)

def test_rake_expected_weights_with_na(self):
def test_rake_expected_weights_with_na(self) -> None:
"""
Test rake weighting behavior with NaN values using different na_action strategies.
Expand Down Expand Up @@ -486,7 +492,7 @@ def test_rake_expected_weights_with_na(self):
pd.Series([1.67, 1.0, 0.33] * 6, name="weight"),
)

def test_rake_consistency_with_default_arguments(self):
def test_rake_consistency_with_default_arguments(self) -> None:
"""
Test consistency of rake function results with default parameters.
Expand Down Expand Up @@ -565,7 +571,7 @@ def test_rake_consistency_with_default_arguments(self):
),
)

def test_variable_order_alphabetized(self):
def test_variable_order_alphabetized(self) -> None:
"""
Test that variable ordering is consistent and alphabetized.
Expand Down Expand Up @@ -617,7 +623,7 @@ def test_variable_order_alphabetized(self):
adjusted_two["weight"],
)

def test_rake_levels_warnings(self):
def test_rake_levels_warnings(self) -> None:
"""
Test warning and error handling for mismatched categorical levels.
Expand Down Expand Up @@ -679,7 +685,7 @@ def test_rake_levels_warnings(self):
target_excess_levels.weight_column,
)

def test__proportional_array_from_dict(self):
def test__proportional_array_from_dict(self) -> None:
"""
Test the _proportional_array_from_dict utility function.
Expand Down Expand Up @@ -716,7 +722,7 @@ def test__proportional_array_from_dict(self):
["a", "a", "a", "b", "b", "b", "b", "b"],
)

def test__realize_dicts_of_proportions(self):
def test__realize_dicts_of_proportions(self) -> None:
"""
Test the _realize_dicts_of_proportions utility function.
Expand Down Expand Up @@ -768,7 +774,7 @@ def test__realize_dicts_of_proportions(self):
},
)

def test_prepare_marginal_dist_for_raking(self):
def test_prepare_marginal_dist_for_raking(self) -> None:
"""
Test the prepare_marginal_dist_for_raking utility function.
Expand Down Expand Up @@ -810,7 +816,7 @@ def test_prepare_marginal_dist_for_raking(self):
},
)

def test_run_ipf_numpy_matches_expected_margins(self):
def test_run_ipf_numpy_matches_expected_margins(self) -> None:
"""Validate that the NumPy IPF solver hits the requested marginals."""

original = np.array([[5.0, 3.0], [2.0, 4.0]])
Expand All @@ -831,7 +837,7 @@ def test_run_ipf_numpy_matches_expected_margins(self):
np.testing.assert_allclose(fitted.sum(axis=0), target_cols, rtol=0, atol=1e-6)
self.assertGreater(len(iterations), 0)

def test_run_ipf_numpy_handles_zero_targets(self):
def test_run_ipf_numpy_handles_zero_targets(self) -> None:
"""Ensure zero-valued margins do not introduce NaNs or divergence."""

original = np.array([[4.0, 1.0, 0.0], [0.0, 3.0, 2.0]])
Expand All @@ -852,7 +858,7 @@ def test_run_ipf_numpy_handles_zero_targets(self):
np.testing.assert_allclose(fitted.sum(axis=1), target_rows, atol=1e-9)
np.testing.assert_allclose(fitted.sum(axis=0), target_cols, atol=1e-9)

def test_run_ipf_numpy_flags_non_convergence(self):
def test_run_ipf_numpy_flags_non_convergence(self) -> None:
"""The solver should report non-convergence when the iteration budget is exhausted."""

original = np.array([[1.0, 0.0], [0.0, 1.0]])
Expand All @@ -869,7 +875,7 @@ def test_run_ipf_numpy_flags_non_convergence(self):

self.assertEqual(converged, 0)

def test_rake_zero_weight_levels_respected(self):
def test_rake_zero_weight_levels_respected(self) -> None:
"""Variable levels with zero target weight should collapse to zero mass."""

sample_df = pd.DataFrame(
Expand Down
Loading