Skip to content

Commit d41fd77

Browse files
committed
revise preproc in ei_bench
1 parent 76d3902 commit d41fd77

File tree

3 files changed

+33
-31
lines changed

3 files changed

+33
-31
lines changed

R/ei_sens.R

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,11 @@ plot.ei_sens <- function(
423423
#' et al. (2024).
424424
#'
425425
#' @param spec An [ei_spec] object.
426-
#' @param preproc An optional function which takes in a `ei_spec` object (`spec`
427-
#' with one covariate removed) and returns a modified object that includes
428-
#' modified object. Useful to apply any preprocessing, such as a basis
429-
#' transformation, as part of the benchmarking process.
426+
#' @param preproc An optional function which takes in a data frame of covariates
427+
#' and returns a transformed data frame or matrix of covariates.
428+
#' Useful to apply any preprocessing, such as a basis transformation, as part
429+
#' of the benchmarking process. Passed to [rlang::as_function()], and so supports
430+
#' `purrr`-style lambda functions.
430431
#' @param subset Passed on to [ei_est()].
431432
#'
432433
#' @references
@@ -443,22 +444,13 @@ plot.ei_sens <- function(
443444
#' ei_bench(spec)
444445
#'
445446
#' # preprocess to add all 2-way interactions
446-
#' ei_bench(spec, preproc = function(s) {
447-
#' z_cols = match(attr(s, "ei_z"), names(s))
448-
#' s_out = s[-z_cols]
449-
#' z_new = model.matrix(~ .^2 - 1, data = s[z_cols])
450-
#' s_out = cbind(s_out, z_new)
451-
#' ei_spec(s_out, vap_white:vap_other, pres_ind_wal,
452-
#' total = attr(s, "ei_n"), covariates = colnames(z_new))
453-
#' })
447+
#' ei_bench(spec, preproc = ~ model.matrix(~ .^2 - 1, data = .x))
454448
#' @export
455449
ei_bench <- function(spec, preproc = NULL, subset = NULL) {
456450
validate_ei_spec(spec)
457451

458452
if (!missing(preproc)) {
459-
if (!is.function(preproc)) {
460-
cli_abort("{.arg preproc} must be a function.")
461-
}
453+
preproc = rlang::as_function(preproc)
462454
} else {
463455
preproc = function(x) x
464456
}
@@ -470,7 +462,24 @@ ei_bench <- function(spec, preproc = NULL, subset = NULL) {
470462
apply(regr$y - regr$fitted, 2, var)
471463
}
472464

473-
spec_proc = preproc(spec)
465+
make_spec_loo = function(spec, out = character(0)) {
466+
covs = setdiff(attr(spec, "ei_z"), out)
467+
z = preproc(spec[, covs])
468+
if (is.data.frame(z)) {
469+
z = model.matrix(~ 0 + ., z)
470+
}
471+
spec$z_ = z
472+
ei_spec(
473+
spec,
474+
predictors = attr(spec, "ei_x"),
475+
outcome = attr(spec, "ei_y"),
476+
total = attr(spec, "ei_n"),
477+
covariates = "z_",
478+
strip = FALSE
479+
)
480+
}
481+
482+
spec_proc = make_spec_loo(spec)
474483
regr0 = ei_ridge(spec_proc, vcov = FALSE)
475484
riesz0 = ei_riesz(spec_proc, penalty = regr0$penalty)
476485
est0 = ei_est(regr0, riesz0, spec_proc, subset = subs)
@@ -480,8 +489,7 @@ ei_bench <- function(spec, preproc = NULL, subset = NULL) {
480489

481490
covs = attr(spec, "ei_z")
482491
benches = lapply(covs, function(cv) {
483-
spec_loo = reconstruct_ei_spec(spec[setdiff(names(spec), cv)], spec)
484-
spec_loo = preproc(spec_loo)
492+
spec_loo = make_spec_loo(spec, cv)
485493

486494
regr_loo = ei_ridge(spec_loo, vcov = FALSE)
487495
riesz_loo = ei_riesz(spec_loo, penalty = regr_loo$penalty)

R/ei_spec.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ weights.ei_spec <- function(object, normalize = TRUE, ...) {
174174
}
175175
if (isTRUE(normalize)) {
176176
n / mean(n)
177-
} else{
177+
} else {
178178
n
179179
}
180180
}

man/ei_bench.Rd

Lines changed: 6 additions & 12 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)