Skip to content

Commit 2dfcb1f

Browse files
committed
convert rel validation scripts into fx
1 parent b8d684c commit 2dfcb1f

File tree

2 files changed

+61
-42
lines changed

2 files changed

+61
-42
lines changed

epimodel-sti/R/validate_simulated_rels.R

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
#' @importFrom tidyr pivot_longer
1313
#' @importFrom rlang .data
1414
#' @export
15-
get_target_degrees_age_race <- function(target_yaml_file, nets = c("main", "casual"), joint_attrs = c("age", "race")) {
15+
get_target_degrees_age_race <- function(yaml_params_loc, nets = c("main", "casual"), joint_attrs = c("age", "race")) {
1616
# load network targets from yaml file
17-
x <- read_yaml(target_yaml_file)
17+
x <- read_yaml(yaml_params_loc)
1818

1919
# Validate inputs, currently only supports main and casual networks, age and race joint attributes
20-
if (nets != c("main", "casual")) {
20+
if (sum(nets == c("main", "casual")) != 2) {
2121
stop("Currently only 'main' and 'casual' networks are supported, in that order.")
2222
}
2323

24-
if (joint_attrs != c("age", "race")) {
24+
if (sum(joint_attrs == c("age", "race")) != 2) {
2525
stop("Currently only race and race are supported as joint attributes, in that order.")
2626
}
2727

@@ -33,9 +33,6 @@ get_target_degrees_age_race <- function(target_yaml_file, nets = c("main", "casu
3333
if (length(joint_attrs) != 2) {
3434
stop("joint_attrs must be a character vector of length 2.")
3535
}
36-
if (!attr %in% names(x[[nets[1]]][["nodefactor"]])) {
37-
stop("Specified attribute not found in the YAML file.")
38-
}
3936

4037
joint_name <- paste0(joint_attrs[1], "_", joint_attrs[2])
4138

@@ -54,7 +51,7 @@ get_target_degrees_age_race <- function(target_yaml_file, nets = c("main", "casu
5451

5552
dat <- data.frame(
5653
main = x[[nets[1]]]$nodefactor[[joint_name]],
57-
casual = x[[nets[1]]]$nodefactor[[joint_name]],
54+
casual = x[[nets[2]]]$nodefactor[[joint_name]],
5855
age = rep(ages, (length(races))),
5956
race = rep(races, each = length(ages))
6057
)
@@ -184,6 +181,7 @@ plot_edges_history <- function(x, network, type) {
184181
# frequency of rels by age in networks at end of simulation
185182
summarize_final_degrees <- function(sim) {
186183
simdat <- NULL
184+
nsims <- sim$control$nsims
187185

188186
for (i in seq_len(nsims)) {
189187
this_sim <- paste0("sim", i)
@@ -219,8 +217,50 @@ summarize_final_degrees <- function(sim) {
219217
dplyr::mutate(data = "simulated")
220218
}
221219

222-
# mean rel durs at end (may not match targets if simulation is not long enough)
223-
get_mean_durations <- function(sim) {
220+
#' @title Plot Final Degrees for Main and Casual Networks
221+
#' @description Plots the final degrees of individuals in the main and casual networks summarized across simulations
222+
#' and compares them to target degrees extracted from a YAML file.
223+
#' @param sim A simulation object of class `EpiModel::netsim`.
224+
#' @param network A character string specifying the network type, either "main" or "casual".
225+
#' @param yaml_params_loc Path to the YAML file containing target degrees.
226+
#' @return A ggplot object showing the final degrees for the specified network type,
227+
#' comparing simulated degrees to target degrees.
228+
#' @importFrom ggplot2 ggplot aes geom_point geom_errorbar facet_wrap
229+
#' @importFrom dplyr filter mutate
230+
#' @importFrom rlang .data
231+
#' @export
232+
plot_final_degrees <- function(sim, network, yaml_params_loc) {
233+
if (!network %in% c("main", "casual")) {
234+
stop("network must be either 'main' or 'casual'.")
235+
}
236+
237+
s <- summarize_final_degrees(sim)
238+
t <- get_target_degrees_age_race(yaml_params_loc) |>
239+
dplyr::mutate(IQR1 = degree, IQR3 = degree) # targets do not have IQRs
240+
241+
y <- rbind(s, t)
242+
243+
y |>
244+
dplyr::filter(.data$type == network) |>
245+
ggplot2::ggplot(ggplot2::aes(x = .data$age, y = .data$degree, color = .data$data)) +
246+
ggplot2::geom_point() +
247+
ggplot2::geom_errorbar(ggplot2::aes(ymin = .data$IQR1, ymax = .data$IQR3), width = 0.2) +
248+
ggplot2::facet_wrap(~ .data$race)
249+
}
250+
251+
252+
#' @title Get Mean Durations of Relationships at End of Simulation
253+
#' @description Calculates the mean durations of relationships in the main and casual networks at the end
254+
#' of the simulation, comparing them to target durations specified in a YAML file.
255+
#' @param sim A simulation object of class `EpiModel::netsim`.
256+
#' @param nets A character vector specifying the networks to calculate durations for, default is c("main", "casual").
257+
#' @param yaml_params_loc Path to the YAML file containing target durations.
258+
#' @return A data frame summarizing the target and simulated mean durations for each network,
259+
#' along with the standard deviation of the simulated durations.
260+
#' @importFrom yaml read_yaml
261+
#' @export
262+
get_mean_durations <- function(sim, nets = c("main", "casual"), yaml_params_loc) {
263+
x <- read_yaml(yaml_params_loc)
224264
main_durs <- NULL
225265
casual_durs <- NULL
226266
nsims <- sim$control$nsims
@@ -238,8 +278,12 @@ get_mean_durations <- function(sim) {
238278
}
239279

240280
data.frame(
241-
type = c("main", "casual"),
242-
mean_duration = c(mean(main_durs), mean(casual_durs)),
243-
sd_duration = c(sd(main_durs), sd(casual_durs))
281+
network = c("main", "casual"),
282+
target = c(
283+
x[[nets[1]]]$duration$overall,
284+
x[[nets[2]]]$duration$overall
285+
),
286+
sim_mean = c(mean(main_durs), mean(casual_durs)),
287+
sim_sd = c(sd(main_durs), sd(casual_durs))
244288
)
245289
}

scripts/prep/4_network_burnin.R

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -69,40 +69,15 @@ if (!dir.exists(localtests_dir)) {
6969
# Save the simulation object to a file
7070
saveRDS(sim, file = here::here("localtests", "sim_30yrs.rds"))
7171

72-
# Plotting and summarizing the simulation results
73-
## optional: first extract edge history df with get_edges_history(sim) to have as separate object
72+
# Plotting and summarizing the simulation results (relationship stats)
7473
plot_edges_history(sim, "main", "percent")
7574
plot_edges_history(sim, "casual", "percent")
7675

7776
yaml_params_loc <- here::here("networks", "params", "nw_params.yaml")
7877
plot_final_degrees(sim, "main")
7978
plot_final_degrees(sim, "casual")
80-
s <- summarize_final_degrees(sim)
81-
t <- get_target_degrees_age_race(yaml_params_loc)
8279

80+
plot_final_degrees(sim, "main", yaml_params_loc)
81+
plot_final_degrees(sim, "casual", yaml_params_loc)
8382

84-
s |>
85-
dplyr::filter(type == "main") |>
86-
ggplot2::ggplot(ggplot2::aes(x = age, y = degree, color = data)) +
87-
ggplot2::geom_point() +
88-
ggplot2::geom_errorbar(ggplot2::aes(ymin = IQR1, ymax = IQR3), width = 0.2) +
89-
ggplot2::geom_point(
90-
data = t |> dplyr::filter(type == "main"),
91-
ggplot2::aes(x = age, y = degree, color = data), linewidth = 3
92-
) +
93-
ggplot2::facet_wrap(~race)
94-
95-
s |>
96-
dplyr::filter(type == "casual") |>
97-
ggplot2::ggplot(ggplot2::aes(x = age, y = degree, color = data)) +
98-
ggplot2::geom_point() +
99-
ggplot2::geom_errorbar(ggplot2::aes(ymin = IQR1, ymax = IQR3), width = 0.2) +
100-
ggplot2::geom_point(
101-
data = t |> dplyr::filter(type == "casual"),
102-
ggplot2::aes(x = age, y = degree, color = data), linewidth = 3
103-
) +
104-
ggplot2::facet_wrap(~race)
105-
106-
107-
108-
get_mean_durations(sim)
83+
get_mean_durations(sim, yaml_params_loc = yaml_params_loc)

0 commit comments

Comments
 (0)