1212
1313import pandas as pd
1414from vivarium .framework .engine import Builder
15+ from vivarium .framework .event import Event
1516from vivarium .framework .results import Observer
1617
1718from vivarium_public_health .results .columns import COLUMNS
@@ -38,6 +39,7 @@ def register_adding_observation(
3839 excluded_stratifications : list [str ] = [],
3940 aggregator_sources : list [str ] | None = None ,
4041 aggregator : Callable [[pd .DataFrame ], float | pd .Series ] = len ,
42+ to_observe : Callable [[Event ], bool ] = lambda event : True ,
4143 ) -> None :
4244 """Registers an adding observation to the results system.
4345
@@ -73,6 +75,9 @@ def register_adding_observation(
7375 List of population view columns to be used in the `aggregator`.
7476 aggregator
7577 Function that computes the quantity for this observation.
78+ to_observe
79+ Function that takes an event and returns a boolean indicating whether
80+ the observation should be performed for that event.
7681 """
7782 builder .results .register_adding_observation (
7883 name = name ,
@@ -85,6 +90,7 @@ def register_adding_observation(
8590 excluded_stratifications = excluded_stratifications ,
8691 aggregator_sources = aggregator_sources ,
8792 aggregator = aggregator ,
93+ to_observe = to_observe ,
8894 )
8995
9096 def format_results (self , measure : str , results : pd .DataFrame ) -> pd .DataFrame :
@@ -207,7 +213,7 @@ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
207213 def get_sub_entity_column (self , measure : str , results : pd .DataFrame ) -> pd .Series :
208214 """Get the 'sub_entity' column.
209215
210- This method should be overwritten in subclasses to provide the 'sub_entity' column.
216+ This method can be overwritten in subclasses to provide the 'sub_entity' column.
211217
212218 Parameters
213219 ----------
0 commit comments