Skip to content

Commit 466ce98

Browse files
committed
Add malaria forecasting scripts and data processing pipeline
- Implemented parallel processing for malaria age-sex forecasting in `96_as_malaria_shifts_parallel.py`. - Created main forecasting logic in `as_malaria_shifts.py`, including data loading, processing, and saving outputs. - Developed full age-sex dataset generation in `05_make_full_as_ds.py`, utilizing xarray for efficient data handling. - Added Jupyter notebook `06_make_mega_as_draws.ipynb` for generating mega datasets from individual draws. - Enhanced data path management and ensured compatibility with existing data structures. - Included comprehensive logging for memory and time tracking during processing.
1 parent 4f0d954 commit 466ce98

31 files changed

+3065
-252
lines changed

src/idd_forecast_mbp/02_data_prep/04_rake_as_A2_to_GBD.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from idd_forecast_mbp import constants as rfc
66
from idd_forecast_mbp.helper_functions import level_filter
77
from idd_forecast_mbp.parquet_functions import read_parquet_with_integer_ids, write_parquet
8+
from idd_forecast_mbp.xarray_functions import convert_to_xarray, write_netcdf
89

910
PROCESSED_DATA_PATH = rfc.PROCESSED_DATA_PATH
1011
FORECASTING_DATA_PATH = rfc.FORECASTING_DATA_PATH
1112
GBD_DATA_PATH = rfc.GBD_DATA_PATH
1213
FHS_DATA_PATH = f"{PROCESSED_DATA_PATH}/age_specific_fhs"
1314

