|
| 1 | +from abc import ABC |
| 2 | +from collections.abc import Collection |
| 3 | +from typing import Any, Literal |
| 4 | + |
| 5 | +import pandas as pd |
| 6 | + |
| 7 | +from vivarium_testing_utils.automated_validation.constants import ( |
| 8 | + DRAW_INDEX, |
| 9 | + SEED_INDEX, |
| 10 | + DataSource, |
| 11 | +) |
| 12 | +from vivarium_testing_utils.automated_validation.data_loader import DataLoader |
| 13 | +from vivarium_testing_utils.automated_validation.data_transformation import ( |
| 14 | + age_groups, |
| 15 | + calculations, |
| 16 | +) |
| 17 | +from vivarium_testing_utils.automated_validation.data_transformation.measures import ( |
| 18 | + CategoricalRelativeRisk, |
| 19 | + Measure, |
| 20 | + RatioMeasure, |
| 21 | + RiskExposure, |
| 22 | +) |
| 23 | +from vivarium_testing_utils.automated_validation.visualization import dataframe_utils |
| 24 | + |
| 25 | + |
| 26 | +class MeasureDataBundle(ABC): |
| 27 | + measure: Measure |
| 28 | + source: DataSource |
| 29 | + data_loader: DataLoader |
| 30 | + scenarios: dict[str, str] | None |
| 31 | + |
| 32 | + |
| 33 | +class RatioMeasureDataBundle: |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + measure: RatioMeasure, |
| 37 | + source: DataSource, |
| 38 | + data_loader: DataLoader, |
| 39 | + age_group_df: pd.DataFrame, |
| 40 | + scenarios: dict[str, str] | None = None, |
| 41 | + ) -> None: |
| 42 | + self.measure = measure |
| 43 | + self.source = source |
| 44 | + self.scenarios = scenarios if scenarios is not None else {} |
| 45 | + self.datasets = self._get_formatted_datasets(data_loader, age_group_df) |
| 46 | + self.weights = self._get_aggregated_weights(data_loader, age_group_df) |
| 47 | + |
| 48 | + @property |
| 49 | + def dataset_names(self) -> dict[str, str]: |
| 50 | + """Return a dictionary of required datasets for the specified source.""" |
| 51 | + if self.source == DataSource.SIM: |
| 52 | + return self.measure.sim_output_datasets |
| 53 | + elif self.source in ([DataSource.ARTIFACT, DataSource.GBD]): |
| 54 | + return self.measure.sim_input_datasets |
| 55 | + else: |
| 56 | + raise ValueError(f"Unsupported data source: {self.source}") |
| 57 | + |
| 58 | + @property |
| 59 | + def index_names(self) -> set[str]: |
| 60 | + return { |
| 61 | + index_name |
| 62 | + for key in self.datasets |
| 63 | + for index_name in self.datasets[key].index.names |
| 64 | + } |
| 65 | + |
| 66 | + def get_metadata(self) -> dict[str, Any]: |
| 67 | + """Organize the data information into a dictionary for display by a styled pandas DataFrame. |
| 68 | + Apply formatting to values that need special handling. |
| 69 | +
|
| 70 | + Returns: |
| 71 | + -------- |
| 72 | + A dictionary containing the formatted data information. |
| 73 | +
|
| 74 | + """ |
| 75 | + dataframe = self.get_measure_data("all") |
| 76 | + data_info: dict[str, Any] = {} |
| 77 | + |
| 78 | + # Source as string |
| 79 | + data_info["source"] = self.source.value |
| 80 | + |
| 81 | + # Index columns as comma-separated string |
| 82 | + data_info["index_columns"] = list(dataframe.index.names) |
| 83 | + |
| 84 | + # Size as formatted string |
| 85 | + size = dataframe.shape |
| 86 | + data_info["size"] = f"{size[0]:,} rows × {size[1]:,} columns" |
| 87 | + |
| 88 | + # Draw information |
| 89 | + if DRAW_INDEX in dataframe.index.names: |
| 90 | + num_draws = dataframe.index.get_level_values(DRAW_INDEX).nunique() |
| 91 | + data_info["num_draws"] = f"{num_draws:,}" |
| 92 | + draw_values = list(dataframe.index.get_level_values(DRAW_INDEX).unique()) |
| 93 | + data_info[DRAW_INDEX + "s"] = dataframe_utils.format_draws_sample( |
| 94 | + draw_values, self.source |
| 95 | + ) |
| 96 | + |
| 97 | + # Seeds information |
| 98 | + if SEED_INDEX in dataframe.index.names: |
| 99 | + num_seeds = dataframe.index.get_level_values(SEED_INDEX).nunique() |
| 100 | + data_info["num_seeds"] = f"{num_seeds:,}" |
| 101 | + |
| 102 | + return data_info |
| 103 | + |
| 104 | + def _get_formatted_datasets( |
| 105 | + self, data_loader: DataLoader, age_group_data: pd.DataFrame |
| 106 | + ) -> dict[str, pd.DataFrame]: |
| 107 | + """Formats measure datasets depending on the source.""" |
| 108 | + raw_datasets = data_loader._get_raw_data_from_source(self.dataset_names, self.source) |
| 109 | + if self.source == DataSource.SIM: |
| 110 | + datasets = self.measure.get_ratio_datasets_from_sim( |
| 111 | + **raw_datasets, |
| 112 | + ) |
| 113 | + elif self.source in [DataSource.ARTIFACT, DataSource.GBD]: |
| 114 | + data = self.measure.get_measure_data_from_sim_inputs(**raw_datasets) |
| 115 | + datasets = {"data": data} |
| 116 | + elif self.source == DataSource.CUSTOM: |
| 117 | + raise NotImplementedError |
| 118 | + else: |
| 119 | + raise ValueError(f"Unsupported data source: {self.source}") |
| 120 | + |
| 121 | + datasets = { |
| 122 | + dataset_name: age_groups.format_dataframe_from_age_bin_df(dataset, age_group_data) |
| 123 | + for dataset_name, dataset in datasets.items() |
| 124 | + } |
| 125 | + datasets = { |
| 126 | + key: calculations.filter_data(dataset, self.scenarios, drop_singles=True) |
| 127 | + for key, dataset in datasets.items() |
| 128 | + } |
| 129 | + |
| 130 | + return datasets |
| 131 | + |
| 132 | + def _get_aggregated_weights( |
| 133 | + self, data_loader: DataLoader, age_group_data: pd.DataFrame |
| 134 | + ) -> pd.DataFrame | None: |
| 135 | + """Fetches and aggregates weights if required by the measure.""" |
| 136 | + if self.source not in [DataSource.ARTIFACT, DataSource.GBD]: |
| 137 | + return None |
| 138 | + |
| 139 | + raw_weights = data_loader._get_raw_data_from_source( |
| 140 | + self.measure.rate_aggregation_weights.weight_keys, self.source |
| 141 | + ) |
| 142 | + weights = self.measure.rate_aggregation_weights.get_weights(**raw_weights) |
| 143 | + return age_groups.format_dataframe_from_age_bin_df(weights, age_group_data) |
| 144 | + |
| 145 | + def get_measure_data( |
| 146 | + self, stratifications: Collection[str] | Literal["all"] |
| 147 | + ) -> pd.DataFrame: |
| 148 | + """Get the measure data, optionally aggregated over specified stratifications.""" |
| 149 | + if self.source == DataSource.SIM: |
| 150 | + return self._aggregate_scenario_stratifications(self.datasets, stratifications) |
| 151 | + elif self.source in [DataSource.ARTIFACT, DataSource.GBD]: |
| 152 | + return self._aggregate_sim_input_stratifications(stratifications) |
| 153 | + elif self.source == DataSource.CUSTOM: |
| 154 | + raise NotImplementedError |
| 155 | + else: |
| 156 | + raise ValueError(f"Unsupported data source: {self.source}") |
| 157 | + |
| 158 | + def _aggregate_scenario_stratifications( |
| 159 | + self, |
| 160 | + datasets: dict[str, pd.DataFrame], |
| 161 | + stratifications: Collection[str] | Literal["all"], |
| 162 | + ) -> pd.DataFrame: |
| 163 | + """This will remove index levels corresponding to the specified stratifications""" |
| 164 | + datasets = { |
| 165 | + key: calculations.stratify(datasets[key], stratifications) for key in datasets |
| 166 | + } |
| 167 | + return self.measure.get_measure_data_from_ratio(**datasets) |
| 168 | + |
| 169 | + def _aggregate_sim_input_stratifications( |
| 170 | + self, stratifications: Collection[str] | Literal["all"] |
| 171 | + ) -> pd.DataFrame: |
| 172 | + """Aggregate the artifact data over specified stratifications. Stratifactions will be retained |
| 173 | + in the returned data.""" |
| 174 | + data = self.datasets["data"].copy() |
| 175 | + if stratifications != "all": |
| 176 | + stratifications = list(stratifications) |
| 177 | + # Retain input_draw, comparison._aggregate_over_draws is the only place we should aggregate over draws. |
| 178 | + if DRAW_INDEX in data.index.names and DRAW_INDEX not in stratifications: |
| 179 | + stratifications.append(DRAW_INDEX) |
| 180 | + if self.weights is None: |
| 181 | + raise ValueError("Weights are required for aggregating artifact data.") |
| 182 | + |
| 183 | + # Update scenario columns to retain during aggregation |
| 184 | + scenario_cols = [] |
| 185 | + # NOTE: This is a hack to handle alignment of index levels in weighted_average. Risk |
| 186 | + # stratification column is treated as a scenario column and the population can be |
| 187 | + # broadcast across each index group since the exposure for each group should sum to 1. |
| 188 | + if isinstance(self.measure, (RiskExposure, CategoricalRelativeRisk)): |
| 189 | + scenario_cols.append(self.measure.risk_stratification_column) |
| 190 | + scenario_cols.extend(list(self.scenarios.keys())) |
| 191 | + weighted_avg = calculations.weighted_average( |
| 192 | + data, self.weights, stratifications, scenario_cols |
| 193 | + ) |
| 194 | + |
| 195 | + # Reference data can be a float or dataframe. Convert floats so dataframes are aligned |
| 196 | + if not isinstance(weighted_avg, pd.DataFrame): |
| 197 | + weighted_avg = pd.DataFrame( |
| 198 | + {"value": [weighted_avg]}, index=pd.Index([0], name="index") |
| 199 | + ) |
| 200 | + return weighted_avg |
0 commit comments