Skip to content
Merged
96 changes: 96 additions & 0 deletions src/vivarium_csu_alzheimers/components/observers.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def semesterize(date):

return pop.squeeze(axis=1).dt.date.apply(semesterize)

def map_treatment_durations(self, pop: pd.DataFrame) -> pd.Series:
durations = pop.fillna(0.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you just set everyone who didn't get treated to have a treatment duration of 0 instead of np.nan, you wouldn't need any mapping at all.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that in the interactive sim it might be confusing if someone has 0 months of treatment duration but they have never tested positive or ultimately had the chance to be treated. Maybe you wouldn't find that confusing? I do not feel strongly.

durations = durations.astype(int).squeeze(axis=1)
return durations

def register_stratifications(self, builder):
super().register_stratifications(builder)
builder.results.register_stratification(
Expand All @@ -82,6 +87,13 @@ def register_stratifications(self, builder):
is_vectorized=True,
requires_columns=["event_time"],
)
builder.results.register_stratification(
name="treatment_durations",
categories=list(range(10)),
mapper=self.map_treatment_durations,
is_vectorized=True,
requires_columns=[COLUMNS.TREATMENT_DURATION],
)


class NewSimulantsObserver(Observer):
Expand Down Expand Up @@ -284,9 +296,39 @@ def get_entity_column(self, measure: str, results: pd.DataFrame) -> pd.Series:


class TreatmentObserver(DiseaseObserver):
@property
def columns_required(self) -> list[str]:
return super().columns_required + [
COLUMNS.WAITING_FOR_TREATMENT_EVENT_TIME,
COLUMNS.TREATMENT_DURATION,
]

def __init__(self) -> None:
super().__init__("treatment")

def setup(self, builder):
super().setup(builder)
self.clock = builder.time.clock()
self.sim_start_time = pd.Timestamp(
month=builder.configuration.time.start.month,
day=builder.configuration.time.start.day,
year=builder.configuration.time.start.year,
)

def register_observations(self, builder):
super().register_observations(builder)
self.register_adding_observation(
builder=builder,
name="treatment_duration",
pop_filter='alive == "alive" and tracked==True',
requires_columns=[
COLUMNS.WAITING_FOR_TREATMENT_EVENT_TIME,
],
additional_stratifications=self.configuration.include + ["treatment_durations"],
excluded_stratifications=self.configuration.exclude,
aggregator=self.count_treatment_durations,
)

def register_disease_state_stratification(self, builder: Builder) -> None:
"""Register the disease state stratification.

Expand Down Expand Up @@ -336,3 +378,57 @@ def register_transition_stratification(self, builder: Builder) -> None:
requires_columns=[self.disease, self.previous_state_column_name],
is_vectorized=True,
)

def count_treatment_durations(self, pop: pd.DataFrame) -> float:
"""Aggregate the total treatment durations for simulants in the population."""
# Handle first time step where we have initial population and time step observations
if self.clock() == self.sim_start_time:
treatment_durations = pop[COLUMNS.WAITING_FOR_TREATMENT_EVENT_TIME].notna()
else:
treatment_durations = (
pop[COLUMNS.WAITING_FOR_TREATMENT_EVENT_TIME]
== self.clock() + self.step_size()
)
return sum(treatment_durations)

def format(self, measure: str, results: pd.DataFrame) -> pd.DataFrame:
"""Rename the appropriate column to 'sub_entity'.

The primary thing this method does is rename the appropriate column
(either the transition stratification name of the disease name, depending
on the measure) to 'sub_entity'. We do this here instead of the
'get_sub_entity_column' method simply because we do not want the original
column at all. If we keep it here and then return it as the sub-entity
column later, the final results would have both.

Parameters
----------
measure
The measure.
results
The results to format.

Returns
-------
The formatted results.
"""
results = results.reset_index()
if "transition_count_" in measure:
sub_entity = self.transition_stratification_name
if "person_time_" in measure:
sub_entity = self.disease
# Handle treatment_duration measure
if measure == "treatment_duration":
sub_entity = "treatment_durations"
results.rename(columns={sub_entity: "sub_entity"}, inplace=True)
return results

def get_measure_column(self, measure: str, results: pd.DataFrame) -> pd.Series:
"""Get the 'measure' column values."""
if "transition_count_" in measure:
measure_name = "transition_count"
if "person_time_" in measure:
measure_name = "person_time"
if measure == "treatment_duration":
measure_name = "treatment_duration_count"
return pd.Series(measure_name, index=results.index)
Loading