Skip to content

Commit 1c6ecce

Browse files
committed
feat: Add parallel summary file creation and visualization scripts
- Introduced a new script for creating summary files in parallel (`03_create_summary_files_parallel.py`). - Updated `combine_as_draws.py` to save draws in NetCDF format instead of HDF5. - Created a new script for generating summary files (`create_summary_files.py`) that aggregates data and calculates rates. - Added a new R script for visualizing malaria model forecasts (`visualize_malaria_model.r`). - Enhanced constants file to include a path for visualization data. - Added new functions in `hd5_functions.py` for creating HDF5 structures and appending data. - Updated `xarray_functions.py` to improve validation during NetCDF writing.
1 parent 9aef0f2 commit 1c6ecce

22 files changed

+982
-1489
lines changed

src/idd_forecast_mbp/02_data_prep/07_forecasted_dataframes_non_draw_part.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@
4444
}
4545

4646
# DAH
47-
dah_df_path = f"{VARIABLE_DATA_PATH}/dah_df.parquet"
47+
# dah_df_path = f"{VARIABLE_DATA_PATH}/dah_df.parquet"
48+
dah_df_path = f"{PROCESSED_DATA_PATH}/dah_df_2025_07_08.parquet"
4849

4950
urban_paths = {
5051
"urban_threshold_300": "{VARIABLE_DATA_PATH}/urban_threshold_300.0_simple_mean.parquet",
@@ -112,8 +113,13 @@
112113

113114
print("Reading DAH data...")
114115
dah_df = read_parquet_with_integer_ids(dah_df_path)
115-
dah_df = dah_df.filter(regex="location_id|year_id|total")
116-
forecast_df = forecast_df.merge(dah_df, on=["location_id", "year_id"], how = "left")
116+
dah_df = dah_df.rename(columns={'location_id': 'A0_location_id'})
117+
dah_df = dah_df.drop(columns=['population', 'location_name', 'iso3'], errors='ignore')
118+
# dah_df = dah_df.filter(regex="location_id|year_id|total")
119+
forecast_df = forecast_df.merge(dah_df, on=["A0_location_id", "year_id"], how = "left")
120+
# Set any NaN values in the total column to 0
121+
forecast_df['mal_DAH_total'] = forecast_df['mal_DAH_total'].fillna(0)
122+
forecast_df['mal_DAH_total_per_capita'] = forecast_df['mal_DAH_total_per_capita'].fillna(0)
117123

118124
print("Writing malaria forecast non-draw part...")
119125
cause = "malaria"
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "1e6d9ab8",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import numpy as np # type: ignore\n",
11+
"import pandas as pd # type: ignore\n",
12+
"from idd_forecast_mbp import constants as rfc\n",
13+
"from idd_forecast_mbp.helper_functions import merge_dataframes, read_income_paths, read_urban_paths, level_filter\n",
14+
"from idd_forecast_mbp.parquet_functions import read_parquet_with_integer_ids, write_parquet\n",
15+
"\n",
16+
"FORECASTING_DATA_PATH = rfc.MODEL_ROOT / \"04-forecasting_data\"\n",
17+
"PROCESSED_DATA_PATH = rfc.MODEL_ROOT / \"02-processed_data\"\n",
18+
"aa_full_population_df_path = f\"{PROCESSED_DATA_PATH}/aa_2023_full_population_df.parquet\"\n",
19+
"\n",
20+
"hierarchy_df_path = f'{PROCESSED_DATA_PATH}/full_hierarchy_lsae_1209.parquet'\n",
21+
"hierarchy_df = read_parquet_with_integer_ids(hierarchy_df_path)\n",
22+
"\n",
23+
"dah_df_path = f\"{PROCESSED_DATA_PATH}/dah_df_2025_07_08.parquet\""
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 2,
29+
"id": "0a2fd7ad",
30+
"metadata": {},
31+
"outputs": [],
32+
"source": [
33+
"new_dah_path = '/mnt/share/resource_tracking/forecasting/dah_channel_HFA/FGH_2024_submission_5_reference/dah_by_channel_hfa_recip_1990_2100.csv'\n",
34+
"new_dah_df = pd.read_csv(new_dah_path)\n",
35+
"new_dah_df = new_dah_df[(new_dah_df['hfa'] == 'mal') & (new_dah_df['year'] >= 2000)]\n",
36+
"new_dah_df = new_dah_df.groupby(['year', 'recip']).agg({'dah': 'sum'}).reset_index()\n",
37+
"new_dah_df = new_dah_df.rename(columns={'recip': 'iso3', 'dah': 'mal_DAH_total', 'year': 'year_id'})\n",
38+
"\n",
39+
"A0_hierarchy_df = hierarchy_df[hierarchy_df['level'] == 3].copy()\n",
40+
"A0_hierarchy_df = A0_hierarchy_df[['location_id', 'location_name', 'ihme_loc_id']].drop_duplicates().reset_index(drop=True)\n",
41+
"A0_hierarchy_df = A0_hierarchy_df.rename(columns={'ihme_loc_id': 'iso3'})\n",
42+
"\n",
43+
"new_dah_df = new_dah_df.merge(A0_hierarchy_df, on='iso3', how='inner')\n",
44+
"A0_location_filter = ('location_id', 'in', A0_hierarchy_df['location_id'].unique().tolist())\n",
45+
"pop_df = read_parquet_with_integer_ids(aa_full_population_df_path, filters=[A0_location_filter])\n",
46+
"new_dah_df = new_dah_df.merge(pop_df, on=['location_id', 'year_id'], how='left')\n",
47+
"\n",
48+
"new_dah_df['mal_DAH_total_per_capita'] = new_dah_df['mal_DAH_total'] / new_dah_df['population']"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": 3,
54+
"id": "97c723d4",
55+
"metadata": {},
56+
"outputs": [
57+
{
58+
"name": "stdout",
59+
"output_type": "stream",
60+
"text": [
61+
"✅ Metadata validation passed for /mnt/team/idd/pub/forecast-mbp/02-processed_data/dah_df_2025_07_08.parquet\n"
62+
]
63+
},
64+
{
65+
"data": {
66+
"text/plain": [
67+
"True"
68+
]
69+
},
70+
"execution_count": 3,
71+
"metadata": {},
72+
"output_type": "execute_result"
73+
}
74+
],
75+
"source": [
76+
"write_parquet(new_dah_df,dah_df_path)"
77+
]
78+
}
79+
],
80+
"metadata": {
81+
"kernelspec": {
82+
"display_name": "forecast-mbp",
83+
"language": "python",
84+
"name": "python3"
85+
},
86+
"language_info": {
87+
"codemirror_mode": {
88+
"name": "ipython",
89+
"version": 3
90+
},
91+
"file_extension": ".py",
92+
"mimetype": "text/x-python",
93+
"name": "python",
94+
"nbconvert_exporter": "python",
95+
"pygments_lexer": "ipython3",
96+
"version": "3.12.9"
97+
}
98+
},
99+
"nbformat": 4,
100+
"nbformat_minor": 5
101+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
rm(list = ls())
2+
#
3+
4+
require(glue)
5+
require(mgcv)
6+
require(scam)
7+
require(arrow)
8+
require(data.table)
9+
10+
"%ni%" <- Negate("%in%")
11+
"%nlike%" <- Negate("%like%")
12+
13+
###########################################
14+
dah_scenario_name = 'Baseline'
15+
draw = '077'
16+
17+
REPO_DIR = "/mnt/team/idd/pub/forecast-mbp"
18+
last_year <- 2022
19+
data_path <- glue("{REPO_DIR}/03-modeling_data")
20+
FORECASTING_DATA_PATH = glue("{REPO_DIR}/04-forecasting_data")
21+
22+
ssp585_df_path <- glue("{FORECASTING_DATA_PATH}/malaria_forecast_ssp_scenario_ssp585_dah_scenario_{dah_scenario_name}_draw_{draw}.parquet")
23+
ssp585_df <-as.data.frame(arrow::read_parquet(ssp585_df_path))
24+
ssp585_df$A0_af <- as.factor(ssp585_df$A0_af)
25+
26+
past_data <- ssp585_df[-which(is.na(ssp585_df$malaria_pfpr)),]
27+
past_data <- past_data[-which(is.na(past_data$gdppc_mean)),]
28+
29+
past_data$malaria_suit_fraction <- past_data$malaria_suitability / 365
30+
past_data$malaria_suit_fraction <- pmin(pmax(past_data$malaria_suit_fraction, 0.001), 0.999)
31+
past_data$logit_malaria_suitability <- log(past_data$malaria_suit_fraction / (1 - past_data$malaria_suit_fraction))
32+
33+
malaria_pfpr_mod <- scam(logit_malaria_pfpr ~ logit_malaria_suitability +
34+
s(gdppc_mean, k = 6, bs = 'mpd') +
35+
s(mal_DAH_total_per_capita, k = 6, bs = 'mpd') +
36+
people_flood_days_per_capita +
37+
A0_af,
38+
data = past_data,
39+
optimizer = "efs", # Faster optimizer
40+
control = list(maxit = 300)) # Limit iterations
41+
42+
mod_df <- past_data[which(past_data$aa_malaria_mort_rate > 0),]
43+
mortality_scam_mod <- scam(log_aa_malaria_mort_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
44+
log_gdppc_mean +
45+
A0_af,
46+
data = mod_df,
47+
optimizer = "efs", # Faster optimizer
48+
control = list(maxit = 300)) # Limit iterations
49+
50+
mod_df <- past_data[which(past_data$aa_malaria_inc_rate > 0),]
51+
incidence_scam_mod <- scam(log_aa_malaria_inc_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
52+
log_gdppc_mean + A0_af,
53+
data = mod_df,
54+
optimizer = "efs", # Faster optimizer
55+
control = list(maxit = 300)) # Limit iterations
56+
57+
mod_df <- past_data[which(past_data$base_malaria_mort_rate > 0),]
58+
mortality_base_scam_mod <- scam(log_base_malaria_mort_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
59+
log_gdppc_mean +
60+
A0_af,
61+
data = mod_df,
62+
optimizer = "efs", # Faster optimizer
63+
control = list(maxit = 300)) # Limit iterations
64+
65+
mod_df <- past_data[which(past_data$base_malaria_inc_rate > 0),]
66+
incidence_base_scam_mod <- scam(log_base_malaria_inc_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
67+
log_gdppc_mean + A0_af,
68+
data = mod_df,
69+
optimizer = "efs", # Faster optimizer
70+
control = list(maxit = 300)) # Limit iterations
71+
72+
73+
model_names <- c("malaria_pfpr_mod", "mortality_scam_mod", "incidence_scam_mod", "mortality_base_scam_mod",
74+
"incidence_base_scam_mod")
75+
76+
save(list = model_names, file = glue("{data_path}/2025_07_03_malaria_models.RData"))
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
rm(list = ls())
2+
#
3+
4+
require(glue)
5+
require(mgcv)
6+
require(scam)
7+
require(arrow)
8+
require(data.table)
9+
10+
"%ni%" <- Negate("%in%")
11+
"%nlike%" <- Negate("%like%")
12+
13+
###########################################
14+
15+
REPO_DIR = "/mnt/team/idd/pub/forecast-mbp"
16+
last_year <- 2022
17+
data_path <- glue("{REPO_DIR}/03-modeling_data")
18+
FORECASTING_DATA_PATH = glue("{REPO_DIR}/04-forecasting_data")
19+
20+
df_path <- glue("{FORECASTING_DATA_PATH}/malaria_forecast_ssp_scenario_ssp126_dah_scenario_Baseline_draw_000.parquet")
21+
df <-as.data.frame(arrow::read_parquet(df_path))
22+
df$A0_af <- as.factor(df$A0_af)
23+
24+
past_data <- df[-which(is.na(df$malaria_pfpr)),]
25+
past_data <- past_data[-which(is.na(past_data$gdppc_mean)),]
26+
27+
past_data$malaria_suit_fraction <- past_data$malaria_suitability / 365
28+
past_data$malaria_suit_fraction <- pmin(pmax(past_data$malaria_suit_fraction, 0.001), 0.999)
29+
past_data$logit_malaria_suitability <- log(past_data$malaria_suit_fraction / (1 - past_data$malaria_suit_fraction))
30+
31+
32+
33+
malaria_pfpr_mod <- scam(logit_malaria_pfpr ~ logit_malaria_suitability +
34+
s(gdppc_mean, k = 6, bs = 'mpd') +
35+
s(mal_DAH_total_per_capita, k = 6, bs = 'mpd') +
36+
people_flood_days_per_capita +
37+
A0_af,
38+
data = past_data,
39+
optimizer = "efs", # Faster optimizer
40+
control = list(maxit = 300)) # Limit iterations
41+
42+
43+
44+
45+
mod_df <- past_data[which(past_data$aa_malaria_mort_rate > 0),]
46+
mortality_scam_mod <- scam(log_aa_malaria_mort_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
47+
log_gdppc_mean +
48+
A0_af,
49+
data = mod_df,
50+
optimizer = "efs", # Faster optimizer
51+
control = list(maxit = 300)) # Limit iterations
52+
53+
mod_df <- past_data[which(past_data$aa_malaria_inc_rate > 0),]
54+
incidence_scam_mod <- scam(log_aa_malaria_inc_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
55+
log_gdppc_mean + A0_af,
56+
data = mod_df,
57+
optimizer = "efs", # Faster optimizer
58+
control = list(maxit = 300)) # Limit iterations
59+
60+
mod_df <- past_data[which(past_data$base_malaria_mort_rate > 0),]
61+
mortality_base_scam_mod <- scam(log_base_malaria_mort_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
62+
log_gdppc_mean +
63+
A0_af,
64+
data = mod_df,
65+
optimizer = "efs", # Faster optimizer
66+
control = list(maxit = 300)) # Limit iterations
67+
68+
mod_df <- past_data[which(past_data$base_malaria_inc_rate > 0),]
69+
incidence_base_scam_mod <- scam(log_base_malaria_inc_rate ~ s(logit_malaria_pfpr, k = 10, bs = "mpi") +
70+
log_gdppc_mean + A0_af,
71+
data = mod_df,
72+
optimizer = "efs", # Faster optimizer
73+
control = list(maxit = 300)) # Limit iterations
74+
75+
76+
model_names <- c("malaria_pfpr_mod", "mortality_scam_mod", "incidence_scam_mod", "mortality_base_scam_mod",
77+
"incidence_base_scam_mod")
78+
79+
save(list = model_names, file = glue("{data_path}/2025_07_08_malaria_models.RData"))
80+
81+
82+
83+
84+
85+
percentiles = seq(0, 1, by = 0.05)
86+
mal_dah_perc = sapply(percentiles, function(p) {
87+
quantile(past_data$mal_DAH_total_per_capita, p, na.rm = TRUE)
88+
})
89+
90+
mal_dah_perc = unique(mal_dah_perc)
91+
92+
bin_df <- data.frame(bin_start = head(mal_dah_perc, -1),
93+
bin_end = tail(mal_dah_perc, -1),
94+
mean_residual = NA,
95+
Q1 = NA,
96+
Q3 = NA)
97+
for (i in bin_df$bin_start){
98+
tmp_locs <- which(past_data$mal_DAH_total_per_capita >= i &
99+
past_data$mal_DAH_total_per_capita < (i + 0.01))
100+
bin_df$mean_residual[which(bin_df$bin_start == i)] <- mean(malaria_pfpr_mod$residuals[tmp_locs])
101+
bin_df$Q1[which(bin_df$bin_start == i)] <- quantile(malaria_pfpr_mod$residuals[tmp_locs], 0.25)
102+
bin_df$Q3[which(bin_df$bin_start == i)] <- quantile(malaria_pfpr_mod$residuals[tmp_locs], 0.75)
103+
}
104+
105+
106+
par(mfrow=c(3,1))
107+
plot(malaria_pfpr_mod, select = 2)
108+
plot(bin_df$bin_start+bin_df$bin_end, bin_df$mean_residual, type = 'n',xlim = c(0, max(bin_df$bin_end)), ylim = c(min(bin_df$Q1), max(bin_df$Q3)))
109+
abline(h = 0, lty = 2)
110+
for (i in seq_along(bin_df$bin_start)){
111+
lines(c(bin_df$bin_start[i], bin_df$bin_end[i]),
112+
c(bin_df$mean_residual[i], bin_df$mean_residual[i]),
113+
col = "blue", lwd = 2)
114+
lines(rep((bin_df$bin_start[i] + bin_df$bin_end[i]) / 2, 2),
115+
c(bin_df$Q1[i], bin_df$Q3[i]),
116+
col = "red", lwd = 2)
117+
}
118+
plot(bin_df$bin_start+1e-6, bin_df$mean_residual, type = 'n', xlim = c(min(bin_df$bin_start) + 1e-6, max(bin_df$bin_end)), ylim = c(min(bin_df$Q1), max(bin_df$Q3)), log = 'x')
119+
abline(h = 0, lty = 2)
120+
for (i in seq_along(bin_df$bin_start)){
121+
lines(c(bin_df$bin_start[i]+1e-6, bin_df$bin_end[i]),
122+
c(bin_df$mean_residual[i], bin_df$mean_residual[i]),
123+
col = "blue", lwd = 2)
124+
lines(rep((bin_df$bin_start[i]+1e-6 + bin_df$bin_end[i]) / 2, 2),
125+
c(bin_df$Q1[i], bin_df$Q3[i]),
126+
col = "red", lwd = 2)
127+
}
128+
129+

0 commit comments

Comments
 (0)