Skip to content

Commit 44e09cd

Browse files
committed
Add to_observe
1 parent ec15a6e commit 44e09cd

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**4.3.10 - 09/23/25**
2+
3+
- Bugfix: Fix bug in PublicHealthObserver to retain stratification columns in results
4+
15
**4.3.9 - 09/16/25**
26

37
- Bugfix: require alive column in register_transition_count_observation()

src/vivarium_public_health/results/observer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import pandas as pd
1414
from vivarium.framework.engine import Builder
15+
from vivarium.framework.event import Event
1516
from vivarium.framework.results import Observer
1617

1718
from vivarium_public_health.results.columns import COLUMNS
@@ -38,6 +39,7 @@ def register_adding_observation(
3839
excluded_stratifications: list[str] = [],
3940
aggregator_sources: list[str] | None = None,
4041
aggregator: Callable[[pd.DataFrame], float | pd.Series] = len,
42+
to_observe: Callable[[Event], bool] = lambda event: True,
4143
) -> None:
4244
"""Registers an adding observation to the results system.
4345
@@ -73,6 +75,9 @@ def register_adding_observation(
7375
List of population view columns to be used in the `aggregator`.
7476
aggregator
7577
Function that computes the quantity for this observation.
78+
to_observe
79+
Function that takes an event and returns a boolean indicating whether
80+
the observation should be performed for that event.
7681
"""
7782
builder.results.register_adding_observation(
7883
name=name,
@@ -85,6 +90,7 @@ def register_adding_observation(
8590
excluded_stratifications=excluded_stratifications,
8691
aggregator_sources=aggregator_sources,
8792
aggregator=aggregator,
93+
to_observe=to_observe,
8894
)
8995

9096
def format_results(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
@@ -207,7 +213,7 @@ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
207213
def get_sub_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
208214
"""Get the 'sub_entity' column.
209215
210-
This method should be overwritten in subclasses to provide the 'sub_entity' column.
216+
This method can be overwritten in subclasses to provide the 'sub_entity' column.
211217
212218
Parameters
213219
----------

tests/results/test_disability_observer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def test_disability_observer_setup(mocker):
7272
excluded_stratifications=observer.configuration.exclude,
7373
aggregator_sources=cause_pipelines,
7474
aggregator=observer.disability_weight_aggregator,
75+
to_observe=mocker.ANY,
7576
)
7677

7778
assert set(observer.disability_classes) == set([DiseaseState, RiskAttributableDisease])

0 commit comments

Comments
 (0)