Skip to content

Commit 3821ec8

Browse files
authored
Multiple comparisons (#1)
* Add functionality for multiple treatment or control groups * Initial incorportation of multiple comparisons (not tested) * Fix bugs in multiple comparisons * Only fill out x matrices when both groups have units, add option to weight supplemental balance * Incorporate sparse matrices to allow for very large problems * Add threads argument * Update descriptions, work on cleaning optimize_controls * Finish cleaning inputs * Allow >2 comparisons; update docs * Add checks, tests, remove ratio_star * Updates documentation * Add multiple comparisons to vignette * Update vignette and docs * Fix package checks * Remove print * Attempt to debug why test is failing * Set RNG * debug test * debug test * debug test * remove print statements; update comments * Add cran installation to readme
1 parent 5d2c87f commit 3821ec8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1920
-583
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@ Meta
77
.httr-oauth
88
.DS_Store
99
docs
10+
/doc/
11+
/Meta/

DESCRIPTION

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
Package: natstrat
22
Type: Package
33
Title: Obtain Unweighted Natural Strata that Balance Many Covariates
4-
Version: 1.0.0
4+
Version: 2.0.0
55
Authors@R: c(
66
person("Katherine", "Brumberg", email = "kbrum@wharton.upenn.edu",
77
role=c("aut", "cre")))
8-
Description: Natural strata fix a constant ratio of controls to treated units within
9-
each stratum. This ratio need not be an integer. The control units are
10-
chosen using randomized rounding of a linear program that balances many
11-
covariates.
8+
Description: Natural strata can be used in observational studies to balance
9+
the distributions of many covariates across any number of treatment
10+
groups and any number of comparisons. These strata have proportional
11+
amounts of units within each stratum across the treatments, allowing
12+
for simple interpretation and aggregation across strata. Within each
13+
stratum, the units are chosen using randomized rounding of a linear
14+
program that balances many covariates.
1215
To solve the linear program, the 'Gurobi' commercial optimization software
1316
is recommended, but not required. The 'gurobi' R package can be installed following the instructions
1417
at <https://www.gurobi.com/documentation/9.1/refman/ins_the_r_package.html>.
@@ -25,7 +28,9 @@ Imports:
2528
pps,
2629
sampling,
2730
ggplot2,
28-
rlang
31+
rlang,
32+
ramify,
33+
slam
2934
Depends:
3035
R (>= 2.10),
3136
caret

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ export(generate_qs)
77
export(optimize_controls)
88
export(stand)
99
import(ggplot2)
10+
import(ramify)
11+
import(slam)
1012
importFrom(caret,dummyVars)
1113
importFrom(rlang,.data)
1214
importFrom(stats,as.formula)
@@ -15,6 +17,8 @@ importFrom(stats,median)
1517
importFrom(stats,model.frame.default)
1618
importFrom(stats,na.pass)
1719
importFrom(stats,predict)
20+
importFrom(stats,rbinom)
21+
importFrom(stats,rmultinom)
1822
importFrom(stats,sd)
1923
importFrom(stats,setNames)
2024
importFrom(stats,terms)

NEWS.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,32 @@
1+
# natstrat 2.0.0 (2021-10-12)
2+
3+
This version adds several new functionalities:
4+
* Multiple treatment or control groups
5+
* Multiple separate comparisons, using various subsets of the treatment and control groups,
6+
for which units are chosen in order to balance covariate distributions
7+
for all comparisons simultaneously
8+
9+
Several changes to the interface have been made:
10+
11+
* `z` should generally be a factor instead of a vector as before
12+
* `treated` and `control` specifications, if needed, should each be a level of `z`
13+
* many arguments can be specified for each of the treatment levels
14+
* `q_s`, `max_entry_s` can have a row per treatment level
15+
* `ratio`, `max_ratio` can have an entry per treatment level
16+
* inputs for the supplemental comparison have been added across the functions: `q_star_s`,
17+
`ratio_star`, `treated_star`, `weight_star`
18+
19+
There are several changes to the outputs:
20+
21+
* `optimize_controls` now has only one version of `eps`, `objective`, `objective_wo_importances`
22+
instead of a raw and regular version. The version now reported is the raw version, not corrected
23+
for missingness. If you would like corrected versions, refer to the standardized differences
24+
outputted by `check_balance` instead
25+
* `generate_constraints` now returns only standardized outputs, not centered. The centering
26+
now takes place within `optimize_controls` instead
27+
28+
29+
130
# natstrat 1.0.0 (2021-05-17)
231

332
The first released version.

R/balance_LP.R

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
#' to select.
55
#'
66
#' @inheritParams optimize_controls
7-
#' @param q_s a named vector indicating how many control units are to be selected from each stratum.
7+
#' @param q_s a named vector or matrix indicating how many control units are to be selected from each stratum.
8+
#' If there is one control group and all treated units are desired, this can be a vector; otherwise,
9+
#' this should have one row per treatment group, where the order of the rows matches the order of
10+
#' the levels of \code{z}, including the treated level.
811
#' @param st_vals the unique stratum levels contained in \code{st}.
912
#' @param S the number of unique stratum levels contained in \code{st}.
1013
#' @param N the total number of available controls in the data.
@@ -19,41 +22,77 @@
1922
#' }
2023
#'
2124
#' @keywords internal
25+
#' @import ramify
26+
#' @import slam
2227

2328
balance_LP <- function(z, X, importances, st, st_vals, S, q_s, N,
24-
solver, integer, time_limit) {
29+
solver, integer, time_limit, threads = 1,
30+
weight_comp = 1) {
31+
2532
if (solver == "gurobi" && !requireNamespace("gurobi", quietly = TRUE)) {
2633
stop("Package \'gurobi\' needed if \"solver\" parameter set to \"gurobi\". Please
2734
install it or switch the \"solver\" parameter to \"Rglpk\".",
2835
call. = FALSE)
2936
}
37+
groups <- levels(z)
38+
k <- length(groups)
39+
kc2 <- choose(k, 2)
40+
n_comp <- length(q_s)
3041

3142
# Set up and solve the linear program
3243
model <- list()
33-
params <- list(TimeLimit = time_limit, OutputFlag = 0)
44+
params <- list(TimeLimit = time_limit, OutputFlag = 0, Threads = threads)
3445

3546
nvars <- dim(X)[2] # number of variables
36-
X[is.na(X)] <- 0
37-
X0 <- X[z == 0, ]
38-
model$obj <- c(rep(0, N), rep(importances, 2))
39-
blk1 <- t(X0)
40-
ident <- diag(1, nvars, nvars) # identity matrix
41-
model$A <- cbind(blk1 / sum(q_s), ident, -ident) # constraints, individual vars
47+
model$obj <- rep(0, n_comp * N)
48+
for (comp in 1:n_comp) {
49+
model$obj <- c(model$obj, rep(rep(importances * weight_comp[comp], 2), kc2))
50+
}
51+
model$A <- create_balance_matrices(X = X, z = z, N = N, nvars = nvars,
52+
kc2 = kc2, q_s = q_s, return = "A")$A
53+
54+
# Now, append stratum size constraints for each comparison
55+
st_mats <- simple_triplet_zero_matrix(nrow = k * S, ncol = N)
56+
for (group_num in 1:k) {
57+
group <- groups[group_num]
58+
st_mats[((group_num - 1) * S + 1):(group_num * S), which(z == group)] <- 1 * outer(st_vals, st[z == group], "==")
59+
}
60+
for (comp in 1:n_comp) {
61+
model$A <- rbind(model$A,
62+
cbind(simple_triplet_zero_matrix(nrow = k * S, ncol = (N * (comp - 1))),
63+
st_mats, simple_triplet_zero_matrix(nrow = k * S, ncol = N * (n_comp - comp) + 2 * n_comp * kc2 * nvars)))
64+
}
65+
66+
# Now, if multiple comparisons, add constraint that all a's for a unit add to <= 1
67+
# (so that one unit is not chosen for multiple comparisons)
68+
if (n_comp > 1) {
69+
mat <- do.call(cbind, replicate(n_comp, simple_triplet_diag_matrix(rep(1, N)), simplify=FALSE))
70+
model$A <- rbind(model$A, cbind(mat, simple_triplet_zero_matrix(nrow = N, ncol = 2 * n_comp * kc2 * nvars)))
71+
}
4272

43-
# Now, append stratum size constraints
44-
model$A <- rbind(model$A, cbind(1 * outer(st_vals, st[z == 0], "=="), matrix(0, S,
45-
2 * nvars)))
4673
# Constraints for eps are equalities, number of controls per strata are equalities
47-
model$sense <- c(rep("==", nvars), rep("==", S))
48-
model$rhs <- c(rep(0, nvars), q_s) # Right hand side of constraints
74+
# Constraints for units only counting in one comparison are <=
75+
model$sense <- c(rep("==", n_comp * kc2 * nvars), rep("==", n_comp * k * S))
76+
if (n_comp > 1) {
77+
model$sense <- c(model$sense, rep("<=", N))
78+
}
79+
80+
# right hand side of constraints
81+
model$rhs <- rep(0, n_comp * kc2 * nvars)
82+
for (comp in 1:n_comp) {
83+
model$rhs <- c(model$rhs, ramify::flatten(q_s[[comp]]))
84+
}
85+
if (n_comp > 1) {
86+
model$rhs <- c(model$rhs, rep(1, N))
87+
}
4988

50-
ndecv <- as.integer(N + (2 * nvars)) # number of decision variables
89+
ndecv <- as.integer(n_comp * N + (2 * n_comp * kc2 * nvars)) # number of decision variables
90+
model$ub <- c(rep(1, n_comp * N), rep(Inf, 2 * n_comp * kc2 * nvars))
5191
model$lb <- rep(0, ndecv)
52-
model$ub <- c(rep(1, N), rep(Inf, 2 * nvars))
5392
bounds <- list(lower = list(ind = 1:ndecv, val = model$lb),
5493
upper = list(ind = 1:ndecv, val = model$ub))
5594
if (integer) {
56-
model$vtype <- c(rep("B", N), rep("C", 2 * nvars))
95+
model$vtype <- c(rep("B", n_comp * N), rep("C", 2 * n_comp * kc2 * nvars))
5796
} else {
5897
model$vtype <- rep("C", ndecv)
5998
}
@@ -64,7 +103,8 @@ balance_LP <- function(z, X, importances, st, st_vals, S, q_s, N,
64103
} else {
65104
params$TimeLimit <- 0
66105
}
67-
o <- Rglpk::Rglpk_solve_LP(model$obj, model$A, model$sense, model$rhs, bounds = bounds,
106+
o <- Rglpk::Rglpk_solve_LP(obj = model$obj, mat = model$A, dir = model$sense,
107+
rhs = model$rhs, bounds = bounds,
68108
types = model$vtype, control = list(
69109
canonicalize_status = FALSE, tm_limit = params$TimeLimit))
70110
if (o$status != 5) {
@@ -75,7 +115,8 @@ balance_LP <- function(z, X, importances, st, st_vals, S, q_s, N,
75115
}
76116
if (solver == "gurobi") {
77117
# Note that for gurobi, all inequalities are interpreted to be "or equal to"
78-
model$sense <- c(rep("=", nvars), rep("=", S))
118+
model$sense <- c(rep("=", n_comp * kc2 * nvars), rep("=", n_comp * k * S),
119+
rep("<", N))
79120
o <- gurobi::gurobi(model, params)
80121
if (o$status != "OPTIMAL") {
81122
warning("No solution found for the linear program.")

R/check_balance.R

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@
66
#' This function can also generate love plots of the same quantities.
77
#'
88
#' @inheritParams stand
9+
#' @inheritParams optimize_controls
910
#' @param X a data frame containing the covariates in the columns over which balance is desired. The number
1011
#' of rows should equal the length of \code{z}.
12+
#' @param treated which treatment value should be considered the treated units. This
13+
#' must be one of the values of \code{z}.
14+
#' @param control which treatment value should be considered the control units. This
15+
#' must be one of the values of \code{z}.
1116
#' @param selected a boolean vector including whether each unit was selected as part of the treated and control
1217
#' groups for analysis. Should be the same length as \code{z} and typically comes from the results of
1318
#' \code{\link{optimize_controls}()}.
@@ -65,40 +70,45 @@
6570
#' selected = results$selected,
6671
#' plot = TRUE)
6772

68-
69-
check_balance <- function(z, X, st, selected, denom_variance = "treated", plot = FALSE, message = TRUE) {
73+
check_balance <- function(z, X, st, selected, treated = 1, control = 0,
74+
denom_variance = "treated", plot = FALSE, message = TRUE) {
7075

7176
if (plot && !requireNamespace("ggplot2", quietly = TRUE) && !requireNamespace("rlang", quietly = TRUE)) {
72-
stop("Packages \"ggplot2\" and \"rlang\" needed if \"plot\" argument set to \"TRUE\". Please
77+
stop("Packages \"ggplot2\" and \"rlang\" needed if \"plot\" argument set to \"TRUE\". Please
7378
install these or switch the \"plot\" argument to \"FALSE\".",
74-
call. = FALSE)
79+
call. = FALSE)
7580
}
7681

7782
st <- as.factor(st)
7883
X[, sapply(X, is.logical)] <- sapply(X[, sapply(X, is.logical)], as.numeric)
7984
dummies <- dummyVars( ~ ., data = X, levelsOnly = FALSE)
8085
full_X <- predict(dummies, newdata = X)
8186

82-
sd_across <- get_stand_diffs(full_X, z, selected, denom_variance = denom_variance)
87+
sd_across <- get_stand_diffs(full_X, z, selected, treated = treated, control = control,
88+
denom_variance = denom_variance)
8389

8490
sd_strata <- NULL
8591
for (ist in levels(st)) {
8692
sd_strata <- rbind(sd_strata, cbind(get_stand_diffs(full_X, z, selected, st, ist,
93+
treated = treated, control = control,
8794
denom_variance = denom_variance), ist))
8895
}
8996
colnames(sd_strata)[4] <- "stratum"
9097

91-
q_s <- sapply(levels(st), function(ist) {sum( !z & selected & st == ist )})
92-
n_s <- sapply(levels(st), function(ist) {sum( !z & st == ist )})
98+
q_s <- sapply(levels(st), function(ist) {sum( z == control & selected & st == ist )})
99+
n_s <- sapply(levels(st), function(ist) {sum( z == control & st == ist )})
93100

94-
fr_tab <- table(z, st)
95101
sd_strata_avg <- sd_across
96102
sd_strata_avg[1:dim(sd_strata_avg)[1], 1:2] <- NA
97103
for (cov in row.names(sd_strata_avg)) {
98104
sd_strata_avg[cov, 1] <- sum(sapply(levels(st), function(ist) {
99-
sd_strata[sd_strata$covariate == cov & sd_strata$stratum == ist, 1] * (n_s[ist] - sum(is.na(X[!z & st == ist, cov]))) })) / (sum(n_s) - sum(is.na(X[!z, cov])))
105+
sd_strata[sd_strata$covariate == cov & sd_strata$stratum == ist, 1] *
106+
(n_s[ist] - sum(is.na(X[z == 0 & st == ist, cov]))) })) /
107+
(sum(n_s) - sum(is.na(X[z == 0, cov])))
100108
sd_strata_avg[cov, 2] <- sum(sapply(levels(st), function(ist) {
101-
sd_strata[sd_strata$covariate == cov & sd_strata$stratum == ist, 2] * (q_s[ist] - sum(is.na(X[!z & st == ist & selected, cov]))) })) / (sum(q_s) - sum(is.na(X[!z & selected, cov])))
109+
sd_strata[sd_strata$covariate == cov & sd_strata$stratum == ist, 2] *
110+
(q_s[ist] - sum(is.na(X[z == 0 & st == ist & selected, cov]))) })) /
111+
(sum(q_s) - sum(is.na(X[z == 0 & selected, cov])))
102112
}
103113

104114
if (message) {
@@ -169,41 +179,50 @@ check_balance <- function(z, X, st, selected, denom_variance = "treated", plot =
169179
#' choosing a subset of controls, and one for after. The rows pertain to covariates.
170180
#' @keywords internal
171181

172-
get_stand_diffs <- function(data, z, selected, st = NULL, ist = NULL, denom_variance = "treated") {
182+
get_stand_diffs <- function(data, z, selected, st = NULL, ist = NULL,
183+
treated = 1, control = 0, denom_variance = "treated") {
184+
if (is.vector(z)) {
185+
z <- as.factor(z)
186+
}
173187
if (!is.null(ist)) {
174188
ind <- st == ist
175189
} else {
176190
ind <- rep(TRUE, length(z))
177191
}
178-
treatedmat_full <- data[z == 1, , drop = FALSE]
179-
treatedmat <- data[z == 1 & ind, , drop = FALSE]
180192
# Standardized differences before matching
181-
controlmat_before_full <- data[z == 0, , drop = FALSE]
182-
controlmat_before <- data[z == 0 & ind, , drop = FALSE]
193+
treatedmat_before_full <- data[z == treated, , drop = FALSE]
194+
treatedmat_before <- data[z == treated & ind, , drop = FALSE]
195+
treatedmean_before <- apply(treatedmat_before, 2, mean, na.rm = TRUE)
196+
controlmat_before_full <- data[z == control, , drop = FALSE]
197+
controlmat_before <- data[z == control & ind, , drop = FALSE]
183198
controlmean_before <- apply(controlmat_before, 2, mean, na.rm = TRUE)
184-
treatmean <- apply(treatedmat, 2, mean, na.rm = TRUE)
185-
treatvar <- apply(treatedmat_full, 2, var, na.rm = TRUE)
186-
controlvar <- apply(controlmat_before_full, 2, var, na.rm = TRUE)
187-
if (dim(treatedmat_full)[1] == 1) {
188-
treatvar[1:length(treatvar)] <- 0.0
199+
variances <- sapply(levels(z), function(group) {
200+
return(apply(data[z == group, , drop = FALSE], 2, var, na.rm = TRUE))
201+
})
202+
if (is.vector(variances)) {
203+
variances <- matrix(variances, ncol = 1)
189204
}
205+
variances[is.na(variances)] <- 0
190206
if (denom_variance == "pooled") {
191-
denom <- sqrt((treatvar + controlvar) / 2)
207+
denom <- sqrt(rowMeans(variances))
192208
} else {
193-
denom <- sqrt(treatvar)
194-
denom[treatvar == 0] <- sqrt(controlvar[treatvar == 0] / 2)
209+
denom <- sqrt(variances[, levels(z) == treated])
210+
denom[denom == 0] <-
211+
sqrt(rowMeans(variances)[denom == 0])
195212
}
196-
stand_diff_before <- rep(NA, length(treatvar))
197-
names(stand_diff_before) <- names(treatvar)
198-
stand_diff_before <- (treatmean - controlmean_before) / denom
199-
stand_diff_before[treatmean == controlmean_before] <- 0.0
213+
stand_diff_before <- rep(NA, nrow(variances))
214+
names(stand_diff_before) <- dimnames(variances)[[1]]
215+
stand_diff_before <- (treatedmean_before - controlmean_before) / denom
216+
stand_diff_before[treatedmean_before == controlmean_before] <- 0.0
200217
# Standardized differences after matching
201-
controlmat_after <- data[selected & z == 0 & ind, , drop = FALSE]
218+
controlmat_after <- data[selected & z == control & ind, , drop = FALSE]
202219
controlmean_after <- apply(controlmat_after, 2, mean, na.rm = TRUE)
203-
stand_diff_after <- rep(NA, length(treatvar))
204-
names(stand_diff_after) <- names(treatvar)
205-
stand_diff_after <- (treatmean - controlmean_after) / denom
206-
stand_diff_after[treatmean == controlmean_after] <- 0.0
220+
treatedmat_after <- data[selected & z == treated & ind, , drop = FALSE]
221+
treatedmean_after <- apply(treatedmat_after, 2, mean, na.rm = TRUE)
222+
stand_diff_after <- rep(NA, nrow(variances))
223+
names(stand_diff_after) <- dimnames(variances)[[1]]
224+
stand_diff_after <- (treatedmean_after - controlmean_after) / denom
225+
stand_diff_after[treatedmean_after == controlmean_after] <- 0.0
207226
sd_matrix <- data.frame(abs_stand_diff_before = abs(stand_diff_before),
208227
abs_stand_diff_after = abs(stand_diff_after))
209228
if (!is.null(ist)) {
@@ -329,8 +348,8 @@ plot_stand_diffs <- function(sds, type) {
329348
stratum = sds$sd_strata$stratum)
330349

331350
p <- apply(as.array(unique(sds$sd_strata$stratum)), 1, function(x) {
332-
ggplot(plot_dataframe[plot_dataframe$stratum == x,],
333-
aes(x = .data$abs_stand_diff, y = .data$covariates)) +
351+
ggplot(plot_dataframe[plot_dataframe$stratum == x,],
352+
aes(x = .data$abs_stand_diff, y = .data$covariates)) +
334353
geom_point(size = 5, aes(shape = .data$type)) +
335354
scale_shape_manual(values = c(4, 1)) +
336355
geom_vline(xintercept = c(.1,.2), lty = 2) +

0 commit comments

Comments
 (0)