Skip to content

Commit 2324a27

Browse files
authored
Albrja/mic 6247/specific rate aggregation weights (#70)
Albrja/mic 6247/specific rate aggregation weights Add specific rate aggregation weights implementation - *Category*: Implementation - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-6247 Changes and notes -add specific implementations of rate aggregation weights for all RatioMeasure classes except PopulationStructure
1 parent 197cb1b commit 2324a27

File tree

3 files changed

+152
-38
lines changed

3 files changed

+152
-38
lines changed

src/vivarium_testing_utils/automated_validation/data_transformation/measures.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
from vivarium_testing_utils.automated_validation.data_transformation.rate_aggregation import (
2727
RateAggregationWeights,
28+
population_weighted,
2829
)
2930

3031

@@ -172,12 +173,13 @@ class Incidence(RatioMeasure):
172173

173174
@property
174175
def rate_aggregation_weights(self) -> RateAggregationWeights:
175-
"""Returns rated aggregated weights."""
176+
"""Returns rate aggregated weights."""
176177
return RateAggregationWeights(
177178
weight_keys={
178179
"population": "population.structure",
179180
"prevalence": f"cause.{self.entity}.prevalence",
180181
},
182+
# TODO: Update formula to account for having more than two states. Only works for SI and SIS models.
181183
formula=lambda population, prevalence: population * (1 - prevalence),
182184
description="Person-time × (1 - prevalence) weighted average",
183185
)
@@ -201,8 +203,8 @@ class Prevalence(RatioMeasure):
201203

202204
@property
203205
def rate_aggregation_weights(self) -> RateAggregationWeights:
204-
"""Will be implemented with MIC-6247."""
205-
raise NotImplementedError
206+
"""Returns rate aggregated weights."""
207+
return population_weighted()
206208

207209
def __init__(self, cause: str) -> None:
208210
super().__init__(
@@ -223,8 +225,15 @@ class SIRemission(RatioMeasure):
223225

224226
@property
225227
def rate_aggregation_weights(self) -> RateAggregationWeights:
226-
"""Will be implemented with MIC-6247."""
227-
raise NotImplementedError
228+
"""Returns rate aggregated weights."""
229+
return RateAggregationWeights(
230+
weight_keys={
231+
"population": "population.structure",
232+
"prevalence": f"cause.{self.entity}.prevalence",
233+
},
234+
formula=lambda population, prevalence: population * prevalence,
235+
description="Person-time × prevalence weighted average",
236+
)
228237

229238
def __init__(self, cause: str) -> None:
230239
super().__init__(
@@ -245,8 +254,8 @@ class CauseSpecificMortalityRate(RatioMeasure):
245254

246255
@property
247256
def rate_aggregation_weights(self) -> RateAggregationWeights:
248-
"""Will be implemented with MIC-6247."""
249-
raise NotImplementedError
257+
"""Returns rate aggregated weights."""
258+
return population_weighted()
250259

251260
def __init__(self, cause: str) -> None:
252261
super().__init__(
@@ -267,8 +276,15 @@ class ExcessMortalityRate(RatioMeasure):
267276

268277
@property
269278
def rate_aggregation_weights(self) -> RateAggregationWeights:
270-
"""Will be implemented with MIC-6247."""
271-
raise NotImplementedError
279+
"""Returns rate aggregated weights."""
280+
return RateAggregationWeights(
281+
weight_keys={
282+
"population": "population.structure",
283+
"prevalence": f"cause.{self.entity}.prevalence",
284+
},
285+
formula=lambda population, prevalence: population * prevalence,
286+
description="Person-time × prevalence weighted average",
287+
)
272288

273289
def __init__(self, cause: str) -> None:
274290
super().__init__(
@@ -296,7 +312,7 @@ class PopulationStructure(RatioMeasure):
296312

297313
@property
298314
def rate_aggregation_weights(self) -> RateAggregationWeights:
299-
"""Will be implemented with MIC-6247."""
315+
"""This will be implemented when we refactor and implement DataBundle Mic-6241."""
300316
raise NotImplementedError
301317

302318
def __init__(self, scenario_columns: list[str]):
@@ -347,8 +363,8 @@ class RiskExposure(RatioMeasure):
347363

348364
@property
349365
def rate_aggregation_weights(self) -> RateAggregationWeights:
350-
"""Will be implemented with MIC-6247."""
351-
raise NotImplementedError
366+
"""Returns rate aggregated weights."""
367+
return population_weighted()
352368

353369
def __init__(self, risk_factor: str) -> None:
354370
super().__init__(
@@ -426,8 +442,8 @@ def artifact_datasets(self) -> dict[str, str]:
426442

427443
@property
428444
def rate_aggregation_weights(self) -> RateAggregationWeights:
429-
"""Will be implemented with MIC-6247."""
430-
raise NotImplementedError
445+
"""Returns rate aggregated weights."""
446+
return self.affected_measure.rate_aggregation_weights
431447

432448
@utils.check_io(
433449
relative_risks=SingleNumericColumn,

src/vivarium_testing_utils/automated_validation/data_transformation/rate_aggregation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,11 @@ class RateAggregationWeights:
2222
@utils.check_io(out=SingleNumericColumn)
2323
def get_weights(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
2424
return self.formula(*args, **kwargs)
25+
26+
27+
def population_weighted() -> RateAggregationWeights:
28+
return RateAggregationWeights(
29+
weight_keys={"population": "population.structure"},
30+
formula=lambda population: population,
31+
description="Population-weighted average",
32+
)

tests/automated_validation/data_transformation/test_measures.py

Lines changed: 114 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -579,40 +579,130 @@ def test_format_title() -> None:
579579
assert _format_title("measure.entity") == "Measure Entity"
580580

581581

582-
def test_rate_aggregation_weights() -> None:
583-
"""Test the rate_aggregation_weights property of Incidence measure."""
584-
cause = "disease"
585-
measure = Incidence(cause)
586-
582+
@pytest.mark.parametrize(
583+
"measure_class,measure_args,expected_weights_config,expected_description",
584+
[
585+
(
586+
Incidence,
587+
("disease",),
588+
{
589+
"population": "population.structure",
590+
"prevalence": "cause.disease.prevalence",
591+
},
592+
"Person-time × (1 - prevalence) weighted average",
593+
),
594+
(
595+
Prevalence,
596+
("disease",),
597+
{"population": "population.structure"},
598+
"Population-weighted average",
599+
),
600+
(
601+
SIRemission,
602+
("disease",),
603+
{
604+
"population": "population.structure",
605+
"prevalence": "cause.disease.prevalence",
606+
},
607+
"Person-time × prevalence weighted average",
608+
),
609+
(
610+
CauseSpecificMortalityRate,
611+
("disease",),
612+
{"population": "population.structure"},
613+
"Population-weighted average",
614+
),
615+
(
616+
ExcessMortalityRate,
617+
("disease",),
618+
{
619+
"population": "population.structure",
620+
"prevalence": "cause.disease.prevalence",
621+
},
622+
"Person-time × prevalence weighted average",
623+
),
624+
(
625+
RiskExposure,
626+
("child_stunting",),
627+
{"population": "population.structure"},
628+
"Population-weighted average",
629+
),
630+
(
631+
CategoricalRelativeRisk,
632+
(
633+
"risky_risk",
634+
"disease",
635+
"excess_mortality_rate",
636+
"common_stratify_column",
637+
None,
638+
),
639+
{
640+
"population": "population.structure",
641+
"prevalence": "cause.disease.prevalence",
642+
},
643+
"Person-time × prevalence weighted average",
644+
),
645+
(
646+
PopulationStructure,
647+
(["scenario"],),
648+
None, # Not used since it raises NotImplementedError
649+
None, # Not used since it raises NotImplementedError
650+
),
651+
],
652+
)
653+
def test_rate_aggregation_weights(
654+
measure_class: type[RatioMeasure],
655+
measure_args: tuple[str],
656+
expected_weights_config: dict[str, str] | None,
657+
expected_description: str | None,
658+
) -> None:
659+
"""Test the rate_aggregation_weights property of various RatioMeasure subclasses."""
660+
# Create the measure instance
661+
measure = measure_class(*measure_args) # type: ignore[call-arg]
662+
663+
if isinstance(measure, PopulationStructure):
664+
# Test that PopulationStructure raises NotImplementedError
665+
with pytest.raises(NotImplementedError):
666+
_ = measure.rate_aggregation_weights
667+
return
668+
669+
assert expected_weights_config is not None
670+
assert expected_description is not None
587671
# Get the rate aggregation weights
588672
rate_agg_weights = measure.rate_aggregation_weights
589-
590673
# Verify the configuration
591-
expected_keys = {
592-
"population": "population.structure",
593-
"prevalence": f"cause.{cause}.prevalence",
594-
}
595-
assert rate_agg_weights.weight_keys == expected_keys
596-
assert rate_agg_weights.description == "Person-time × (1 - prevalence) weighted average"
674+
assert rate_agg_weights.weight_keys == expected_weights_config
675+
assert rate_agg_weights.description == expected_description
597676

598677
# Create test data matching expected format
599678
test_index = pd.MultiIndex.from_tuples(
600679
[("A", "baseline"), ("B", "baseline")], names=["common_stratify_column", "scenario"]
601680
)
602-
603681
# Population structure data (proportions summing to 1)
604682
population_data = get_expected_dataframe(0.6, 0.4)
605-
# Prevalence data (proportions between 0 and 1)
606-
prevalence_data = get_expected_dataframe(0.1, 0.2)
607-
608-
# Test get_weights with keyword arguments
609-
weights = rate_agg_weights.get_weights(
610-
population=population_data, prevalence=prevalence_data
611-
)
683+
# Mock data from artifact
684+
key_data = get_expected_dataframe(0.1, 0.2)
612685

613-
# Expected calculation: population * (1 - prevalence)
614-
expected_weights = pd.DataFrame(
615-
{"value": [0.6 * (1 - 0.1), 0.4 * (1 - 0.2)]}, index=test_index # [0.54, 0.32]
616-
)
686+
if len(rate_agg_weights.weight_keys) > 1:
687+
weights = rate_agg_weights.get_weights(population_data, key_data)
688+
else:
689+
weights = rate_agg_weights.get_weights(population_data)
690+
691+
# Expected calculation depends on the measure type
692+
if "prevalence" in expected_weights_config:
693+
if "1 - prevalence" in expected_description:
694+
# Incidence: population * (1 - prevalence)
695+
expected_weights = pd.DataFrame(
696+
{"value": [0.6 * (1 - 0.1), 0.4 * (1 - 0.2)]},
697+
index=test_index, # [0.54, 0.32]
698+
)
699+
else:
700+
# SIRemission and ExcessMortalityRate: population * prevalence
701+
expected_weights = pd.DataFrame(
702+
{"value": [0.6 * 0.1, 0.4 * 0.2]}, index=test_index # [0.06, 0.08]
703+
)
704+
else:
705+
# Population weighted measures: just population
706+
expected_weights = population_data
617707

618708
pd.testing.assert_frame_equal(weights, expected_weights)

0 commit comments

Comments
 (0)