Skip to content

Commit 53488aa

Browse files
committed
Add test
1 parent eff4e06 commit 53488aa

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

src/vivarium_testing_utils/automated_validation/comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,8 @@ def aggregate_strata(self, strata: Collection[str] = ()) -> pd.DataFrame | float
318318
strata = list(strata)
319319
for stratum in strata:
320320
if (
321-
stratum not in self.reference_data.index
322-
and stratum not in self.reference_weights.index
321+
stratum not in self.reference_data.index.names
322+
and stratum not in self.reference_weights.index.names
323323
):
324324
raise ValueError(
325325
f"Stratum '{stratum}' not found in reference data or weights."

tests/automated_validation/test_comparison.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,3 +458,43 @@ def test_fuzzy_comparison_align_datasets_calculation(
458458
index=expected_index,
459459
),
460460
)
461+
462+
463+
def test_aggregate_strata(
464+
mock_ratio_measure: RatioMeasure,
465+
test_data: dict[str, pd.DataFrame],
466+
reference_data: pd.DataFrame,
467+
reference_weights: pd.DataFrame,
468+
) -> None:
469+
"""Test that aggregate_strata correctly aggregates data."""
470+
comparison = FuzzyComparison(
471+
mock_ratio_measure,
472+
DataSource.SIM,
473+
test_data,
474+
DataSource.GBD,
475+
reference_data,
476+
reference_weights,
477+
)
478+
479+
aggregated = comparison.aggregate_strata(["age", "sex"])
480+
# (0, Male) = (0.12 * 0.15 + 0.29 * 0.35) / (0.15 + 0.35)
481+
expected = pd.DataFrame(
482+
{
483+
"value": [
484+
(0.12 * 0.15 + 0.29 * 0.35) / (0.15 + 0.35),
485+
(0.2 * 0.25) / 0.25,
486+
],
487+
},
488+
index=pd.MultiIndex.from_tuples(
489+
[
490+
(0, "male"),
491+
(0, "female"),
492+
],
493+
names=["age", "sex"],
494+
),
495+
)
496+
assert isinstance(aggregated, pd.DataFrame)
497+
pd.testing.assert_frame_equal(aggregated, expected)
498+
499+
with pytest.raises(ValueError, match="not found in reference data or weights"):
500+
comparison.aggregate_strata(["dog", "cat"])

0 commit comments

Comments
 (0)