Skip to content

Commit 73d82ef

Browse files
Copilotkaiemjoy
andauthored
Fix cluster-robust SE handling
Co-authored-by: kaiemjoy <16113030+kaiemjoy@users.noreply.github.com>
1 parent 0b1825e commit 73d82ef

4 files changed

Lines changed: 155 additions & 92 deletions

File tree

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
shifted the random-number stream out of sync and made simulated values, and
2929
their snapshots, differ between macOS, Windows, and Linux). Simulated
3030
values change slightly as a result of this fix. (#447)
31+
* Cluster-robust standard errors now treat multiple `cluster_var` columns as
32+
multi-way clustering instead of collapsing them to a single interaction, and
33+
they are no longer allowed to be smaller than the corresponding model-based
34+
standard errors. (#543)
3135
* Corrected default axis labels in `strat_ests_barplot()` (`xlab`) and
3236
`strat_ests_scatterplot()` (`ylab`) to say "seroincidence" rather than
3337
"seroconversion"/"incidence".

R/compute_cluster_robust_var.R

Lines changed: 34 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -17,103 +17,45 @@
1717
fit,
1818
cluster_var,
1919
stratum_var = NULL) {
20-
# Extract stored data (already split by antigen_iso)
2120
pop_data_list <- attr(fit, "pop_data")
22-
sr_params_list <- attr(fit, "sr_params")
23-
noise_params_list <- attr(fit, "noise_params")
24-
antigen_isos <- attr(fit, "antigen_isos")
25-
26-
# Get MLE estimate
27-
log_lambda_mle <- fit$estimate
28-
29-
# Combine pop_data list back into a single data frame
30-
# to get cluster info
3121
pop_data_combined <- do.call(rbind, pop_data_list)
32-
33-
# Compute score (gradient) using numerical differentiation
34-
# The score is the derivative of log-likelihood w.r.t. log(lambda)
35-
epsilon <- 1e-6
36-
37-
# For each observation, compute the contribution to the score
38-
# We need to identify which cluster each observation belongs to
39-
40-
# Handle multiple clustering levels by creating composite cluster ID
41-
if (length(cluster_var) == 1) {
42-
cluster_ids <- pop_data_combined[[cluster_var]]
43-
} else {
44-
# Create composite cluster ID from multiple variables
45-
cluster_ids <- interaction(
46-
pop_data_combined[, cluster_var, drop = FALSE],
47-
drop = TRUE,
48-
sep = "_"
49-
)
50-
}
51-
52-
# Get unique clusters
53-
unique_clusters <- unique(cluster_ids)
54-
n_clusters <- length(unique_clusters)
55-
56-
# Compute cluster-level scores
57-
cluster_scores <- numeric(n_clusters)
58-
59-
for (i in seq_along(unique_clusters)) {
60-
cluster_id <- unique_clusters[i]
61-
62-
# Get observations in this cluster
63-
cluster_mask <- cluster_ids == cluster_id
64-
65-
# Create temporary pop_data with only this cluster
66-
pop_data_cluster <- pop_data_combined[cluster_mask, , drop = FALSE]
67-
68-
# Split by antigen
69-
pop_data_cluster_list <- split(
70-
pop_data_cluster,
71-
pop_data_cluster$antigen_iso
72-
)
73-
74-
# Ensure all antigen_isos are represented
75-
# (add empty data frames if missing)
76-
for (ag in antigen_isos) {
77-
if (!ag %in% names(pop_data_cluster_list)) {
78-
# Create empty data frame with correct structure
79-
pop_data_cluster_list[[ag]] <- pop_data_list[[ag]][0, , drop = FALSE]
80-
}
22+
standard_var_log_lambda <- 1 / fit$hessian |> as.numeric()
23+
24+
subset_cluster_vars <- unlist(
25+
lapply(seq_along(cluster_var), function(n_vars) {
26+
utils::combn(cluster_var, n_vars, simplify = FALSE)
27+
}),
28+
recursive = FALSE
29+
)
30+
31+
cluster_var_terms <- vapply(subset_cluster_vars, length, integer(1))
32+
robust_var_log_lambda <- 0
33+
34+
for (i in seq_along(subset_cluster_vars)) {
35+
cluster_vars_subset <- subset_cluster_vars[[i]]
36+
if (length(cluster_vars_subset) == 1) {
37+
cluster_ids <- pop_data_combined[[cluster_vars_subset]]
38+
} else {
39+
cluster_ids <- interaction(
40+
pop_data_combined[, cluster_vars_subset, drop = FALSE],
41+
drop = TRUE,
42+
sep = "_"
43+
)
8144
}
8245

83-
# Compute log-likelihood for this cluster at MLE
84-
ll_cluster_mle <- -(.nll(
85-
log.lambda = log_lambda_mle,
86-
pop_data = pop_data_cluster_list,
87-
antigen_isos = antigen_isos,
88-
curve_params = sr_params_list,
89-
noise_params = noise_params_list,
90-
verbose = FALSE
91-
))
92-
93-
# Compute log-likelihood at MLE + epsilon
94-
ll_cluster_plus <- -(.nll(
95-
log.lambda = log_lambda_mle + epsilon,
96-
pop_data = pop_data_cluster_list,
97-
antigen_isos = antigen_isos,
98-
curve_params = sr_params_list,
99-
noise_params = noise_params_list,
100-
verbose = FALSE
101-
))
102-
103-
# Numerical derivative (score for this cluster)
104-
cluster_scores[i] <- (ll_cluster_plus - ll_cluster_mle) / epsilon
46+
robust_var_log_lambda <- robust_var_log_lambda +
47+
(-1)^(cluster_var_terms[[i]] + 1) *
48+
.compute_cluster_var_oneway(
49+
fit = fit,
50+
cluster_ids = cluster_ids,
51+
pop_data_combined = pop_data_combined
52+
)
10553
}
10654

107-
# Compute B matrix (middle of sandwich)
108-
# B = sum of outer products of cluster scores
109-
b_matrix <- sum(cluster_scores^2) # nolint: object_name_linter
110-
111-
# Get Hessian (already computed by nlm)
112-
h_matrix <- fit$hessian # nolint: object_name_linter
113-
114-
# Sandwich variance: V = H^(-1) * B * H^(-1)
115-
# Since we have a scalar parameter, this simplifies to:
116-
var_log_lambda_robust <- b_matrix / (h_matrix^2)
55+
robust_var_log_lambda <- max(
56+
standard_var_log_lambda,
57+
robust_var_log_lambda
58+
)
11759

118-
return(var_log_lambda_robust)
60+
return(robust_var_log_lambda)
11961
}

R/compute_cluster_var_oneway.R

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#' Compute one-way cluster-robust variance for seroincidence estimates
2+
#'
3+
#' @param fit a `seroincidence` object from [est_seroincidence()]
4+
#' @param cluster_ids cluster identifier for each row in `pop_data_combined`
5+
#' @param pop_data_combined combined population data across antigen isotypes
6+
#'
7+
#' @return one-way cluster-robust variance of log(lambda)
8+
#' @keywords internal
9+
#' @noRd
10+
.compute_cluster_var_oneway <- function(
11+
fit,
12+
cluster_ids,
13+
pop_data_combined) {
14+
pop_data_list <- attr(fit, "pop_data")
15+
sr_params_list <- attr(fit, "sr_params")
16+
noise_params_list <- attr(fit, "noise_params")
17+
antigen_isos <- attr(fit, "antigen_isos")
18+
log_lambda_mle <- fit$estimate
19+
epsilon <- 1e-6
20+
21+
unique_clusters <- unique(cluster_ids)
22+
cluster_scores <- numeric(length(unique_clusters))
23+
24+
for (i in seq_along(unique_clusters)) {
25+
cluster_id <- unique_clusters[i]
26+
cluster_mask <- cluster_ids == cluster_id
27+
pop_data_cluster <- pop_data_combined[cluster_mask, , drop = FALSE]
28+
pop_data_cluster_list <- split(
29+
pop_data_cluster,
30+
pop_data_cluster$antigen_iso
31+
)
32+
33+
for (ag in antigen_isos) {
34+
if (!ag %in% names(pop_data_cluster_list)) {
35+
pop_data_cluster_list[[ag]] <- pop_data_list[[ag]][0, , drop = FALSE]
36+
}
37+
}
38+
39+
ll_cluster_mle <- -(.nll(
40+
log.lambda = log_lambda_mle,
41+
pop_data = pop_data_cluster_list,
42+
antigen_isos = antigen_isos,
43+
curve_params = sr_params_list,
44+
noise_params = noise_params_list,
45+
verbose = FALSE
46+
))
47+
ll_cluster_plus <- -(.nll(
48+
log.lambda = log_lambda_mle + epsilon,
49+
pop_data = pop_data_cluster_list,
50+
antigen_isos = antigen_isos,
51+
curve_params = sr_params_list,
52+
noise_params = noise_params_list,
53+
verbose = FALSE
54+
))
55+
56+
cluster_scores[i] <- (ll_cluster_plus - ll_cluster_mle) / epsilon
57+
}
58+
59+
b_matrix <- sum(cluster_scores^2) # nolint: object_name_linter
60+
h_matrix <- fit$hessian # nolint: object_name_linter
61+
62+
b_matrix / (h_matrix^2)
63+
}

tests/testthat/test-cluster_robust_se.R

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,57 @@ test_that("multiple cluster variables work correctly", {
145145
# Standard errors should be positive
146146
expect_true(sum_multi$SE > 0)
147147
})
148+
149+
test_that("singleton cluster IDs do not reduce standard errors", {
150+
withr::local_seed(20241213)
151+
152+
test_data <- sees_pop_data_pk_100
153+
test_data$household_id <- seq_len(nrow(test_data))
154+
155+
est_standard <- est_seroincidence(
156+
pop_data = test_data,
157+
sr_param = typhoid_curves_nostrat_100,
158+
noise_param = example_noise_params_pk,
159+
antigen_isos = c("HlyE_IgG", "HlyE_IgA")
160+
)
161+
est_household <- est_seroincidence(
162+
pop_data = test_data,
163+
sr_param = typhoid_curves_nostrat_100,
164+
noise_param = example_noise_params_pk,
165+
antigen_isos = c("HlyE_IgG", "HlyE_IgA"),
166+
cluster_var = "household_id"
167+
)
168+
169+
sum_standard <- summary(est_standard, verbose = FALSE)
170+
sum_household <- summary(est_household, verbose = FALSE)
171+
172+
expect_equal(sum_household$SE, sum_standard$SE)
173+
})
174+
175+
test_that("nested multi-level clustering uses the broader cluster level", {
176+
withr::local_seed(20241213)
177+
178+
test_data <- sees_pop_data_pk_100
179+
test_data$household_id <- seq_len(nrow(test_data))
180+
test_data$commune <- rep(1:10, length.out = nrow(test_data))
181+
182+
est_commune <- est_seroincidence(
183+
pop_data = test_data,
184+
sr_param = typhoid_curves_nostrat_100,
185+
noise_param = example_noise_params_pk,
186+
antigen_isos = c("HlyE_IgG", "HlyE_IgA"),
187+
cluster_var = "commune"
188+
)
189+
est_nested <- est_seroincidence(
190+
pop_data = test_data,
191+
sr_param = typhoid_curves_nostrat_100,
192+
noise_param = example_noise_params_pk,
193+
antigen_isos = c("HlyE_IgG", "HlyE_IgA"),
194+
cluster_var = c("commune", "household_id")
195+
)
196+
197+
sum_commune <- summary(est_commune, verbose = FALSE)
198+
sum_nested <- summary(est_nested, verbose = FALSE)
199+
200+
expect_equal(sum_nested$SE, sum_commune$SE)
201+
})

0 commit comments

Comments
 (0)