@@ -423,10 +423,11 @@ plot.ei_sens <- function(
423423# ' et al. (2024).
424424# '
425425# ' @param spec An [ei_spec] object.
426- # ' @param preproc An optional function which takes in a `ei_spec` object (`spec`
427- # ' with one covariate removed) and returns a modified object that includes
428- # ' modified object. Useful to apply any preprocessing, such as a basis
429- # ' transformation, as part of the benchmarking process.
426+ # ' @param preproc An optional function which takes in a data frame of covariates
427+ # ' and returns a transformed data frame or matrix of covariates.
428+ # ' Useful to apply any preprocessing, such as a basis transformation, as part
429+ # ' of the benchmarking process. Passed to [rlang::as_function()], and so supports
430+ # ' `purrr`-style lambda functions.
430431# ' @param subset Passed on to [ei_est()].
431432# '
432433# ' @references
@@ -443,22 +444,13 @@ plot.ei_sens <- function(
443444# ' ei_bench(spec)
444445# '
445446# ' # preprocess to add all 2-way interactions
446- # ' ei_bench(spec, preproc = function(s) {
447- # ' z_cols = match(attr(s, "ei_z"), names(s))
448- # ' s_out = s[-z_cols]
449- # ' z_new = model.matrix(~ .^2 - 1, data = s[z_cols])
450- # ' s_out = cbind(s_out, z_new)
451- # ' ei_spec(s_out, vap_white:vap_other, pres_ind_wal,
452- # ' total = attr(s, "ei_n"), covariates = colnames(z_new))
453- # ' })
447+ # ' ei_bench(spec, preproc = ~ model.matrix(~ .^2 - 1, data = .x))
454448# ' @export
455449ei_bench <- function (spec , preproc = NULL , subset = NULL ) {
456450 validate_ei_spec(spec )
457451
458452 if (! missing(preproc )) {
459- if (! is.function(preproc )) {
460- cli_abort(" {.arg preproc} must be a function." )
461- }
453+ preproc = rlang :: as_function(preproc )
462454 } else {
463455 preproc = function (x ) x
464456 }
@@ -470,7 +462,24 @@ ei_bench <- function(spec, preproc = NULL, subset = NULL) {
470462 apply(regr $ y - regr $ fitted , 2 , var )
471463 }
472464
473- spec_proc = preproc(spec )
465+ make_spec_loo = function (spec , out = character (0 )) {
466+ covs = setdiff(attr(spec , " ei_z" ), out )
467+ z = preproc(spec [, covs ])
468+ if (is.data.frame(z )) {
469+ z = model.matrix(~ 0 + . , z )
470+ }
471+ spec $ z_ = z
472+ ei_spec(
473+ spec ,
474+ predictors = attr(spec , " ei_x" ),
475+ outcome = attr(spec , " ei_y" ),
476+ total = attr(spec , " ei_n" ),
477+ covariates = " z_" ,
478+ strip = FALSE
479+ )
480+ }
481+
482+ spec_proc = make_spec_loo(spec )
474483 regr0 = ei_ridge(spec_proc , vcov = FALSE )
475484 riesz0 = ei_riesz(spec_proc , penalty = regr0 $ penalty )
476485 est0 = ei_est(regr0 , riesz0 , spec_proc , subset = subs )
@@ -480,8 +489,7 @@ ei_bench <- function(spec, preproc = NULL, subset = NULL) {
480489
481490 covs = attr(spec , " ei_z" )
482491 benches = lapply(covs , function (cv ) {
483- spec_loo = reconstruct_ei_spec(spec [setdiff(names(spec ), cv )], spec )
484- spec_loo = preproc(spec_loo )
492+ spec_loo = make_spec_loo(spec , cv )
485493
486494 regr_loo = ei_ridge(spec_loo , vcov = FALSE )
487495 riesz_loo = ei_riesz(spec_loo , penalty = regr_loo $ penalty )
0 commit comments