@@ -52,7 +52,7 @@ class ResultsContext:
5252 objects to be produced keyed by the observation name.
5353 grouped_observations
5454 Dictionary of observation details. It is of the format
55- {lifecycle_state: {( pop_filter, stratifications) : list[Observation]}}.
55+ {lifecycle_state: {pop_filter: { stratifications: list[Observation]} }}.
5656 Allowable lifecycle_states are "time_step__prepare", "time_step",
5757 "time_step__cleanup", and "collect_metrics".
5858 logger
@@ -65,8 +65,8 @@ def __init__(self) -> None:
6565 self .excluded_categories : dict [str , list [str ]] = {}
6666 self .observations : dict [str , Observation ] = {}
6767 self .grouped_observations : defaultdict [
68- str , defaultdict [tuple [ str , tuple [str , ...] | None ] , list [Observation ]]
69- ] = defaultdict (lambda : defaultdict (list ))
68+ str , defaultdict [str , defaultdict [ tuple [str , ...] | None , list [Observation ] ]]
69+ ] = defaultdict (lambda : defaultdict (lambda : defaultdict ( list ) ))
7070
7171 @property
7272 def name (self ) -> str :
@@ -112,18 +112,18 @@ def set_stratifications(self) -> None:
112112 """
113113 used_stratifications : set [str ] = set ()
114114 for state_observations in self .grouped_observations .values ():
115- for observation_details in state_observations .items ():
116- ( _ , stratification_names ) , observations = observation_details
117- if stratification_names is None :
118- continue
119-
120- used_stratifications |= set (stratification_names )
121- for observation in observations :
122- observation .stratifications = tuple (
123- self .stratifications [name ]
124- for name in stratification_names
125- if name in self .stratifications
126- )
115+ for pop_filter_observations in state_observations .values ():
116+ for stratification_names , observations in pop_filter_observations . items ():
117+ if stratification_names is None :
118+ continue
119+
120+ used_stratifications |= set (stratification_names )
121+ for observation in observations :
122+ observation .stratifications = tuple (
123+ self .stratifications [name ]
124+ for name in stratification_names
125+ if name in self .stratifications
126+ )
127127
128128 if unused_stratifications := set (self .stratifications .keys ()) - used_stratifications :
129129 self .logger .info (
@@ -272,8 +272,8 @@ def register_observation(
272272 ** kwargs ,
273273 )
274274 self .observations [name ] = observation
275- self .grouped_observations [observation .when ][
276- ( observation . pop_filter , stratifications )
275+ self .grouped_observations [observation .when ][observation . pop_filter ][
276+ stratifications
277277 ].append (observation )
278278 return observation
279279
@@ -318,26 +318,36 @@ def gather_results(
318318
319319 # Optimization: We store all the producers by pop_filter and stratifications
320320 # so that we only have to apply them once each time we compute results.
321- for ( pop_filter , stratification_names ), observations in self .grouped_observations [
321+ for pop_filter , stratification_observations in self .grouped_observations [
322322 lifecycle_state
323323 ].items ():
324- observations = [obs for obs in observations if obs in event_observations ]
325- if not observations :
324+ event_pop_filter_observations = [
325+ observation
326+ for observations in stratification_observations .values ()
327+ for observation in observations
328+ if observation in event_observations
329+ ]
330+ if not event_pop_filter_observations :
326331 continue
327332
328- # Results production can be simplified to
329- # filter -> groupby -> aggregate in all situations we've seen.
330- filtered_pop = self ._filter_population (
331- population , pop_filter , stratification_names
332- )
333- if filtered_pop .empty :
333+ filtered_population = self ._filter_population (population , pop_filter )
334+ if filtered_population .empty :
334335 continue
335- else :
336+
337+ for stratification_names , observations in stratification_observations .items ():
338+ observations = [
339+ obs for obs in observations if obs in event_pop_filter_observations
340+ ]
341+ if not observations :
342+ continue
343+
336344 pop : pd .DataFrame | DataFrameGroupBy [tuple [str , ...] | str , bool ]
337- if stratification_names is None :
338- pop = filtered_pop
339- else :
340- pop = self ._get_groups (stratification_names , filtered_pop )
345+ pop = self ._drop_na_stratifications (filtered_population , stratification_names )
346+ if pop .empty :
347+ continue
348+ if stratification_names is not None :
349+ pop = self ._get_groups (stratification_names , pop )
350+
341351 for observation in observations :
342352 results = observation .observe (pop , stratification_names )
343353 yield (results , observation .name , observation .results_updater )
@@ -357,7 +367,8 @@ def get_observations(self, event: Event) -> list[Observation]:
357367 """
358368 return [
359369 observation
360- for observations in self .grouped_observations [event .name ].values ()
370+ for stratification_observations in self .grouped_observations [event .name ].values ()
371+ for observations in stratification_observations .values ()
361372 for observation in observations
362373 if observation .to_observe (event )
363374 ]
@@ -436,24 +447,24 @@ def get_required_values(
436447 required_values .update (stratification .requires_values )
437448 return list (required_values )
438449
439- def _filter_population (
440- self ,
441- population : pd .DataFrame ,
442- pop_filter : str ,
443- stratification_names : tuple [str , ...] | None ,
450+ def _filter_population (self , population : pd .DataFrame , pop_filter : str ) -> pd .DataFrame :
451+ """Filter out simulants not to observe."""
452+ return population .query (pop_filter ) if pop_filter else population .copy ()
453+
454+ def _drop_na_stratifications (
455+ self , population : pd .DataFrame , stratification_names : tuple [str , ...] | None
444456 ) -> pd .DataFrame :
445457 """Filter out simulants not to observe."""
446- pop = population .query (pop_filter ) if pop_filter else population .copy ()
447458 if stratification_names :
448459 # Drop all rows in the mapped_stratification columns that have NaN values
449460 # (which only exist if the mapper returned an excluded category).
450- pop = pop .dropna (
461+ population = population .dropna (
451462 subset = [
452463 get_mapped_col_name (stratification )
453464 for stratification in stratification_names
454465 ]
455466 )
456- return pop
467+ return population
457468
458469 @staticmethod
459470 def _get_groups (
0 commit comments