1415
as_full_cause_df_path_template = '{PROCESSED_DATA_PATH}/as_full_{cause}_df.parquet'
16+
as_full_cause_ds_path_template = '{PROCESSED_DATA_PATH}/as_full_{cause}_ds.nc'
1517
################################################################
1618
#### Paths, loading, and cleaning
1719
################################################################
@@ -157,5 +159,22 @@
157159
as_full_cause_df.loc[as_full_cause_df['age_group_id'].isin(force_zero_age_ids), [f'{cause}_{measure_map[measure]["short"]}_{metric}' for measure in measure_map]] = 0
158160
as_full_cause_df.loc[as_full_cause_df['age_group_id'].isin(force_zero_age_ids), [f'aa_{cause}_{measure_map[measure]["short"]}_{metric}' for measure in measure_map]] = 0
159161
as_full_cause_df_path = as_full_cause_df_path_template.format(PROCESSED_DATA_PATH=PROCESSED_DATA_PATH, cause=cause)
162+
as_full_cause_ds_path = as_full_cause_ds_path_template.format(PROCESSED_DATA_PATH=PROCESSED_DATA_PATH, cause=cause)
160163
write_parquet(as_full_cause_df, as_full_cause_df_path)
164+
165+
166+
as_full_cause_ds = convert_to_xarray(
167+
as_full_cause_df,
168+
dimensions=['location_id', 'year_id', 'sex_id', 'age_group_id'],
169+
dimension_dtypes={'location_id': 'int32', 'year_id': 'int16', 'sex_id': 'int16', 'age_group_id': 'int16'},
170+
auto_optimize_dtypes=True
171+
)
172+
173+
write_netcdf(as_full_cause_ds, as_full_cause_ds_path,
174+
compression=True,
175+
compression_level=4,
176+
chunking=True,
177+
chunk_by_dim={'location_id': 1500, 'year_id': 79},
178+
engine='netcdf4'
179+
)
161180
print(f"Wrote {as_full_cause_df_path}")
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import getpass
2+
import uuid
3+
from jobmon.client.tool import Tool # type: ignore
4+
from pathlib import Path
5+
import geopandas as gpd # type: ignore
6+
from idd_forecast_mbp import constants as rfc
7+
8+
repo_name = rfc.repo_name
9+
package_name = rfc.package_name
10+
11+
# Script directory
12+
SCRIPT_ROOT = rfc.REPO_ROOT / repo_name / "src" / package_name / "02_data_prep"
13+
14+
ssp_scenarios = rfc.ssp_scenarios
15+
draws = rfc.draws
16+
17+
# Jobmon setup
18+
user = getpass.getuser()
19+
20+
log_dir = Path(f"/mnt/share/homes/{user}/{package_name}/")
21+
log_dir.mkdir(parents=True, exist_ok=True)
22+
# Create directories for stdout and stderr
23+
stdout_dir = log_dir / "stdout"
24+
stderr_dir = log_dir / "stderr"
25+
stdout_dir.mkdir(parents=True, exist_ok=True)
26+
stderr_dir.mkdir(parents=True, exist_ok=True)
27+
28+
# Project
29+
project = "proj_rapidresponse" # Adjust this to your project name if needed
30+
31+
32+
wf_uuid = uuid.uuid4()
33+
tool_name = f"{package_name}_draw_level_dataframe_generation"
34+
tool = Tool(name=tool_name)
35+
36+
# Create a workflow
37+
workflow = tool.create_workflow(
38+
name=f"{tool_name}_workflow_{wf_uuid}",
39+
max_concurrently_running=10000, # Adjust based on system capacity
40+
)
41+
42+
# Compute resources
43+
workflow.set_default_compute_resources_from_dict(
44+
cluster_name="slurm",
45+
dictionary={
46+
"memory": "15G",
47+
"cores": 1,
48+
"runtime": "60m",
49+
"queue": "all.q",
50+
"project": project,
51+
"stdout": str(stdout_dir),
52+
"stderr": str(stderr_dir),
53+
}
54+
)
55+
56+
# Define the task template for processing each year batch
57+
task_template = tool.get_task_template(
58+
template_name="alt_forecasted_malaria_dataframe_creation",
59+
default_cluster_name="slurm",
60+
default_compute_resources={
61+
"memory": "50G",
62+
"cores": 1,
63+
"runtime": "1m",
64+
"queue": "all.q",
65+
"project": project,
66+
"stdout": str(stdout_dir),
67+
"stderr": str(stderr_dir),
68+
},
69+
command_template=(
70+
"python {script_root}/alt_forecasted_malaria_dataframes.py "
71+
"--ssp_scenario {{ssp_scenario}} "
72+
"--draw {{draw}} "
73+
).format(script_root=SCRIPT_ROOT),
74+
node_args=["ssp_scenario", "draw"],
75+
task_args=[],
76+
op_args=[],
77+
)
78+
79+
# Add tasks
80+
tasks = []
81+
for ssp_scenario in ssp_scenarios:
82+
for draw in draws:
83+
# Create the primary task
84+
task = task_template.create_task(
85+
ssp_scenario=ssp_scenario,
86+
draw=draw,
87+
)
88+
tasks.append(task)
89+
90+
print(f"Number of tasks: {len(tasks)}")
91+
92+
if tasks:
93+
workflow.add_tasks(tasks)
94+
print("✅ Tasks successfully added to workflow.")
95+
else:
96+
print("⚠️ No tasks added to workflow. Check task generation.")
97+
98+
try:
99+
workflow.bind()
100+
print("✅ Workflow successfully bound.")
101+
print(f"Running workflow with ID {workflow.workflow_id}.")
102+
print("For full information see the Jobmon GUI:")
103+
print(f"https://jobmon-gui.ihme.washington.edu/#/workflow/{workflow.workflow_id}")
104+
except Exception as e:
105+
print(f"❌ Workflow binding failed: {e}")
106+
107+
try:
108+
status = workflow.run()
109+
print(f"Workflow {workflow.workflow_id} completed with status {status}.")
110+
except Exception as e:
111+
print(f"❌ Workflow submission failed: {e}")
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import xarray as xr # type: ignore
2+
from pathlib import Path
3+
import numpy as np # type: ignore
4+
from typing import cast
5+
import numpy.typing as npt # type: ignore
6+
import pandas as pd # type: ignore
7+
from typing import Literal, NamedTuple
8+
import itertools
9+
from rra_tools.shell_tools import mkdir # type: ignore
10+
from idd_forecast_mbp import constants as rfc
11+
from idd_forecast_mbp.helper_functions import merge_dataframes, read_income_paths, read_urban_paths, level_filter
12+
from idd_forecast_mbp.parquet_functions import read_parquet_with_integer_ids, write_parquet
13+
14+
15+
import argparse
16+
parser = argparse.ArgumentParser(description="Add DAH Sceanrios and create draw level dataframes for forecating dengue")
17+
18+
# Define arguments
19+
parser.add_argument("--ssp_scenario", type=str, required=True, help="ssp scenario number (ssp16, ssp245, ssp585")
20+
parser.add_argument("--draw", type=str, required=True, help="Draw number (e.g., '001', '002', etc.)")
21+
22+
# Parse arguments
23+
args = parser.parse_args()
24+
25+
26+
ssp_scenario = args.ssp_scenario
27+
draw = args.draw
28+
29+
# Hierarchy
30+
hierarchy = "lsae_1209"
31+
PROCESSED_DATA_PATH = rfc.MODEL_ROOT / "02-processed_data"
32+
FORECASTING_DATA_PATH = rfc.MODEL_ROOT / "04-forecasting_data"
33+
34+
# New DAH data
35+
new_dah_scenarios = {
36+
'reference': {
37+
'name': 'reference',
38+
'path': f'{PROCESSED_DATA_PATH}/dah_reference_df.parquet'
39+
},
40+
'better': {
41+
'name': 'better',
42+
'path': f'{PROCESSED_DATA_PATH}/dah_better_df.parquet'
43+
},
44+
'worse': {
45+
'name': 'worse',
46+
'path': f'{PROCESSED_DATA_PATH}/dah_worse_df.parquet'
47+
}
48+
}
49+
50+
51+
base_dah_scenario_df_path_template = "{FORECASTING_DATA_PATH}/malaria_forecast_ssp_scenario_{ssp_scenario}_dah_scenario_Baseline_draw_{draw}.parquet"
52+
dah_scenario_df_path_template = "{FORECASTING_DATA_PATH}/malaria_forecast_ssp_scenario_{ssp_scenario}_dah_scenario_{dah_scenario_name}_draw_{draw}.parquet"
53+
54+
columns_to_keep = ['location_id', 'year_id', 'people_flood_days_per_capita',
55+
'gdppc_mean', 'log_gdppc_mean',
56+
'logit_malaria_pfpr',
57+
'aa_malaria_mort_rate', 'aa_malaria_inc_rate',
58+
'base_malaria_mort_rate', 'base_malaria_inc_rate',
59+
'log_aa_malaria_mort_rate', 'log_aa_malaria_inc_rate',
60+
'log_base_malaria_mort_rate', 'log_base_malaria_inc_rate',
61+
'malaria_suitability', 'year_to_rake_to', 'A0_af']
62+
63+
dah_columns_to_keep = ['location_id', 'year_id', 'mal_DAH_total_per_capita']
64+
65+
66+
base_dah_scenario_df_path = base_dah_scenario_df_path_template.format(
67+
FORECASTING_DATA_PATH=FORECASTING_DATA_PATH,
68+
ssp_scenario=ssp_scenario,
69+
draw=draw
70+
)
71+
base_dah_scenario_df = read_parquet_with_integer_ids(base_dah_scenario_df_path,
72+
columns=columns_to_keep
73+
)
74+
75+
for dah_scenario_name, dah_scenario in new_dah_scenarios.items():
76+
print(f"Processing DAH scenario: {dah_scenario_name}")
77+
78+
# Read the new DAH scenario data
79+
dah_df = read_parquet_with_integer_ids(dah_scenario['path'],
80+
columns=dah_columns_to_keep)
81+
82+
# Merge with the existing DAH scenario data
83+
dah_scenario_df = base_dah_scenario_df.merge(dah_df, on=['location_id', 'year_id'], how='left')
84+
85+
# Add the new DAH column
86+
dah_scenario_df['mal_DAH_total_per_capita'] = dah_scenario_df['mal_DAH_total_per_capita'].fillna(0)
87+
88+
# Write the output to a new parquet file
89+
dah_scenario_df_path = dah_scenario_df_path_template.format(
90+
FORECASTING_DATA_PATH=FORECASTING_DATA_PATH,
91+
ssp_scenario=ssp_scenario,
92+
dah_scenario_name=dah_scenario_name,
93+
draw=draw
94+
)
95+
96+
write_parquet(dah_scenario_df, dah_scenario_df_path)

src/idd_forecast_mbp/02_data_prep/forecasted_draw_specific_dengue_dataframes.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,6 @@
4949
forecast_non_draw_df_path = f"{FORECASTING_DATA_PATH}/{cause}_forecast_scenario_{ssp_scenario}_non_draw_part.parquet"
5050
forecast_by_draw_df_path_template = "{FORECASTING_DATA_PATH}/{cause}_forecast_ssp_scenario_{ssp_scenario}_draw_{draw}.parquet"
5151

52-
# Hierarchy path
53-
hierarchy_df_path = f'{PROCESSED_DATA_PATH}/full_hierarchy_lsae_1209.parquet'
54-
hierarchy_df = read_parquet_with_integer_ids(hierarchy_df_path)
55-
56-
5752
# Hierarchy path
5853
hierarchy_df_path = f'{PROCESSED_DATA_PATH}/full_hierarchy_lsae_1209.parquet'
5954
hierarchy_df = read_parquet_with_integer_ids(hierarchy_df_path)

0 commit comments

Comments
 (0)