Skip to content

Commit 44f7eac

Browse files
authored
Albrja/mic 6600/sort get frame (#90)
Albrja/mic 6600/sort get frame Sort output of get_frame - *Category*: Feature - *JIRA issue*: https://jira.ihme.washington.edu/browse/MIC-6600 Changes and notes -containerizes column strings -sort dataframe returned to user from get_frame
1 parent c1bc3de commit 44f7eac

File tree

11 files changed

+336
-86
lines changed

11 files changed

+336
-86
lines changed

Jenkinsfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ library("vivarium_build_utils@${get_vbu_version()}")
2929

3030
reusable_pipeline(
3131
scheduled_branches: ["main", "epic/auto-validation"],
32+
requires_slurm: true,
3233
skip_doc_build: true
3334
)

src/vivarium_testing_utils/automated_validation/constants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from enum import Enum
4+
from typing import NamedTuple
45

56
DRAW_PREFIX = "draw_"
67

@@ -24,3 +25,23 @@ def from_str(cls, source: str) -> DataSource:
2425

2526
LOCATION_ARTIFACT_KEY = "population.location"
2627
POPULATION_STRUCTURE_ARTIFACT_KEY = "population.structure"
28+
29+
30+
class InputDataIndexNames(NamedTuple):
31+
LOCATION_ID: str = "location_id"
32+
SEX_ID: str = "sex_id"
33+
AGE_GROUP_ID: str = "age_group_id"
34+
YEAR_ID: str = "year_id"
35+
PARAMETER: str = "parameter"
36+
CAUSE_ID: str = "cause_id"
37+
AFFECTED_ENTITY: str = "affected_entity"
38+
LOCATION: str = "location"
39+
SEX: str = "sex"
40+
AGE_GROUP: str = "age_group"
41+
AGE_START: str = "age_start"
42+
AGE_END: str = "age_end"
43+
YEAR_START: str = "year_start"
44+
YEAR_END: str = "year_end"
45+
46+
47+
INPUT_DATA_INDEX_NAMES = InputDataIndexNames()

src/vivarium_testing_utils/automated_validation/data_transformation/age_groups.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import pandas as pd
66
from loguru import logger
77

8-
AGE_GROUP_COLUMN = "age_group"
9-
AGE_START_COLUMN = "age_start"
10-
AGE_END_COLUMN = "age_end"
8+
from vivarium_testing_utils.automated_validation.constants import INPUT_DATA_INDEX_NAMES
119

1210
AgeTuple = tuple[str, int | float, int | float]
1311
AgeRange = tuple[int | float, int | float]
@@ -316,12 +314,19 @@ def from_dataframe(cls, df: pd.DataFrame) -> AgeSchema:
316314
-------
317315
An AgeSchema with the specified age groups.
318316
"""
319-
has_names = AGE_GROUP_COLUMN in df.index.names
320-
has_ranges = AGE_START_COLUMN in df.index.names and AGE_END_COLUMN in df.index.names
317+
has_age_group = INPUT_DATA_INDEX_NAMES.AGE_GROUP in df.index.names
318+
has_age_range = (
319+
INPUT_DATA_INDEX_NAMES.AGE_START in df.index.names
320+
and INPUT_DATA_INDEX_NAMES.AGE_END in df.index.names
321+
)
321322

322323
# Usually this occurs for the artifact population.age_bins
323-
if has_names and has_ranges:
324-
levels = [AGE_GROUP_COLUMN, AGE_START_COLUMN, AGE_END_COLUMN]
324+
if has_age_group and has_age_range:
325+
levels = [
326+
INPUT_DATA_INDEX_NAMES.AGE_GROUP,
327+
INPUT_DATA_INDEX_NAMES.AGE_START,
328+
INPUT_DATA_INDEX_NAMES.AGE_END,
329+
]
325330
age_groups = list(
326331
df.index.droplevel(list(set(df.index.names) - set(levels)))
327332
.reorder_levels(levels)
@@ -330,17 +335,17 @@ def from_dataframe(cls, df: pd.DataFrame) -> AgeSchema:
330335

331336
return cls.from_tuples(age_groups)
332337
# Most artifact dataframes have age start/end but not age group
333-
elif has_ranges:
334-
levels = [AGE_START_COLUMN, AGE_END_COLUMN]
338+
elif has_age_range:
339+
levels = [INPUT_DATA_INDEX_NAMES.AGE_START, INPUT_DATA_INDEX_NAMES.AGE_END]
335340
age_groups = (
336341
df.index.droplevel(list(set(df.index.names) - set(levels)))
337342
.reorder_levels(levels)
338343
.unique()
339344
)
340345
return cls.from_ranges(age_groups)
341346
# Most simulation dataframes have age group but not start/end
342-
elif has_names:
343-
levels = [AGE_GROUP_COLUMN]
347+
elif has_age_group:
348+
levels = [INPUT_DATA_INDEX_NAMES.AGE_GROUP]
344349
age_groups = list(
345350
df.index.droplevel(list(set(df.index.names) - set(levels))).unique()
346351
)
@@ -355,12 +360,16 @@ def to_dataframe(self) -> pd.DataFrame:
355360
Convert the AgeSchema to a DataFrame with age group names and their start and end ages.
356361
"""
357362
data = {
358-
AGE_GROUP_COLUMN: [group.name for group in self.age_groups],
359-
AGE_START_COLUMN: [group.start for group in self.age_groups],
360-
AGE_END_COLUMN: [group.end for group in self.age_groups],
363+
INPUT_DATA_INDEX_NAMES.AGE_GROUP: [group.name for group in self.age_groups],
364+
INPUT_DATA_INDEX_NAMES.AGE_START: [group.start for group in self.age_groups],
365+
INPUT_DATA_INDEX_NAMES.AGE_END: [group.end for group in self.age_groups],
361366
}
362367
return pd.DataFrame(data).set_index(
363-
[AGE_GROUP_COLUMN, AGE_START_COLUMN, AGE_END_COLUMN]
368+
[
369+
INPUT_DATA_INDEX_NAMES.AGE_GROUP,
370+
INPUT_DATA_INDEX_NAMES.AGE_START,
371+
INPUT_DATA_INDEX_NAMES.AGE_END,
372+
]
364373
)
365374

366375
def _validate(self) -> None:
@@ -428,7 +437,11 @@ def _format_dataframe(target_schema: AgeSchema, df: pd.DataFrame) -> pd.DataFram
428437
"""
429438
source_age_schema = AgeSchema.from_dataframe(df)
430439
index_names = list(df.index.names)
431-
for age_group_indices in [AGE_GROUP_COLUMN, AGE_START_COLUMN, AGE_END_COLUMN]:
440+
for age_group_indices in [
441+
INPUT_DATA_INDEX_NAMES.AGE_GROUP,
442+
INPUT_DATA_INDEX_NAMES.AGE_START,
443+
INPUT_DATA_INDEX_NAMES.AGE_END,
444+
]:
432445
if age_group_indices not in index_names:
433446
index_names.append(age_group_indices)
434447
df = pd.merge(
@@ -443,21 +456,22 @@ def _format_dataframe(target_schema: AgeSchema, df: pd.DataFrame) -> pd.DataFram
443456
if source_age_schema.is_subset(target_schema):
444457
return (
445458
pd.merge(
446-
df.droplevel([AGE_GROUP_COLUMN]),
459+
df.droplevel([INPUT_DATA_INDEX_NAMES.AGE_GROUP]),
447460
target_schema.to_dataframe(),
448461
left_index=True,
449462
right_index=True,
450463
)
451464
.reorder_levels(index_names)
452-
.droplevel([AGE_START_COLUMN, AGE_END_COLUMN])
465+
.droplevel([INPUT_DATA_INDEX_NAMES.AGE_START, INPUT_DATA_INDEX_NAMES.AGE_END])
453466
)
454467
else:
455468
logger.info(
456469
f"Rebinning DataFrame age groups from {source_age_schema} to {target_schema}."
457470
)
458471
# if we don't fit pandera schema SimOutputData, assume the data is rate data and raise an error.
459472
data = rebin_count_dataframe(
460-
target_schema, df.droplevel([AGE_START_COLUMN, AGE_END_COLUMN])
473+
target_schema,
474+
df.droplevel([INPUT_DATA_INDEX_NAMES.AGE_START, INPUT_DATA_INDEX_NAMES.AGE_END]),
461475
)
462476
return data
463477

@@ -496,18 +510,20 @@ def rebin_count_dataframe(
496510
# Unstack the DataFrame to get the age groups as columns
497511
unstacked_series = (
498512
df[val_col]
499-
.unstack(level=AGE_GROUP_COLUMN, fill_value=0)
513+
.unstack(level=INPUT_DATA_INDEX_NAMES.AGE_GROUP, fill_value=0)
500514
.reindex(columns=transform_matrix.columns, fill_value=0)
501515
)
502516

503517
# Perform the dot product
504518
result_matrix_for_col = unstacked_series.dot(transform_matrix.T)
505519

506-
# Name the column AGE_GROUP_COLUMN for re-stacking
507-
result_matrix_for_col.columns.name = AGE_GROUP_COLUMN
520+
# Name the column GBD_INDEX_NAMES.AGE_GROUP for re-stacking
521+
result_matrix_for_col.columns.name = INPUT_DATA_INDEX_NAMES.AGE_GROUP
508522

509523
# Stack the new age group columns into the index
510-
stacked_series_for_col = result_matrix_for_col.stack(level=AGE_GROUP_COLUMN)
524+
stacked_series_for_col = result_matrix_for_col.stack(
525+
level=INPUT_DATA_INDEX_NAMES.AGE_GROUP
526+
)
511527
stacked_series_for_col.name = val_col
512528

513529
all_results_series.append(stacked_series_for_col)

src/vivarium_testing_utils/automated_validation/data_transformation/utils.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
import pandas as pd
66
import pandera as pa
7+
from vivarium_inputs.globals import DEMOGRAPHIC_COLUMNS, VIVARIUM_COLUMNS
8+
9+
from vivarium_testing_utils.automated_validation.constants import INPUT_DATA_INDEX_NAMES
710

811
F = TypeVar("F", bound=Callable[..., Any])
912

@@ -58,24 +61,21 @@ def drop_extra_columns(raw_gbd: pd.DataFrame, data_key: str) -> pd.DataFrame:
5861
f"No value columns found in the data. Columns found: {raw_gbd.columns.tolist()}"
5962
)
6063

61-
gbd_cols = ["location_id", "sex_id", "age_group_id", "year_id", "cause_id"]
62-
measure = data_key.split(".")[-1]
63-
if measure in ["exposure", "relative_risk"]:
64-
gbd_cols.append("parameter")
64+
gbd_cols = get_measure_index_names(data_key)
6565
columns_to_keep = [col for col in raw_gbd.columns if col in gbd_cols + value_cols]
6666
return raw_gbd[columns_to_keep]
6767

6868

6969
def set_gbd_index(data: pd.DataFrame, data_key: str) -> pd.DataFrame:
7070
"""Set the index of a GBD DataFrame based on the data key."""
71-
measure = data_key.split(".")[-1]
72-
gbd_cols = ["location_id", "sex_id", "age_group_id", "year_id"]
73-
if measure in ["exposure", "relative_risk"]:
74-
gbd_cols.append("parameter")
75-
if measure != "relative_risk" and "cause_id" in data.columns:
76-
data = data.drop(columns=["cause_id"])
71+
gbd_cols = get_measure_index_names(data_key)
7772

78-
index_cols = [col for col in gbd_cols if col in data.columns]
73+
# CAUSE_ID is expected to be a column when Vivarium Inputs maps all of the IDs to values.
74+
index_cols = [
75+
col
76+
for col in gbd_cols
77+
if col in data.columns and col != INPUT_DATA_INDEX_NAMES.CAUSE_ID
78+
]
7979

8080
formatted = data.set_index(index_cols)
8181
return formatted
@@ -93,3 +93,35 @@ def set_validation_index(data: pd.DataFrame) -> pd.DataFrame:
9393
data = data.set_index(sorted_data_index)
9494

9595
return data
96+
97+
98+
def get_measure_index_names(data_key: str, data_schema: str = "gbd") -> list[str]:
99+
"""Get the expected index names for a given data key.
100+
101+
Parameters
102+
----------
103+
data_key : str
104+
The data key to get the index names for.
105+
data_schema : str
106+
The data schema type. Either "gbd" or "vivarium". Defaults to "gbd".
107+
108+
Returns
109+
-------
110+
list[str]
111+
The list of expected index names for the given data key.
112+
"""
113+
114+
measure = data_key.split(".")[-1]
115+
if data_schema == "gbd":
116+
measure_cols = list(DEMOGRAPHIC_COLUMNS)
117+
else:
118+
measure_cols = list(VIVARIUM_COLUMNS)
119+
if measure in ["exposure", "relative_risk"]:
120+
measure_cols.append(INPUT_DATA_INDEX_NAMES.PARAMETER)
121+
if measure == "relative_risk":
122+
if data_schema == "gbd":
123+
measure_cols.append(INPUT_DATA_INDEX_NAMES.CAUSE_ID)
124+
else:
125+
measure_cols.append(INPUT_DATA_INDEX_NAMES.AFFECTED_ENTITY)
126+
127+
return measure_cols

src/vivarium_testing_utils/automated_validation/interface.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from vivarium_testing_utils.automated_validation.data_transformation.utils import (
2323
drop_extra_columns,
24+
get_measure_index_names,
2425
set_gbd_index,
2526
set_validation_index,
2627
)
@@ -249,9 +250,10 @@ def get_frame(
249250
sort_by = "percent_error"
250251

251252
if (isinstance(num_rows, int) and num_rows > 0) or num_rows == "all":
252-
return self.comparisons[comparison_key].get_frame(
253+
data = self.comparisons[comparison_key].get_frame(
253254
stratifications, num_rows, sort_by, ascending, aggregate_draws
254255
)
256+
return self.sort_ui_data_index(data, comparison_key)
255257
else:
256258
raise ValueError("num_rows must be a positive integer or literal 'all'")
257259

@@ -344,3 +346,27 @@ def _format_to_vivarium_inputs_conventions(
344346
data = vi.split_interval(data, interval_column="year", split_column_prefix="year")
345347
formatted_data: pd.DataFrame = vi.sort_hierarchical_data(data)
346348
return formatted_data
349+
350+
@staticmethod
351+
def sort_ui_data_index(data: pd.DataFrame, comparison_key: str) -> pd.DataFrame:
352+
"""Sort the data for UI display.
353+
354+
Parameters
355+
----------
356+
data
357+
The DataFrame to sort.
358+
comparison_key
359+
The comparison key for logging purposes.
360+
361+
Returns
362+
-------
363+
The sorted DataFrame.
364+
"""
365+
366+
expected_order = get_measure_index_names(comparison_key, "vivarium")
367+
ordered_cols = [col for col in expected_order if col in data.index.names]
368+
extra_idx_cols = [col for col in data.index.names if col not in ordered_cols]
369+
sorted_index = ordered_cols + extra_idx_cols
370+
sorted = data.reorder_levels(sorted_index).sort_index()
371+
372+
return sorted

tests/automated_validation/conftest.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from vivarium_testing_utils.automated_validation.constants import (
1212
DRAW_INDEX,
13+
INPUT_DATA_INDEX_NAMES,
1314
LOCATION_ARTIFACT_KEY,
1415
SEED_INDEX,
1516
)
@@ -21,9 +22,6 @@
2122
utils,
2223
)
2324
from vivarium_testing_utils.automated_validation.data_transformation.age_groups import (
24-
AGE_END_COLUMN,
25-
AGE_GROUP_COLUMN,
26-
AGE_START_COLUMN,
2725
AgeSchema,
2826
AgeTuple,
2927
)
@@ -138,8 +136,8 @@ def _get_artifact_index() -> pd.MultiIndex:
138136
names=[
139137
"common_stratify_column",
140138
"other_stratify_column",
141-
AGE_START_COLUMN,
142-
AGE_END_COLUMN,
139+
INPUT_DATA_INDEX_NAMES.AGE_START,
140+
INPUT_DATA_INDEX_NAMES.AGE_END,
143141
],
144142
)
145143

@@ -160,11 +158,17 @@ def _create_sample_age_group_df() -> pd.DataFrame:
160158
"""Create sample age group data for testing."""
161159
return pd.DataFrame(
162160
{
163-
AGE_GROUP_COLUMN: ["0_to_4", "5_to_9", "10_to_14"],
164-
AGE_START_COLUMN: [0.0, 5.0, 10.0],
165-
AGE_END_COLUMN: [5.0, 10.0, 15.0],
161+
INPUT_DATA_INDEX_NAMES.AGE_GROUP: ["0_to_4", "5_to_9", "10_to_14"],
162+
INPUT_DATA_INDEX_NAMES.AGE_START: [0.0, 5.0, 10.0],
163+
INPUT_DATA_INDEX_NAMES.AGE_END: [5.0, 10.0, 15.0],
166164
}
167-
).set_index([AGE_GROUP_COLUMN, AGE_START_COLUMN, AGE_END_COLUMN])
165+
).set_index(
166+
[
167+
INPUT_DATA_INDEX_NAMES.AGE_GROUP,
168+
INPUT_DATA_INDEX_NAMES.AGE_START,
169+
INPUT_DATA_INDEX_NAMES.AGE_END,
170+
]
171+
)
168172

169173

170174
@utils.check_io(out=SingleNumericColumn)
@@ -398,9 +402,9 @@ def sample_df_with_ages() -> pd.DataFrame:
398402
names=[
399403
"cause",
400404
"disease",
401-
AGE_GROUP_COLUMN,
402-
AGE_START_COLUMN,
403-
AGE_END_COLUMN,
405+
INPUT_DATA_INDEX_NAMES.AGE_GROUP,
406+
INPUT_DATA_INDEX_NAMES.AGE_START,
407+
INPUT_DATA_INDEX_NAMES.AGE_END,
404408
],
405409
),
406410
)
@@ -514,8 +518,8 @@ def _artifact_population_structure() -> pd.DataFrame:
514518
"location",
515519
"common_stratify_column",
516520
"other_stratify_column",
517-
AGE_START_COLUMN,
518-
AGE_END_COLUMN,
521+
INPUT_DATA_INDEX_NAMES.AGE_START,
522+
INPUT_DATA_INDEX_NAMES.AGE_END,
519523
]
520524
pop = pop.reset_index().set_index(index_order)
521525

@@ -613,10 +617,10 @@ def reference_weights() -> pd.DataFrame:
613617

614618
def is_on_slurm() -> bool:
615619
"""Returns True if the current environment is a SLURM cluster."""
616-
return not shutil.which("sbatch") is not None
620+
return shutil.which("sbatch") is not None
617621

618622

619-
NO_GBD_ACCESS = is_on_slurm()
623+
IS_ON_SLURM = is_on_slurm()
620624

621625

622626
@pytest.fixture

0 commit comments

Comments
 (0)