Skip to content

Commit ec4dac5

Browse files
authored
Support callable and tuple(str, callable) for population filter (#139)
Albrja/mic-7064/Support callable and tuple(str, callable) for population filter Support callable and tuple(str, callable) for population filter - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-7064 Changes and notes -Support callable and tuple(str, callable) for population filter
1 parent a1b4145 commit ec4dac5

11 files changed

Lines changed: 187 additions & 72 deletions

File tree

libs/engine/CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
**5.2.0 - 06/29/26**
2+
3+
- Support a callable population filter (or ``(query, callable)`` tuple) for all observations
4+
15
**5.1.7 - 06/22/26**
26

37
- Pin vivarium-build-utils to v4.x and update Makefile to use ``vivarium.build_utils``

libs/engine/docs/nitpick-exceptions

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ py:class PandasObject
4848
py:class DataFrameGroupBy
4949
py:class ResultsFormatter
5050
py:class ResultsUpdater
51+
py:class PopulationFilter
52+
py:class _PopulationFilter
53+
py:class vivarium.engine.framework.results.interface._PopulationFilter
5154
py:class _NestedDict
5255
py:exc ResultsConfigurationError
5356
py:exc vivarium.engine.framework.results.exceptions.ResultsConfigurationError

libs/engine/docs/source/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import importlib.metadata
66

7-
87
# -- Project information -----------------------------------------------------
98

109
project = "vivarium.engine"

libs/engine/src/vivarium/engine/framework/results/context.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
if TYPE_CHECKING:
2828
from vivarium.engine.framework.engine import Builder
29-
from vivarium.engine.framework.results.interface import PopulationFilter
29+
from vivarium.engine.framework.results.interface import _PopulationFilter
3030

3131

3232
class ResultsContext:
@@ -52,7 +52,7 @@ class ResultsContext:
5252
objects to be produced keyed by the observation name.
5353
grouped_observations
5454
Dictionary of observation details. It is of the format
55-
{lifecycle_state: {PopulationFilter: {stratifications: list[Observation]}}}.
55+
{lifecycle_state: {_PopulationFilter: {stratifications: list[Observation]}}}.
5656
Allowable lifecycle_states are "time_step__prepare", "time_step",
5757
"time_step__cleanup", and "collect_metrics".
5858
logger
@@ -67,7 +67,7 @@ def __init__(self) -> None:
6767
self.grouped_observations: defaultdict[
6868
str,
6969
defaultdict[
70-
PopulationFilter,
70+
_PopulationFilter,
7171
defaultdict[tuple[str, ...] | None, list[Observation]],
7272
],
7373
] = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
@@ -223,7 +223,7 @@ def register_observation(
223223
self,
224224
observation_type: type[Observation],
225225
name: str,
226-
population_filter: PopulationFilter,
226+
population_filter: _PopulationFilter,
227227
when: str,
228228
requires_attributes: list[str],
229229
stratifications: tuple[str, ...] | None,
@@ -430,14 +430,17 @@ def get_required_attributes(
430430
return list(required_attributes)
431431

432432
def _filter_population(
433-
self, population: pd.DataFrame, population_filter: PopulationFilter
433+
self, population: pd.DataFrame, population_filter: _PopulationFilter
434434
) -> pd.DataFrame:
435435
"""Filter out simulants not to observe."""
436436
query = population_filter.query
437437
if not population_filter.include_untracked:
438438
# combine the tracking query with the population filter query
439439
query = pop_utils.combine_queries(query, self.get_tracked_query())
440-
return population.query(query) if query else population.copy()
440+
filtered = population.query(query) if query else population.copy()
441+
if population_filter.index_filter is not None:
442+
filtered = filtered.loc[population_filter.index_filter(filtered.index)]
443+
return filtered
441444

442445
def _drop_na_stratifications(
443446
self, population: pd.DataFrame, stratification_names: tuple[str, ...] | None

libs/engine/src/vivarium/engine/framework/results/interface.py

Lines changed: 64 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Callable
12-
from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
12+
from typing import TYPE_CHECKING, Any, Sequence, Union
1313

1414
import pandas as pd
1515
from pandas.core.groupby.generic import DataFrameGroupBy
@@ -40,6 +40,11 @@
4040
ResultsGatherer = Callable[[ResultsGathererInput], pd.DataFrame]
4141
"""A Callable that optionally takes a possibly stratified population and returns
4242
new observation results."""
43+
PopulationIndexFilter = Callable[[pd.Index], pd.Index] # type: ignore [type-arg]
44+
"""A Callable that takes a population index and returns the subset of indices to keep."""
45+
PopulationFilter = Union[str, PopulationIndexFilter, tuple[str, PopulationIndexFilter]]
46+
"""A population filter given as a Pandas query string, an index-filter callable, or a
47+
``(query, callable)`` tuple that applies the query first followed by the callable."""
4348

4449

4550
def _required_function_placeholder(
@@ -76,11 +81,44 @@ def _default_unstratified_observation_formatter(
7681
return results
7782

7883

79-
class PopulationFilter(NamedTuple):
80-
"""Container class for population query string and include_untracked flag."""
84+
class _PopulationFilter:
85+
"""Container for a query string, the untracked flag, and an optional index filter.
8186
82-
query: str = ""
83-
include_untracked: bool = False
87+
The ``pop_filter`` argument may be a Pandas query string, an index-filter
88+
callable, or a ``(query, callable)`` tuple; it is resolved into the separate
89+
``query`` and ``index_filter`` attributes on construction.
90+
"""
91+
92+
def __init__(
93+
self, pop_filter: PopulationFilter = "", include_untracked: bool = False
94+
) -> None:
95+
self.query, self.index_filter = self._parse_pop_filter(pop_filter)
96+
self.include_untracked = include_untracked
97+
98+
@staticmethod
99+
def _parse_pop_filter(
100+
pop_filter: PopulationFilter,
101+
) -> tuple[str, PopulationIndexFilter | None]:
102+
"""Resolve a pop_filter argument into a (query, index_filter) pair."""
103+
if isinstance(pop_filter, str):
104+
return pop_filter, None
105+
if isinstance(pop_filter, tuple):
106+
return pop_filter
107+
return "", pop_filter
108+
109+
def __eq__(self, other: object) -> bool:
110+
"""Compare by field value so equal filters share a grouping key."""
111+
if not isinstance(other, _PopulationFilter):
112+
return NotImplemented
113+
return (self.query, self.include_untracked, self.index_filter) == (
114+
other.query,
115+
other.include_untracked,
116+
other.index_filter,
117+
)
118+
119+
def __hash__(self) -> int:
120+
"""Hash by field value, consistent with __eq__, for use as a dict key."""
121+
return hash((self.query, self.include_untracked, self.index_filter))
84122

85123

86124
class ResultsInterface(Interface):
@@ -199,7 +237,7 @@ def register_binned_stratification(
199237
def register_stratified_observation(
200238
self,
201239
name: str,
202-
pop_filter: str = "",
240+
pop_filter: PopulationFilter = "",
203241
include_untracked: bool = False,
204242
when: str = lifecycle_states.COLLECT_METRICS,
205243
requires_attributes: list[str] = [],
@@ -219,8 +257,9 @@ def register_stratified_observation(
219257
Name of the observation. It will also be the name of the output results file
220258
for this particular observation.
221259
pop_filter
222-
A Pandas query filter string to filter the population down to the simulants who should
223-
be considered for the observation.
260+
A filter selecting which simulants to observe. Either a Pandas query string, a
261+
callable mapping the population index to the subset of indices to keep, or a
262+
``(query, callable)`` tuple that applies the query first followed by the callable.
224263
include_untracked
225264
Whether to include simulants who are untracked from this observation.
226265
when
@@ -254,7 +293,7 @@ def register_stratified_observation(
254293
self._manager.register_observation(
255294
observation_type=StratifiedObservation,
256295
name=name,
257-
population_filter=PopulationFilter(pop_filter, include_untracked),
296+
population_filter=_PopulationFilter(pop_filter, include_untracked),
258297
when=when,
259298
requires_attributes=requires_attributes,
260299
results_updater=results_updater,
@@ -269,7 +308,7 @@ def register_stratified_observation(
269308
def register_unstratified_observation(
270309
self,
271310
name: str,
272-
pop_filter: str = "",
311+
pop_filter: PopulationFilter = "",
273312
include_untracked: bool = False,
274313
when: str = lifecycle_states.COLLECT_METRICS,
275314
requires_attributes: list[str] = [],
@@ -286,8 +325,9 @@ def register_unstratified_observation(
286325
Name of the observation. It will also be the name of the output results file
287326
for this particular observation.
288327
pop_filter
289-
A Pandas query filter string to filter the population down to the simulants who should
290-
be considered for the observation.
328+
A filter selecting which simulants to observe. Either a Pandas query string, a
329+
callable mapping the population index to the subset of indices to keep, or a
330+
``(query, callable)`` tuple that applies the query first followed by the callable.
291331
include_untracked
292332
Whether to include simulants who are untracked from this observation.
293333
when
@@ -317,7 +357,7 @@ def register_unstratified_observation(
317357
self._manager.register_observation(
318358
observation_type=UnstratifiedObservation,
319359
name=name,
320-
population_filter=PopulationFilter(pop_filter, include_untracked),
360+
population_filter=_PopulationFilter(pop_filter, include_untracked),
321361
when=when,
322362
requires_attributes=requires_attributes,
323363
results_updater=results_updater,
@@ -329,7 +369,7 @@ def register_unstratified_observation(
329369
def register_adding_observation(
330370
self,
331371
name: str,
332-
pop_filter: str = "",
372+
pop_filter: PopulationFilter = "",
333373
include_untracked: bool = False,
334374
when: str = lifecycle_states.COLLECT_METRICS,
335375
requires_attributes: list[str] = [],
@@ -355,8 +395,9 @@ def register_adding_observation(
355395
Name of the observation. It will also be the name of the output results file
356396
for this particular observation.
357397
pop_filter
358-
A Pandas query filter string to filter the population down to the simulants who should
359-
be considered for the observation.
398+
A filter selecting which simulants to observe. Either a Pandas query string, a
399+
callable mapping the population index to the subset of indices to keep, or a
400+
``(query, callable)`` tuple that applies the query first followed by the callable.
360401
include_untracked
361402
Whether to include simulants who are untracked from this observation.
362403
when
@@ -382,7 +423,7 @@ def register_adding_observation(
382423
self._manager.register_observation(
383424
observation_type=AddingObservation,
384425
name=name,
385-
population_filter=PopulationFilter(pop_filter, include_untracked),
426+
population_filter=_PopulationFilter(pop_filter, include_untracked),
386427
when=when,
387428
requires_attributes=requires_attributes,
388429
results_formatter=results_formatter,
@@ -396,7 +437,7 @@ def register_adding_observation(
396437
def register_concatenating_observation(
397438
self,
398439
name: str,
399-
pop_filter: str = "",
440+
pop_filter: PopulationFilter = "",
400441
include_untracked: bool = False,
401442
when: str = lifecycle_states.COLLECT_METRICS,
402443
requires_attributes: list[str] = [],
@@ -418,15 +459,16 @@ def register_concatenating_observation(
418459
Name of the observation. It will also be the name of the output results file
419460
for this particular observation.
420461
pop_filter
421-
A Pandas query filter string to filter the population down to the simulants who should
422-
be considered for the observation.
462+
A filter selecting which simulants to observe. Either a Pandas query string, a
463+
callable mapping the population index to the subset of indices to keep, or a
464+
``(query, callable)`` tuple that applies the query first followed by the callable.
423465
include_untracked
424466
Whether to include simulants who are untracked from this observation.
425467
when
426468
Name of the lifecycle phase the observation should happen. Valid values are:
427469
"time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
428470
requires_attributes
429-
The population attributes that are required by the `aggregator`.
471+
The population attributes to record for this observation.
430472
results_formatter
431473
Function that formats the raw observation results.
432474
to_observe
@@ -435,7 +477,7 @@ def register_concatenating_observation(
435477
self._manager.register_observation(
436478
observation_type=ConcatenatingObservation,
437479
name=name,
438-
population_filter=PopulationFilter(pop_filter, include_untracked),
480+
population_filter=_PopulationFilter(pop_filter, include_untracked),
439481
when=when,
440482
requires_attributes=requires_attributes,
441483
results_formatter=results_formatter,

libs/engine/src/vivarium/engine/framework/results/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
if TYPE_CHECKING:
2727
from vivarium.engine.framework.engine import Builder
28-
from vivarium.engine.framework.results.interface import PopulationFilter
28+
from vivarium.engine.framework.results.interface import _PopulationFilter
2929

3030

3131
class ResultsManager(Manager):
@@ -255,7 +255,7 @@ def register_observation(
255255
self,
256256
observation_type: type[Observation],
257257
name: str,
258-
population_filter: PopulationFilter,
258+
population_filter: _PopulationFilter,
259259
when: str,
260260
requires_attributes: list[str],
261261
**kwargs: Any,

libs/engine/src/vivarium/engine/framework/results/observation.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
)
3939

4040
if TYPE_CHECKING:
41-
from vivarium.engine.framework.results.interface import PopulationFilter
41+
from vivarium.engine.framework.results.interface import _PopulationFilter
4242

4343
VALUE_COLUMN = "value"
4444

@@ -55,7 +55,7 @@ class Observation(ABC):
5555
name: str
5656
"""Name of the observation. It will also be the name of the output results file
5757
for this particular observation."""
58-
population_filter: PopulationFilter
58+
population_filter: _PopulationFilter
5959
"""A named tuple of population filtering details. The first item is a Pandas
6060
query string to filter the population down to the simulants who should be
6161
considered for the observation. The second item is a boolean indicating whether
@@ -149,7 +149,7 @@ class UnstratifiedObservation(Observation):
149149
def __init__(
150150
self,
151151
name: str,
152-
population_filter: PopulationFilter,
152+
population_filter: _PopulationFilter,
153153
when: str,
154154
requires_attributes: list[str],
155155
results_gatherer: Callable[[pd.DataFrame], pd.DataFrame],
@@ -233,7 +233,7 @@ class StratifiedObservation(Observation):
233233
def __init__(
234234
self,
235235
name: str,
236-
population_filter: PopulationFilter,
236+
population_filter: _PopulationFilter,
237237
when: str,
238238
requires_attributes: list[str],
239239
results_updater: Callable[[pd.DataFrame, pd.DataFrame], pd.DataFrame],
@@ -441,7 +441,7 @@ class AddingObservation(StratifiedObservation):
441441
def __init__(
442442
self,
443443
name: str,
444-
population_filter: PopulationFilter,
444+
population_filter: _PopulationFilter,
445445
when: str,
446446
requires_attributes: list[str],
447447
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],
@@ -526,7 +526,7 @@ class ConcatenatingObservation(UnstratifiedObservation):
526526
def __init__(
527527
self,
528528
name: str,
529-
population_filter: PopulationFilter,
529+
population_filter: _PopulationFilter,
530530
when: str,
531531
requires_attributes: list[str],
532532
results_formatter: Callable[[str, pd.DataFrame], pd.DataFrame],

0 commit comments

Comments
 (0)