Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 70 additions & 61 deletions balance/balancedf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import logging
from typing import Any, Dict, Literal, Tuple
from typing import Any, Callable, Dict, Literal, Tuple

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -1110,6 +1110,50 @@ def _get_df_and_weights(
weights = self._weights.values if (self._weights is not None) else None
return df_model_matrix, weights

@staticmethod
def _apply_comparison_stat_to_BalanceDF(
comparison_func: Callable[..., pd.Series],
sample_BalanceDF: "BalanceDF",
target_BalanceDF: "BalanceDF",
aggregate_by_main_covar: bool = False,
**kwargs: Any,
) -> pd.Series:
"""Generic helper to apply a weighted comparison statistic function to two BalanceDF objects.

This helper function reduces code duplication across multiple comparison methods
(asmd, kld, emd, cvmd, ks) by extracting the common pattern of:
1. Validating inputs are BalanceDF objects
2. Extracting df and weights from both objects
3. Calling the comparison function with the extracted data

Args:
comparison_func (Callable[..., pd.Series]): The comparison function from
weighted_comparisons_stats to apply (e.g., asmd, kld, emd, cvmd, ks).
sample_BalanceDF (BalanceDF): Sample object.
target_BalanceDF (BalanceDF): Target object.
aggregate_by_main_covar (bool, optional): Whether to aggregate by main covariate.
Defaults to False. Passed to the comparison function.
**kwargs: Additional keyword arguments to pass to the comparison function
(e.g., std_type for asmd).

Returns:
pd.Series: The result from the comparison function.
"""
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")

sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()

return comparison_func(
sample_df_values,
target_df_values,
sample_weights,
target_weights,
aggregate_by_main_covar=aggregate_by_main_covar,
**kwargs,
)

@staticmethod
def _asmd_BalanceDF(
sample_BalanceDF: "BalanceDF",
Expand Down Expand Up @@ -1156,19 +1200,12 @@ def _asmd_BalanceDF(
# mean(asmd) 1.756543
# dtype: float64
"""
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "target_BalanceDF")

sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()

return weighted_comparisons_stats.asmd(
sample_df_values,
target_df_values,
sample_weights,
target_weights,
return BalanceDF._apply_comparison_stat_to_BalanceDF(
weighted_comparisons_stats.asmd,
sample_BalanceDF,
target_BalanceDF,
aggregate_by_main_covar,
std_type="target",
aggregate_by_main_covar=aggregate_by_main_covar,
)

@staticmethod
Expand All @@ -1190,18 +1227,11 @@ def _kld_BalanceDF(
Returns:
pd.Series: See :func:`weighted_comparisons_stats.kld`.
"""
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")

sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()

return weighted_comparisons_stats.kld(
sample_df_values,
target_df_values,
sample_weights,
target_weights,
aggregate_by_main_covar=aggregate_by_main_covar,
return BalanceDF._apply_comparison_stat_to_BalanceDF(
weighted_comparisons_stats.kld,
sample_BalanceDF,
target_BalanceDF,
aggregate_by_main_covar,
)

@staticmethod
Expand All @@ -1223,18 +1253,11 @@ def _emd_BalanceDF(
Returns:
pd.Series: See :func:`weighted_comparisons_stats.emd`.
"""
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")

sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()

return weighted_comparisons_stats.emd(
sample_df_values,
target_df_values,
sample_weights,
target_weights,
aggregate_by_main_covar=aggregate_by_main_covar,
return BalanceDF._apply_comparison_stat_to_BalanceDF(
weighted_comparisons_stats.emd,
sample_BalanceDF,
target_BalanceDF,
aggregate_by_main_covar,
)

@staticmethod
Expand All @@ -1256,18 +1279,11 @@ def _cvmd_BalanceDF(
Returns:
pd.Series: See :func:`weighted_comparisons_stats.cvmd`.
"""
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")

sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()

return weighted_comparisons_stats.cvmd(
sample_df_values,
target_df_values,
sample_weights,
target_weights,
aggregate_by_main_covar=aggregate_by_main_covar,
return BalanceDF._apply_comparison_stat_to_BalanceDF(
weighted_comparisons_stats.cvmd,
sample_BalanceDF,
target_BalanceDF,
aggregate_by_main_covar,
)

@staticmethod
Expand All @@ -1289,18 +1305,11 @@ def _ks_BalanceDF(
Returns:
pd.Series: See :func:`weighted_comparisons_stats.ks`.
"""
BalanceDF._check_if_not_BalanceDF(sample_BalanceDF, "sample_BalanceDF")
BalanceDF._check_if_not_BalanceDF(target_BalanceDF, "target_BalanceDF")

sample_df_values, sample_weights = sample_BalanceDF._get_df_and_weights()
target_df_values, target_weights = target_BalanceDF._get_df_and_weights()

return weighted_comparisons_stats.ks(
sample_df_values,
target_df_values,
sample_weights,
target_weights,
aggregate_by_main_covar=aggregate_by_main_covar,
return BalanceDF._apply_comparison_stat_to_BalanceDF(
weighted_comparisons_stats.ks,
sample_BalanceDF,
target_BalanceDF,
aggregate_by_main_covar,
)

def asmd(
Expand Down
130 changes: 130 additions & 0 deletions tests/test_balancedf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,136 @@ def test_BalanceDF_asmd_aggregate_by_main_covar(self) -> None:
self.assertEqual(outcome_default, expected_default)
self.assertEqual(outcome_main_covar, expected_main_covar)

def test_BalanceDF__kld_BalanceDF(self) -> None:
"""Test _kld_BalanceDF static method directly."""
sample = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
).covars()

target = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
).covars()

result = BalanceDF._kld_BalanceDF(sample, target)

# Verify result is a Series with expected keys
self.assertIsInstance(result, pd.Series)
self.assertIn("a", result.index)
self.assertIn("b", result.index)
self.assertIn("mean(kld)", result.index)

# Verify all values are non-negative (KLD property)
self.assertTrue((result >= 0).all())

# Test with aggregate_by_main_covar
result_agg = BalanceDF._kld_BalanceDF(
sample, target, aggregate_by_main_covar=True
)
self.assertIsInstance(result_agg, pd.Series)

def test_BalanceDF__emd_BalanceDF(self) -> None:
"""Test _emd_BalanceDF static method directly."""
sample = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
).covars()

target = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
).covars()

result = BalanceDF._emd_BalanceDF(sample, target)

# Verify result is a Series with expected keys
self.assertIsInstance(result, pd.Series)
self.assertIn("a", result.index)
self.assertIn("b", result.index)
self.assertIn("mean(emd)", result.index)

# Verify all values are non-negative (EMD property)
self.assertTrue((result >= 0).all())

# Test with aggregate_by_main_covar
result_agg = BalanceDF._emd_BalanceDF(
sample, target, aggregate_by_main_covar=True
)
self.assertIsInstance(result_agg, pd.Series)

def test_BalanceDF__cvmd_BalanceDF(self) -> None:
"""Test _cvmd_BalanceDF static method directly."""
sample = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
).covars()

target = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
).covars()

result = BalanceDF._cvmd_BalanceDF(sample, target)

# Verify result is a Series with expected keys
self.assertIsInstance(result, pd.Series)
self.assertIn("a", result.index)
self.assertIn("b", result.index)
self.assertIn("mean(cvmd)", result.index)

# Verify all values are non-negative (CVMD property)
self.assertTrue((result >= 0).all())

# Test with aggregate_by_main_covar
result_agg = BalanceDF._cvmd_BalanceDF(
sample, target, aggregate_by_main_covar=True
)
self.assertIsInstance(result_agg, pd.Series)

def test_BalanceDF__ks_BalanceDF(self) -> None:
"""Test _ks_BalanceDF static method directly."""
sample = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (1, 2), "b": (-1, 12), "weight": (1, 2)})
).covars()

target = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (3, 4), "b": (0, 42), "weight": (1, 2)})
).covars()

result = BalanceDF._ks_BalanceDF(sample, target)

# Verify result is a Series with expected keys
self.assertIsInstance(result, pd.Series)
self.assertIn("a", result.index)
self.assertIn("b", result.index)
self.assertIn("mean(ks)", result.index)

# Verify all values are in [0, 1] (KS property)
self.assertTrue((result >= 0).all())
self.assertTrue((result <= 1).all())

# Test with aggregate_by_main_covar
result_agg = BalanceDF._ks_BalanceDF(
sample, target, aggregate_by_main_covar=True
)
self.assertIsInstance(result_agg, pd.Series)

def test_BalanceDF_comparison_functions_invalid_input(self) -> None:
"""Test that all comparison functions properly validate inputs."""
sample = Sample.from_frame(
pd.DataFrame({"id": (1, 2), "a": (1, 2), "weight": (1, 2)})
).covars()

# Test with non-BalanceDF inputs
invalid_input = "not a BalanceDF"

with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
BalanceDF._kld_BalanceDF(invalid_input, sample) # type: ignore

with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
BalanceDF._emd_BalanceDF(sample, invalid_input) # type: ignore

with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
BalanceDF._cvmd_BalanceDF(invalid_input, sample) # type: ignore

with self.assertRaisesRegex(ValueError, "must be balancedf_class.BalanceDF"):
BalanceDF._ks_BalanceDF(sample, invalid_input) # type: ignore


class TestBalanceDF_to_download(BalanceTestCase):
def test_BalanceDF_to_download(self) -> None:
Expand Down
Loading