99from __future__ import annotations
1010
1111from collections .abc import Callable
12- from typing import TYPE_CHECKING , Any , NamedTuple , Sequence , Union
12+ from typing import TYPE_CHECKING , Any , Sequence , Union
1313
1414import pandas as pd
1515from pandas .core .groupby .generic import DataFrameGroupBy
4040ResultsGatherer = Callable [[ResultsGathererInput ], pd .DataFrame ]
4141"""A Callable that optionally takes a possibly stratified population and returns
4242new observation results."""
43+ PopulationIndexFilter = Callable [[pd .Index ], pd .Index ] # type: ignore [type-arg]
44+ """A Callable that takes a population index and returns the subset of indices to keep."""
45+ PopulationFilter = Union [str , PopulationIndexFilter , tuple [str , PopulationIndexFilter ]]
46+ """A population filter given as a Pandas query string, an index-filter callable, or a
47+ ``(query, callable)`` tuple that applies the query first followed by the callable."""
4348
4449
4550def _required_function_placeholder (
@@ -76,11 +81,44 @@ def _default_unstratified_observation_formatter(
7681 return results
7782
7883
79- class PopulationFilter ( NamedTuple ) :
80- """Container class for population query string and include_untracked flag."""
84+ class _PopulationFilter :
85+ """Container for a query string, the untracked flag, and an optional index filter.
8186
82- query : str = ""
83- include_untracked : bool = False
87+ The ``pop_filter`` argument may be a Pandas query string, an index-filter
88+ callable, or a ``(query, callable)`` tuple; it is resolved into the separate
89+ ``query`` and ``index_filter`` attributes on construction.
90+ """
91+
92+ def __init__ (
93+ self , pop_filter : PopulationFilter = "" , include_untracked : bool = False
94+ ) -> None :
95+ self .query , self .index_filter = self ._parse_pop_filter (pop_filter )
96+ self .include_untracked = include_untracked
97+
98+ @staticmethod
99+ def _parse_pop_filter (
100+ pop_filter : PopulationFilter ,
101+ ) -> tuple [str , PopulationIndexFilter | None ]:
102+ """Resolve a pop_filter argument into a (query, index_filter) pair."""
103+ if isinstance (pop_filter , str ):
104+ return pop_filter , None
105+ if isinstance (pop_filter , tuple ):
106+ return pop_filter
107+ return "" , pop_filter
108+
109+ def __eq__ (self , other : object ) -> bool :
110+ """Compare by field value so equal filters share a grouping key."""
111+ if not isinstance (other , _PopulationFilter ):
112+ return NotImplemented
113+ return (self .query , self .include_untracked , self .index_filter ) == (
114+ other .query ,
115+ other .include_untracked ,
116+ other .index_filter ,
117+ )
118+
119+ def __hash__ (self ) -> int :
120+ """Hash by field value, consistent with __eq__, for use as a dict key."""
121+ return hash ((self .query , self .include_untracked , self .index_filter ))
84122
85123
86124class ResultsInterface (Interface ):
@@ -199,7 +237,7 @@ def register_binned_stratification(
199237 def register_stratified_observation (
200238 self ,
201239 name : str ,
202- pop_filter : str = "" ,
240+ pop_filter : PopulationFilter = "" ,
203241 include_untracked : bool = False ,
204242 when : str = lifecycle_states .COLLECT_METRICS ,
205243 requires_attributes : list [str ] = [],
@@ -219,8 +257,9 @@ def register_stratified_observation(
219257 Name of the observation. It will also be the name of the output results file
220258 for this particular observation.
221259 pop_filter
222- A Pandas query filter string to filter the population down to the simulants who should
223- be considered for the observation.
260+ A filter selecting which simulants to observe. Either a Pandas query string, a
261+ callable mapping the population index to the subset of indices to keep, or a
262+ ``(query, callable)`` tuple that applies the query first followed by the callable.
224263 include_untracked
225264 Whether to include simulants who are untracked from this observation.
226265 when
@@ -254,7 +293,7 @@ def register_stratified_observation(
254293 self ._manager .register_observation (
255294 observation_type = StratifiedObservation ,
256295 name = name ,
257- population_filter = PopulationFilter (pop_filter , include_untracked ),
296+ population_filter = _PopulationFilter (pop_filter , include_untracked ),
258297 when = when ,
259298 requires_attributes = requires_attributes ,
260299 results_updater = results_updater ,
@@ -269,7 +308,7 @@ def register_stratified_observation(
269308 def register_unstratified_observation (
270309 self ,
271310 name : str ,
272- pop_filter : str = "" ,
311+ pop_filter : PopulationFilter = "" ,
273312 include_untracked : bool = False ,
274313 when : str = lifecycle_states .COLLECT_METRICS ,
275314 requires_attributes : list [str ] = [],
@@ -286,8 +325,9 @@ def register_unstratified_observation(
286325 Name of the observation. It will also be the name of the output results file
287326 for this particular observation.
288327 pop_filter
289- A Pandas query filter string to filter the population down to the simulants who should
290- be considered for the observation.
328+ A filter selecting which simulants to observe. Either a Pandas query string, a
329+ callable mapping the population index to the subset of indices to keep, or a
330+ ``(query, callable)`` tuple that applies the query first followed by the callable.
291331 include_untracked
292332 Whether to include simulants who are untracked from this observation.
293333 when
@@ -317,7 +357,7 @@ def register_unstratified_observation(
317357 self ._manager .register_observation (
318358 observation_type = UnstratifiedObservation ,
319359 name = name ,
320- population_filter = PopulationFilter (pop_filter , include_untracked ),
360+ population_filter = _PopulationFilter (pop_filter , include_untracked ),
321361 when = when ,
322362 requires_attributes = requires_attributes ,
323363 results_updater = results_updater ,
@@ -329,7 +369,7 @@ def register_unstratified_observation(
329369 def register_adding_observation (
330370 self ,
331371 name : str ,
332- pop_filter : str = "" ,
372+ pop_filter : PopulationFilter = "" ,
333373 include_untracked : bool = False ,
334374 when : str = lifecycle_states .COLLECT_METRICS ,
335375 requires_attributes : list [str ] = [],
@@ -355,8 +395,9 @@ def register_adding_observation(
355395 Name of the observation. It will also be the name of the output results file
356396 for this particular observation.
357397 pop_filter
358- A Pandas query filter string to filter the population down to the simulants who should
359- be considered for the observation.
398+ A filter selecting which simulants to observe. Either a Pandas query string, a
399+ callable mapping the population index to the subset of indices to keep, or a
400+ ``(query, callable)`` tuple that applies the query first followed by the callable.
360401 include_untracked
361402 Whether to include simulants who are untracked from this observation.
362403 when
@@ -382,7 +423,7 @@ def register_adding_observation(
382423 self ._manager .register_observation (
383424 observation_type = AddingObservation ,
384425 name = name ,
385- population_filter = PopulationFilter (pop_filter , include_untracked ),
426+ population_filter = _PopulationFilter (pop_filter , include_untracked ),
386427 when = when ,
387428 requires_attributes = requires_attributes ,
388429 results_formatter = results_formatter ,
@@ -396,7 +437,7 @@ def register_adding_observation(
396437 def register_concatenating_observation (
397438 self ,
398439 name : str ,
399- pop_filter : str = "" ,
440+ pop_filter : PopulationFilter = "" ,
400441 include_untracked : bool = False ,
401442 when : str = lifecycle_states .COLLECT_METRICS ,
402443 requires_attributes : list [str ] = [],
@@ -418,15 +459,16 @@ def register_concatenating_observation(
418459 Name of the observation. It will also be the name of the output results file
419460 for this particular observation.
420461 pop_filter
421- A Pandas query filter string to filter the population down to the simulants who should
422- be considered for the observation.
462+ A filter selecting which simulants to observe. Either a Pandas query string, a
463+ callable mapping the population index to the subset of indices to keep, or a
464+ ``(query, callable)`` tuple that applies the query first followed by the callable.
423465 include_untracked
424466 Whether to include simulants who are untracked from this observation.
425467 when
426468 Name of the lifecycle phase the observation should happen. Valid values are:
427469 "time_step__prepare", "time_step", "time_step__cleanup", or "collect_metrics".
428470 requires_attributes
429- The population attributes that are required by the `aggregator` .
471+ The population attributes to record for this observation .
430472 results_formatter
431473 Function that formats the raw observation results.
432474 to_observe
@@ -435,7 +477,7 @@ def register_concatenating_observation(
435477 self ._manager .register_observation (
436478 observation_type = ConcatenatingObservation ,
437479 name = name ,
438- population_filter = PopulationFilter (pop_filter , include_untracked ),
480+ population_filter = _PopulationFilter (pop_filter , include_untracked ),
439481 when = when ,
440482 requires_attributes = requires_attributes ,
441483 results_formatter = results_formatter ,
0 commit comments