Skip to content

Commit 9889376

Browse files
authored
reuse filtered population during results processing when possible (#660)
1 parent 6be348c commit 9889376

File tree

4 files changed

+126
-98
lines changed

4 files changed

+126
-98
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**3.5.3 - 09/23/25**
2+
3+
- Reuse filtered population during results processing when possible
4+
15
**3.5.2 - 09/23/25**
26

37
- Fix type hint in get_output_model_name_string() utility method

src/vivarium/framework/results/context.py

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class ResultsContext:
5252
objects to be produced keyed by the observation name.
5353
grouped_observations
5454
Dictionary of observation details. It is of the format
55-
{lifecycle_state: {(pop_filter, stratifications): list[Observation]}}.
55+
{lifecycle_state: {pop_filter: {stratifications: list[Observation]}}}.
5656
Allowable lifecycle_states are "time_step__prepare", "time_step",
5757
"time_step__cleanup", and "collect_metrics".
5858
logger
@@ -65,8 +65,8 @@ def __init__(self) -> None:
6565
self.excluded_categories: dict[str, list[str]] = {}
6666
self.observations: dict[str, Observation] = {}
6767
self.grouped_observations: defaultdict[
68-
str, defaultdict[tuple[str, tuple[str, ...] | None], list[Observation]]
69-
] = defaultdict(lambda: defaultdict(list))
68+
str, defaultdict[str, defaultdict[tuple[str, ...] | None, list[Observation]]]
69+
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
7070

7171
@property
7272
def name(self) -> str:
@@ -112,18 +112,18 @@ def set_stratifications(self) -> None:
112112
"""
113113
used_stratifications: set[str] = set()
114114
for state_observations in self.grouped_observations.values():
115-
for observation_details in state_observations.items():
116-
(_, stratification_names), observations = observation_details
117-
if stratification_names is None:
118-
continue
119-
120-
used_stratifications |= set(stratification_names)
121-
for observation in observations:
122-
observation.stratifications = tuple(
123-
self.stratifications[name]
124-
for name in stratification_names
125-
if name in self.stratifications
126-
)
115+
for pop_filter_observations in state_observations.values():
116+
for stratification_names, observations in pop_filter_observations.items():
117+
if stratification_names is None:
118+
continue
119+
120+
used_stratifications |= set(stratification_names)
121+
for observation in observations:
122+
observation.stratifications = tuple(
123+
self.stratifications[name]
124+
for name in stratification_names
125+
if name in self.stratifications
126+
)
127127

128128
if unused_stratifications := set(self.stratifications.keys()) - used_stratifications:
129129
self.logger.info(
@@ -272,8 +272,8 @@ def register_observation(
272272
**kwargs,
273273
)
274274
self.observations[name] = observation
275-
self.grouped_observations[observation.when][
276-
(observation.pop_filter, stratifications)
275+
self.grouped_observations[observation.when][observation.pop_filter][
276+
stratifications
277277
].append(observation)
278278
return observation
279279

@@ -318,26 +318,36 @@ def gather_results(
318318

319319
# Optimization: We store all the producers by pop_filter and stratifications
320320
# so that we only have to apply them once each time we compute results.
321-
for (pop_filter, stratification_names), observations in self.grouped_observations[
321+
for pop_filter, stratification_observations in self.grouped_observations[
322322
lifecycle_state
323323
].items():
324-
observations = [obs for obs in observations if obs in event_observations]
325-
if not observations:
324+
event_pop_filter_observations = [
325+
observation
326+
for observations in stratification_observations.values()
327+
for observation in observations
328+
if observation in event_observations
329+
]
330+
if not event_pop_filter_observations:
326331
continue
327332

328-
# Results production can be simplified to
329-
# filter -> groupby -> aggregate in all situations we've seen.
330-
filtered_pop = self._filter_population(
331-
population, pop_filter, stratification_names
332-
)
333-
if filtered_pop.empty:
333+
filtered_population = self._filter_population(population, pop_filter)
334+
if filtered_population.empty:
334335
continue
335-
else:
336+
337+
for stratification_names, observations in stratification_observations.items():
338+
observations = [
339+
obs for obs in observations if obs in event_pop_filter_observations
340+
]
341+
if not observations:
342+
continue
343+
336344
pop: pd.DataFrame | DataFrameGroupBy[tuple[str, ...] | str, bool]
337-
if stratification_names is None:
338-
pop = filtered_pop
339-
else:
340-
pop = self._get_groups(stratification_names, filtered_pop)
345+
pop = self._drop_na_stratifications(filtered_population, stratification_names)
346+
if pop.empty:
347+
continue
348+
if stratification_names is not None:
349+
pop = self._get_groups(stratification_names, pop)
350+
341351
for observation in observations:
342352
results = observation.observe(pop, stratification_names)
343353
yield (results, observation.name, observation.results_updater)
@@ -357,7 +367,8 @@ def get_observations(self, event: Event) -> list[Observation]:
357367
"""
358368
return [
359369
observation
360-
for observations in self.grouped_observations[event.name].values()
370+
for stratification_observations in self.grouped_observations[event.name].values()
371+
for observations in stratification_observations.values()
361372
for observation in observations
362373
if observation.to_observe(event)
363374
]
@@ -436,24 +447,24 @@ def get_required_values(
436447
required_values.update(stratification.requires_values)
437448
return list(required_values)
438449

439-
def _filter_population(
440-
self,
441-
population: pd.DataFrame,
442-
pop_filter: str,
443-
stratification_names: tuple[str, ...] | None,
450+
def _filter_population(self, population: pd.DataFrame, pop_filter: str) -> pd.DataFrame:
451+
"""Filter out simulants not to observe."""
452+
return population.query(pop_filter) if pop_filter else population.copy()
453+
454+
def _drop_na_stratifications(
455+
self, population: pd.DataFrame, stratification_names: tuple[str, ...] | None
444456
) -> pd.DataFrame:
445457
"""Filter out simulants not to observe."""
446-
pop = population.query(pop_filter) if pop_filter else population.copy()
447458
if stratification_names:
448459
# Drop all rows in the mapped_stratification columns that have NaN values
449460
# (which only exist if the mapper returned an excluded category).
450-
pop = pop.dropna(
461+
population = population.dropna(
451462
subset=[
452463
get_mapped_col_name(stratification)
453464
for stratification in stratification_names
454465
]
455466
)
456-
return pop
467+
return population
457468

458469
@staticmethod
459470
def _get_groups(

tests/framework/results/test_context.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -752,45 +752,47 @@ def get_required_resources_kwargs(
752752

753753

754754
@pytest.mark.parametrize(
755-
"pop_filter, stratifications",
756-
[
757-
('familiar=="cat"', tuple()),
758-
('familiar=="spaghetti_yeti"', tuple()),
759-
("", ("new_col1",)),
760-
("", ("new_col1", "new_col2")),
761-
('familiar=="cat"', ("new_col1",)),
762-
("", tuple()),
763-
],
764-
ids=[
765-
"pop_filter",
766-
"pop_filter_empties_dataframe",
767-
"single_excluded_stratification",
768-
"two_excluded_stratifications",
769-
"pop_filter_and_excluded_stratification",
770-
"no_pop_filter_or_excluded_stratifications",
771-
],
755+
"pop_filter",
756+
['familiar=="cat"', 'familiar=="spaghetti_yeti"', ""],
757+
ids=["pop_filter", "pop_filter_empties_dataframe", "no_pop_filter"],
772758
)
773-
def test__filter_population(pop_filter: str, stratifications: tuple[str, ...]) -> None:
759+
def test__filter_population(pop_filter: str) -> None:
774760
population = BASE_POPULATION.copy()
775-
if stratifications:
776-
# Make some of the stratifications missing to mimic mapping to excluded categories
777-
population["new_col1"] = "new_value1"
778-
population.loc[population["tracked"] == True, "new_col1"] = np.nan
779-
if len(stratifications) == 2:
780-
population["new_col2"] = "new_value2"
781-
population.loc[population["new_col1"].notna(), "new_col2"] = np.nan
782-
# Add on the post-stratified columns
783-
for stratification in stratifications:
784-
mapped_col = f"{stratification}_mapped_values"
785-
population[mapped_col] = population[stratification]
786761

787762
filtered_pop = ResultsContext()._filter_population(
788-
population=population, pop_filter=pop_filter, stratification_names=stratifications
763+
population=population, pop_filter=pop_filter
789764
)
790765
expected = population.copy()
791766
if pop_filter:
792767
familiar = pop_filter.split("==")[1].strip('"')
793768
expected = expected[expected["familiar"] == familiar]
769+
assert filtered_pop.equals(expected)
770+
771+
772+
@pytest.mark.parametrize(
773+
"stratifications",
774+
[tuple(), ("new_col1",), ("new_col1", "new_col2")],
775+
ids=[
776+
"no_stratifications",
777+
"single_excluded_stratification",
778+
"two_excluded_stratifications",
779+
],
780+
)
781+
def test__drop_na_stratifications(stratifications: tuple[str, ...]) -> None:
782+
population = BASE_POPULATION.copy()
783+
population["new_col1"] = "new_value1"
784+
population.loc[population["tracked"] == True, "new_col1"] = np.nan
785+
population["new_col2"] = "new_value2"
786+
population.loc[population["new_col1"].notna(), "new_col2"] = np.nan
787+
# Add on the post-stratified columns
788+
for stratification in stratifications:
789+
mapped_col = f"{stratification}_mapped_values"
790+
population[mapped_col] = population[stratification]
791+
792+
filtered_pop = ResultsContext()._drop_na_stratifications(
793+
population=population, stratification_names=stratifications
794+
)
795+
expected = population.copy()
794796
for stratification in stratifications:
795797
expected = expected[expected[stratification].notna()]
796798
assert filtered_pop.equals(expected)

tests/framework/results/test_interface.py

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ def test_register_stratified_observation(mocker: MockerFixture) -> None:
195195

196196
grouped_observations = interface._manager._results_context.grouped_observations
197197
assert len(grouped_observations) == 1
198-
((filter, stratifications), observations) = list(
199-
grouped_observations["some-when"].items()
200-
)[0]
198+
filter = list(grouped_observations["some-when"].keys())[0]
199+
stratifications = list(grouped_observations["some-when"][filter])[0]
200+
observations = grouped_observations["some-when"][filter][stratifications]
201201
assert filter == "some-filter"
202202
assert isinstance(stratifications, tuple) # for mypy in following set(stratifications)
203203
assert set(stratifications) == {
@@ -239,13 +239,15 @@ def test_register_unstratified_observation(mocker: MockerFixture) -> None:
239239
results_gatherer=lambda _: pd.DataFrame(),
240240
results_updater=lambda _, __: pd.DataFrame(),
241241
)
242-
observations = interface._manager._results_context.grouped_observations
243-
assert len(observations) == 1
244-
((filter, stratification), observation) = list(observations["some-when"].items())[0]
242+
grouped_observations = interface._manager._results_context.grouped_observations
243+
assert len(grouped_observations) == 1
244+
filter = list(grouped_observations["some-when"].keys())[0]
245+
stratifications = list(grouped_observations["some-when"][filter])[0]
246+
observations = grouped_observations["some-when"][filter][stratifications]
245247
assert filter == "some-filter"
246-
assert stratification is None
247-
assert len(observation) == 1
248-
obs = observation[0]
248+
assert stratifications is None
249+
assert len(observations) == 1
250+
obs = observations[0]
249251
assert obs.name == "some-name"
250252
assert obs.pop_filter == "some-filter"
251253
assert obs.when == "some-when"
@@ -346,26 +348,33 @@ def test_register_multiple_adding_observations(mocker: MockerFixture) -> None:
346348
)
347349
# Test observation gets added
348350
assert len(interface._manager._results_context.grouped_observations) == 1
349-
# Test for default pop_filter
350-
assert ("tracked==True", ()) in interface._manager._results_context.grouped_observations[
351-
lifecycle_states.TIME_STEP_CLEANUP
352-
]
351+
assert (
352+
interface._manager._results_context.grouped_observations[
353+
lifecycle_states.TIME_STEP_CLEANUP
354+
]["tracked==True"][()][0].name
355+
== "living_person_time"
356+
)
357+
353358
interface.register_adding_observation(
354359
name="undead_person_time",
355-
pop_filter="undead == True",
360+
pop_filter="undead==True",
356361
when=lifecycle_states.TIME_STEP_PREPARE,
357362
aggregator=_silly_aggregator,
358363
)
359364
# Test new observation gets added
360365
assert len(interface._manager._results_context.grouped_observations) == 2
361-
# Preserve other observation and its pop filter
362-
assert ("tracked==True", ()) in interface._manager._results_context.grouped_observations[
363-
lifecycle_states.TIME_STEP_CLEANUP
364-
]
365-
# Test for overridden pop_filter
366-
assert ("undead == True", ()) in interface._manager._results_context.grouped_observations[
367-
lifecycle_states.TIME_STEP_PREPARE
368-
]
366+
assert (
367+
interface._manager._results_context.grouped_observations[
368+
lifecycle_states.TIME_STEP_CLEANUP
369+
]["tracked==True"][()][0].name
370+
== "living_person_time"
371+
)
372+
assert (
373+
interface._manager._results_context.grouped_observations[
374+
lifecycle_states.TIME_STEP_PREPARE
375+
]["undead==True"][()][0].name
376+
== "undead_person_time"
377+
)
369378

370379

371380
@pytest.mark.parametrize("resource_type", ["value", "column"])
@@ -484,13 +493,15 @@ def test_register_concatenating_observation(mocker: MockerFixture) -> None:
484493
requires_values=["some-value", "some-other-value"],
485494
results_formatter=lambda _, __: pd.DataFrame(),
486495
)
487-
observations = interface._manager._results_context.grouped_observations
488-
assert len(observations) == 1
489-
((filter, stratification), observation) = list(observations["some-when"].items())[0]
496+
grouped_observations = interface._manager._results_context.grouped_observations
497+
assert len(grouped_observations) == 1
498+
filter = list(grouped_observations["some-when"].keys())[0]
499+
stratifications = list(grouped_observations["some-when"][filter])[0]
500+
observations = grouped_observations["some-when"][filter][stratifications]
490501
assert filter == "some-filter"
491-
assert stratification is None
492-
assert len(observation) == 1
493-
obs = observation[0]
502+
assert stratifications is None
503+
assert len(observations) == 1
504+
obs = observations[0]
494505
assert obs.name == "some-name"
495506
assert obs.pop_filter == "some-filter"
496507
assert obs.when == "some-when"

0 commit comments

Comments
 (0)