Skip to content

Commit 49ec6f5

Browse files
authored
reorganisation/tidy (#34)
* reorganisation/tidy * load script
1 parent 3ce35cc commit 49ec6f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+302
-42367
lines changed

.github/workflows/render-report.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333

3434
- name: Compile the report
3535
run: |
36-
rmarkdown::render("output/results.Rmd")
36+
rmarkdown::render("report/results.Rmd")
3737
shell: Rscript {0}
3838

3939
- name: Create Pull Request
@@ -46,5 +46,5 @@ jobs:
4646
branch: "render-report-${{ github.run_number }}"
4747
labels: "documentation"
4848
add-paths: |
49-
output/results.pdf
49+
report/results.pdf
5050
token: ${{ secrets.GITHUB_TOKEN }}

.github/workflows/test-report.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ jobs:
3434

3535
- name: Compile the report
3636
run: |
37-
rmarkdown::render("output/results.Rmd")
37+
rmarkdown::render("report/results.Rmd")
3838
shell: Rscript {0}
3939

4040
- name: Upload pdf as an artifact
4141
uses: actions/upload-artifact@v4
4242
with:
4343
name: results
44-
path: output/results.pdf
44+
path: report/results.pdf

R/import-data.R

Lines changed: 10 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,3 @@
1-
# Functions to import and save data for predictions and observations
2-
# Examples:
3-
# forecasts <- get_forecasts()
4-
# obs <- get_observed()
5-
# forecasts <- left_join(forecasts, obs,
6-
# by = c("location", "target_end_date"))
7-
# anomalies <- get_anomalies()
8-
# forecasts <- anti_join(forecasts, anomalies,
9-
# by = c("target_end_date", "location"))
10-
111
library(here)
122
library(dplyr)
133
library(readr)
@@ -16,57 +6,10 @@ library(arrow)
166
library(tidyr)
177
library(ggplot2)
188
library(stringr)
19-
theme_set(theme_minimal())
20-
21-
# Prediction data ------------------------------------------------------
22-
get_forecasts <- function(data_type = "death") {
23-
forecasts <- arrow::read_parquet(here("data",
24-
"covid19-forecast-hub-europe.parquet")) |>
25-
filter(grepl(data_type, target))
26-
27-
forecasts <- forecasts |>
28-
separate(target, into = c("horizon", "target_variable"),
29-
sep = " wk ahead ") |>
30-
# set forecast date to corresponding submission date
31-
mutate(
32-
horizon = as.numeric(horizon),
33-
forecast_date = target_end_date - weeks(horizon) + days(1)) |>
34-
rename(prediction = value) |>
35-
select(location, forecast_date,
36-
horizon, target_end_date,
37-
model, quantile, prediction)
38-
39-
# Exclusions
40-
# dates should be between start of hub and until end of JHU data
41-
forecasts <- forecasts |>
42-
filter(forecast_date >= as.Date("2021-03-07") &
43-
target_end_date <= as.Date("2023-03-10"))
44-
# only keep forecasts up to 4 weeks ahead
45-
forecasts <- filter(forecasts, horizon <= 4)
46-
47-
# only include predictions from models with all quantiles
48-
rm_quantiles <- forecasts |>
49-
group_by(model, forecast_date, location) |>
50-
summarise(q = length(unique(quantile))) |>
51-
filter(q < 23)
52-
forecasts <- anti_join(forecasts, rm_quantiles,
53-
by = c("model", "forecast_date", "location"))
54-
forecasts <- filter(forecasts, !is.na(quantile)) # remove "median"
55-
56-
# remove duplicates
57-
forecasts <- forecasts |>
58-
group_by_all() |>
59-
mutate(duplicate = row_number()) |>
60-
ungroup() |>
61-
filter(duplicate == 1) |>
62-
select(-duplicate)
63-
64-
return(forecasts)
65-
}
9+
library(purrr)
6610

6711
# Observed data ---------------------------------------------------------
68-
# Get raw values
69-
get_observed <- function(data_type = "death") {
12+
walk(c("case", "death"), \(data_type) {
7013
file_name <- paste0("truth_JHU-Incident%20", str_to_title(data_type), "s.csv")
7114
obs <- read_csv(paste0(
7215
"https://raw.githubusercontent.com/covid19-forecast-hub-europe/",
@@ -98,67 +41,13 @@ get_observed <- function(data_type = "death") {
9841
"Stable")))))
9942
obs <- obs |>
10043
select(location, target_end_date, observed, trend)
101-
return(obs)
102-
}
103-
104-
# Observed data ---------------------------------------------------------
105-
# Get raw values
106-
get_pop <- function() {
107-
pop <- read_csv(paste0(
108-
"https://raw.githubusercontent.com/european-modelling-hubs/",
109-
"covid19-forecast-hub-europe/main/data-locations/locations_eu.csv"
110-
), show_col_types = FALSE) |>
111-
select(location, population)
112-
return(pop)
113-
}
114-
115-
# Plot observed data and trend classification
116-
plot_observed <- function() {
117-
obs <- import_observed()
118-
obs |>
119-
ggplot(aes(x = target_end_date, y = log(observed))) +
120-
geom_point(col = trend) +
121-
geom_line(alpha = 0.3) +
122-
scale_x_date() +
123-
labs(x = NULL, y = "Log observed", col = "Trend",
124-
caption = "Trend (coloured points) of weekly change in 3-week moving average") +
125-
theme(legend.position = "bottom", ) +
126-
facet_wrap(facets = "location", ncol = 1,
127-
strip.position = "left")
128-
129-
ggsave(filename = here("output/fig-trends.pdf"),
130-
height = 50, width = 15, limitsize = FALSE)
131-
}
132-
133-
# Anomalies
134-
get_anomalies <- function() {
135-
read_csv("https://raw.githubusercontent.com/covid19-forecast-hub-europe/covid19-forecast-hub-europe/5a2a8d48e018888f652e981c95de0bf05a838135/data-truth/anomalies/anomalies.csv") |>
136-
filter(target_variable == "inc death") |>
137-
select(-target_variable) |>
138-
mutate(anomaly = TRUE)
44+
write_csv(obs, here("data", paste0("observed-", data_type, ".csv")))
13945
}
14046

141-
# Plot anomalies
142-
plot_anomalies <- function() {
143-
obs <- get_observed()
144-
anomalies <- get_anomalies()
145-
146-
obs <- left_join(obs, anomalies) |>
147-
group_by(location) |>
148-
mutate(anomaly = replace_na(anomaly, FALSE))
149-
150-
obs |>
151-
ggplot(aes(x = target_end_date,
152-
y = log(observed),
153-
col = anomaly)) +
154-
geom_line() +
155-
geom_point(size = 0.3) +
156-
scale_x_date() +
157-
labs(x = NULL, y = "Log observed") +
158-
theme(legend.position = "bottom", ) +
159-
facet_wrap(facets = "location", ncol = 1,
160-
strip.position = "left")
161-
162-
ggsave(filename = here("output/fig-anomalies.pdf"),
163-
height = 50, width = 15, limitsize = FALSE)
164-
}
47+
# Population data ---------------------------------------------------------
48+
pop <- read_csv(paste0(
49+
"https://raw.githubusercontent.com/european-modelling-hubs/",
50+
"covid19-forecast-hub-europe/main/data-locations/locations_eu.csv"
51+
), show_col_types = FALSE) |>
52+
select(location, population)
53+
write_csv(pop, here("data", paste0("populations.csv")))

R/model-plots.R

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
library("purrr")
2+
library("dplyr")
3+
library("ggplot2")
4+
library("patchwork")
5+
library("gammit")
6+
source(here("R", "prep-data.R"))
7+
source(here("R", "descriptive.R"))
8+
9+
plot_models <- function(fits, scores, x_labels = TRUE) {
10+
outcomes <- unique(scores$outcome_target)
11+
classification <- classify_models() |>
12+
rename(group = model)
13+
targets <- table_targets(scores) |>
14+
select(group = Model, CountryTargets) |>
15+
distinct()
16+
plots <- map(fits, function(fit) {
17+
plot <- extract_ranef(fit) |>
18+
filter(group_var == "Model") |>
19+
left_join(classification) |>
20+
left_join(targets) |>
21+
mutate(group = sub(".*-", "", group)) |> ## remove institution identifier
22+
select(-group_var) |>
23+
arrange(-value) |>
24+
mutate(group = factor(group, levels = unique(as.character(group)))) |>
25+
ggplot(aes(x = group, col = classification, shape = CountryTargets)) +
26+
geom_point(aes(y = value)) +
27+
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
28+
geom_hline(yintercept = 0, lty = 2) +
29+
labs(y = "Partial effect", x = "Model", colour = NULL, shape = NULL) +
30+
scale_colour_brewer(type = "qual", palette = 2) +
31+
theme(
32+
legend.position = "bottom",
33+
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
34+
) +
35+
coord_flip()
36+
if (!x_labels) {
37+
plot <- plot +
38+
theme(
39+
axis.text.y = element_blank(),
40+
axis.ticks.y = element_blank()
41+
)
42+
}
43+
return(plot)
44+
})
45+
## remove legends
46+
if (length(plots) > 1) {
47+
for (i in seq_len(length(plots) - 1)) {
48+
plots[[i]] <- plots[[i]] + theme(legend.position = "none")
49+
}
50+
}
51+
for (i in seq_along(plots)) {
52+
plots[[i]] <- plots[[i]] + ggtitle(outcomes[i])
53+
}
54+
Reduce(`+`, plots) + plot_layout(ncol = 2)
55+
}
56+
57+
plot_effects <- function(fits, scores) {
58+
map(fits, extract_ranef) |>
59+
bind_rows(.id = "outcome_target") |>
60+
filter(!(group_var %in% c("Model", "location"))) |>
61+
mutate(group = factor(group, levels = unique(as.character(rev(group))))) |>
62+
ggplot(aes(x = group, col = group_var)) +
63+
geom_point(aes(y = value)) +
64+
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
65+
geom_hline(yintercept = 0, lty = 2, alpha = 0.25) +
66+
facet_wrap(~outcome_target, scales = "free_y") +
67+
labs(y = "Partial effect", x = NULL, colour = NULL, shape = NULL) +
68+
scale_colour_brewer(type = "qual", palette = "Set1") +
69+
theme(
70+
legend.position = "bottom",
71+
strip.background = element_blank(),
72+
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
73+
) +
74+
coord_flip()
75+
}

R/model-wis.R

Lines changed: 2 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,8 @@ library(readr)
55
library(tidyr)
66
library(purrr)
77
library(mgcv)
8-
library(gratia) # devtools::install_github('gavinsimpson/gratia')
9-
library(broom)
10-
library(ggplot2)
11-
library(broom)
12-
library(gammit)
13-
theme_set(theme_classic())
148
source(here("R", "prep-data.R"))
9+
source(here("R", "descriptive.R"))
1510

1611
# --- Get data ---
1712
data <- prep_data(scoring_scale = "log")
@@ -60,70 +55,4 @@ m.fits <- outcomes |>
6055
)
6156
})
6257

