Skip to content

Commit b8d684c

Browse files
committed
add fx for edge history visualization
1 parent 3ebd90a commit b8d684c

File tree

2 files changed

+94
-38
lines changed

2 files changed

+94
-38
lines changed

epimodel-sti/R/validate_simulated_rels.R

Lines changed: 85 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,14 @@
77
#' @param joint_attrs Character vector specifying the two joint attribute to extract targets for.
88
#' Default is c("age", "race").
99
#' @return A data frame where each cell represents the mean degree for that age, race, and network type.
10+
#' @importFrom yaml read_yaml
11+
#' @importFrom dplyr mutate group_by summarize
12+
#' @importFrom tidyr pivot_longer
13+
#' @importFrom rlang .data
1014
#' @export
1115
get_target_degrees_age_race <- function(target_yaml_file, nets = c("main", "casual"), joint_attrs = c("age", "race")) {
1216
# load network targets from yaml file
13-
x <- yaml::read_yaml(target_yaml_file)
17+
x <- read_yaml(target_yaml_file)
1418

1519
# Validate inputs, currently only supports main and casual networks, age and race joint attributes
1620
if (nets != c("main", "casual")) {
@@ -57,14 +61,14 @@ get_target_degrees_age_race <- function(target_yaml_file, nets = c("main", "casu
5761

5862
# Summarize mean degrees and reshape to long format
5963
dat |>
60-
dplyr::group_by(age, race) |>
61-
dplyr::summarize(
62-
main = mean(main),
63-
casual = mean(casual),
64+
group_by(.data$age, .data$race) |>
65+
summarize(
66+
main = mean(.data$main),
67+
casual = mean(.data$casual),
6468
.groups = "drop"
6569
) |>
66-
dplyr::mutate(data = "targets") |>
67-
tidyr::pivot_longer(
70+
mutate(data = "targets") |>
71+
pivot_longer(
6872
cols = c("main", "casual"),
6973
names_to = "type",
7074
values_to = "degree"
@@ -81,6 +85,8 @@ get_target_degrees_age_race <- function(target_yaml_file, nets = c("main", "casu
8185
#' @return A data frame with time, simulation identifier, edges for main and casual networks,
8286
#' the difference from target edges, and the percentage difference from target edges.
8387
#' @importFrom rlang .data
88+
#' @importFrom dplyr select rename_with mutate filter group_by ungroup all_of
89+
#' @importFrom tidyr pivot_longer
8490
#' @export
8591
get_edges_history <- function(sim, nets = c("main", "casual")) {
8692
edges <- paste0("edges_", nets)
@@ -98,20 +104,82 @@ get_edges_history <- function(sim, nets = c("main", "casual")) {
98104
# Extract edges history from simulation object
99105
sim |>
100106
as.data.frame() |>
101-
dplyr::select(.data$time, .data$sim, dplyr::all_of(edges)) |>
102-
dplyr::mutate(
103-
main_diff = edges[1] - .data$target_main,
104-
casual_diff = edges[2] - .data$target_casual,
105-
main_diff_perc = (main_diff) / .data$target_main * 100,
106-
casual_diff_perc = (casual_diff) / .data$target_casual * 100
107+
select(.data$time, .data$sim, all_of(edges)) |>
108+
rename_with(~ gsub("edges_", "", .), all_of(edges)) |>
109+
pivot_longer(cols = all_of(nets), names_to = "net", values_to = "edges") |>
110+
mutate(
111+
target = ifelse(.data$net == "main", target_main, target_casual),
112+
absolute = .data$edges - .data$target,
113+
percent = (.data$absolute / .data$target) * 100
114+
) |>
115+
pivot_longer(
116+
cols = c("absolute", "percent", "edges"),
117+
names_to = "diff_type",
118+
values_to = "diff"
107119
) |>
108-
dplyr::group_by(time) |>
109-
dplyr::mutate(
110-
mean_main_diff_perc = mean(main_diff_perc, na.rm = TRUE),
111-
mean_casual_diff_perc = mean(casual_diff_perc, na.rm = TRUE)
120+
group_by(.data$time, .data$net, .data$diff_type) |>
121+
mutate(mean = mean(.data$diff, na.rm = TRUE)) |>
122+
ungroup() |>
123+
mutate(target = ifelse(.data$diff_type == "edges", .data$target, 0))
124+
}
125+
126+
#' @title Plot Edges History
127+
#' @description Plots the edges history for a specified network and type of difference (absolute, percent, or edges).
128+
#' @param edges_df A data frame containing the edges history,
129+
#' typically obtained from `get_edges_history()`.
130+
#' @param network A character string specifying the network type, either "main" or "casual".
131+
#' @param type A character string specifying the type of difference to plot, either "percent", "absolute", or "edges".
132+
#' @return A ggplot object showing the edges history over time for the specified network and type.
133+
#' @importFrom ggplot2 ggplot aes geom_line geom_hline labs
134+
#' @importFrom dplyr filter pull
135+
#' @importFrom rlang .data
136+
#' @export
137+
plot_edges_history <- function(x, network, type) {
138+
if (!class(x) %in% c("netsim", "data.frame")) {
139+
stop("x must be a netsim object or a data frame.")
140+
}
141+
if (!type %in% c("percent", "absolute", "edges")) {
142+
stop("type must be one of 'percent', 'absolute', or 'edges'.")
143+
}
144+
if (!network %in% c("main", "casual")) {
145+
stop("network must be either 'main', or 'casual'.")
146+
}
147+
148+
if (class(x) == "netsim") {
149+
edges_df <- get_edges_history(x, nets = network)
150+
} else {
151+
edges_df <- x
152+
}
153+
154+
if (!all(c("time", "sim", "net", "target", "diff_type", "diff", "mean") %in% names(edges_df))) {
155+
stop("edges_df must contain the columns: time, sim, net, target, diff_type, diff, and mean.")
156+
}
157+
158+
target_val <- edges_df |>
159+
filter(net == network, diff_type == type) |>
160+
pull(target) |>
161+
unique()
162+
163+
edges_df |>
164+
filter(.data$net == network, .data$diff_type == type) |>
165+
ggplot(aes(x = .data$time, y = .data$diff, color = .data$sim)) +
166+
geom_line() +
167+
geom_line(aes(y = .data$mean), color = "black", linewidth = 1) +
168+
geom_hline(aes(yintercept = target_val)) +
169+
labs(
170+
title = paste("Edges history for ", network, " network (", type, ")", sep = ""),
171+
y = paste(type, "difference"),
172+
x = "time"
112173
)
113174
}
114175

176+
#' @title Summarize Final Degrees from Simulation
177+
#' @description Summarizes the final degrees of individuals in the main and casual networks
178+
#' at the end of the simulation and calculates the mean degree for each age and race combination.
179+
#' @param sim A simulation object of class `EpiModel::netsim`.
180+
#' @return A data frame summarizing the mean degree, interquartile range (IQR), and data source
181+
#' for each age and race combination
182+
#' @export
115183

116184
# frequency of rels by age in networks at end of simulation
117185
summarize_final_degrees <- function(sim) {

scripts/prep/4_network_burnin.R

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,28 +69,16 @@ if (!dir.exists(localtests_dir)) {
6969
# Save the simulation object to a file
7070
saveRDS(sim, file = here::here("localtests", "sim_30yrs.rds"))
7171

72-
plot(sim, y = c("edges_main"), sim.lines = TRUE)
73-
plot(sim, y = c("edges_casual"), sim.lines = TRUE)
74-
75-
76-
e <- get_edges_history(sim)
77-
78-
e |>
79-
ggplot2::ggplot(ggplot2::aes(x = time, y = main_diff_perc, color = sim)) +
80-
ggplot2::geom_line() +
81-
ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
82-
ggplot2::geom_line(ggplot2::aes(y = mean_main_diff_perc), color = "black", linewidth = 1) +
83-
ggplot2::labs(title = "Main Network Edges Over Time")
84-
85-
e |>
86-
ggplot2::ggplot(ggplot2::aes(x = time, y = casual_diff_perc, color = sim)) +
87-
ggplot2::geom_line() +
88-
ggplot2::geom_hline(yintercept = 0, linetype = "dashed") +
89-
ggplot2::geom_line(ggplot2::aes(y = mean_casual_diff_perc), color = "black", linewidth = 1) +
90-
ggplot2::labs(title = "Casual Network Edges Over Time")
91-
72+
# Plotting and summarizing the simulation results
73+
## optional: first extract edge history df with get_edges_history(sim) to have as separate object
74+
plot_edges_history(sim, "main", "percent")
75+
plot_edges_history(sim, "casual", "percent")
76+
77+
yaml_params_loc <- here::here("networks", "params", "nw_params.yaml")
78+
plot_final_degrees(sim, "main")
79+
plot_final_degrees(sim, "casual")
9280
s <- summarize_final_degrees(sim)
93-
t <- get_target_degrees_age_race(here::here("networks", "params", "nw_params.yaml"))
81+
t <- get_target_degrees_age_race(yaml_params_loc)
9482

9583

9684
s |>

0 commit comments

Comments
 (0)