63-
plot_models <- function(fits, scores, x_labels = TRUE) {
64-
outcomes <- unique(scores$outcome_target)
65-
classification <- classify_models() |>
66-
rename(group = model)
67-
targets <- table_targets(scores) |>
68-
select(group = Model, CountryTargets) |>
69-
distinct()
70-
plots <- map(fits, function(fit) {
71-
plot <- extract_ranef(fit) |>
72-
filter(group_var == "Model") |>
73-
left_join(classification) |>
74-
left_join(targets) |>
75-
mutate(group = sub(".*-", "", group)) |> ## remove institution identifier
76-
select(-group_var) |>
77-
arrange(-value) |>
78-
mutate(group = factor(group, levels = unique(as.character(group)))) |>
79-
ggplot(aes(x = group, col = classification, shape = CountryTargets)) +
80-
geom_point(aes(y = value)) +
81-
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
82-
geom_hline(yintercept = 0, lty = 2) +
83-
labs(y = "Partial effect", x = "Model", colour = NULL, shape = NULL) +
84-
scale_colour_brewer(type = "qual", palette = 2) +
85-
theme(
86-
legend.position = "bottom",
87-
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
88-
) +
89-
coord_flip()
90-
if (!x_labels) {
91-
plot <- plot +
92-
theme(
93-
axis.text.y = element_blank(),
94-
axis.ticks.y = element_blank()
95-
)
96-
}
97-
return(plot)
98-
})
99-
## remove legends
100-
if (length(plots) > 1) {
101-
for (i in seq_len(length(plots) - 1)) {
102-
plots[[i]] <- plots[[i]] + theme(legend.position = "none")
103-
}
104-
}
105-
for (i in seq_along(plots)) {
106-
plots[[i]] <- plots[[i]] + ggtitle(outcomes[i])
107-
}
108-
Reduce(`+`, plots) + plot_layout(ncol = 2)
109-
}
110-
111-
plot_effects <- function(fits, scores) {
112-
map(fits, extract_ranef) |>
113-
bind_rows(.id = "outcome_target") |>
114-
filter(!(group_var %in% c("Model", "location"))) |>
115-
mutate(group = factor(group, levels = unique(as.character(rev(group))))) |>
116-
ggplot(aes(x = group, col = group_var)) +
117-
geom_point(aes(y = value)) +
118-
geom_linerange(aes(ymin = lower_2.5, ymax = upper_97.5)) +
119-
geom_hline(yintercept = 0, lty = 2, alpha = 0.25) +
120-
facet_wrap(~outcome_target, scales = "free_y") +
121-
labs(y = "Partial effect", x = NULL, colour = NULL, shape = NULL) +
122-
scale_colour_brewer(type = "qual", palette = "Set1") +
123-
theme(
124-
legend.position = "bottom",
125-
strip.background = element_blank(),
126-
axis.text.x = element_text(angle = 45, vjust = 1, hjust = 1)
127-
) +
128-
coord_flip()
129-
}
58+
saveRDS(m.fits, here("output", "fits.rds"))

R/natural-scale-scores.R

Lines changed: 0 additions & 55 deletions
This file was deleted.

0 commit comments

Comments
 (0)