Skip to content

Commit d9d4b62

Browse files
committed
fix no covs setting
1 parent 60fe535 commit d9d4b62

File tree

4 files changed

+23
-11
lines changed

4 files changed

+23
-11
lines changed

R/ei_ridge.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ ei_ridge.ei_spec <- function(x, weights, penalty=NULL, scale=TRUE, ...) {
131131

132132
form = as.formula(paste0(
133133
paste0(attr(spec, "ei_y"), collapse=" + "), " ~ ",
134-
paste0(attr(spec, "ei_x"), collapse=" + "), " + ",
135-
paste0(attr(spec, "ei_z"), collapse=" + ")
134+
paste0(c(attr(spec, "ei_x"), attr(spec, "ei_z")), collapse=" + ")
136135
))
137136

138137
bp = hardhat::new_default_formula_blueprint(

R/ei_riesz.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ ei_riesz.ei_spec <- function(x, weights, penalty, scale=TRUE, ...) {
7575

7676
form = as.formula(paste0(
7777
paste0(attr(spec, "ei_y"), collapse=" + "), " ~ ",
78-
paste0(attr(spec, "ei_x"), collapse=" + "), " + ",
79-
paste0(attr(spec, "ei_z"), collapse=" + ")
78+
paste0(c(attr(spec, "ei_x"), attr(spec, "ei_z")), collapse=" + ")
8079
))
8180

8281
bp = hardhat::new_default_formula_blueprint(

R/rr_impl.R

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,15 @@ ridge_hat_naive <- function(
7979
riesz_naive <- function(xz, p, total, weights, group=1, penalty=0) {
8080
n_x = ncol(xz) %/% (1L + p)
8181
use = c(group, n_x + p*(group-1) + seq_len(p))
82-
Dz = crossprod(xz[, use], total) / mean(xz[, group] * total)
82+
Dz = crossprod(xz[, use, drop=FALSE], total) / mean(xz[, group] * total)
8383
Lambda = diag(rep(penalty, ncol(xz)))
8484

8585
XXinv = solve(crossprod(xz, weights*xz) + Lambda)
86-
xzAinv = xz %*% XXinv[, use]
86+
xzAinv = xz %*% XXinv[, use, drop=FALSE]
8787
alpha = c(xzAinv %*% Dz)
8888

8989
h1m = 1 - ridge_hat_naive(xz, weights, XXinv, penalty)
90-
xzi = xz[, use] * total / mean(xz[, group] * total)
90+
xzi = xz[, use, drop=FALSE] * total / mean(xz[, group] * total)
9191
loo = c((alpha - rowSums(xzAinv * xzi)) / h1m)
9292

9393
list(alpha = alpha, loo = loo, nu2 = NA)
@@ -98,18 +98,19 @@ riesz_naive <- function(xz, p, total, weights, group=1, penalty=0) {
9898
riesz_svd <- function(xz, udv, p, total, weights, sqrt_w, group=1, penalty=0) {
9999
n_x = ncol(xz) %/% (1L + p)
100100
use = c(group, n_x + p*(group-1) + seq_len(p))
101-
Dz = colSums(xz[, use] * total) / mean(xz[, group] * total)
101+
Dz = colSums(xz[, use, drop=FALSE] * total) / mean(xz[, group] * total)
102102
d_pen = c(udv$d / (udv$d^2 + penalty))
103103

104-
xzAinv = (udv$u / sqrt_w) %*% (d_pen * t(udv$v[use, ]))
104+
xzAinv = (udv$u / sqrt_w) %*% (d_pen * t(udv$v[use, , drop=FALSE]))
105105
alpha = xzAinv %*% Dz
106106

107107
h1m = 1 - ridge_hat_svd(udv, penalty)
108-
xzi = xz[, use] * total / mean(xz[, group] * total)
108+
xzi = xz[, use, drop=FALSE] * total / mean(xz[, group] * total)
109109
loo = rowSums(-xzAinv * shift_cols(xzi, Dz)) / h1m
110110

111111
# Neyman-orthogonal estimate of criterion fn
112-
nu2 = sum(crossprod(Dz, udv$v[use, ])^2 * (2/(udv$d^2 + penalty) - d_pen^2)) / nrow(xz)
112+
nu2 = sum(crossprod(Dz, udv$v[use, , drop=FALSE])^2 *
113+
(2/(udv$d^2 + penalty) - d_pen^2)) / nrow(xz)
113114

114115
list(alpha = alpha, loo = loo, nu2 = nu2)
115116
}

tests/testthat/test-ei_est.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,16 @@ test_that("Shrinkage has the expected effect", {
5050
expect_lt(mean(est_w$std.error), mean(est_w0$std.error))
5151
# expect_lt(mean(est_d$std.error), mean(est_d0$std.error))
5252
})
53+
54+
test_that("Estimation methods work with no predictors", {
55+
spec = ei_spec(elec_1968, vap_white:vap_other, pres_dem_hum:pres_oth,
56+
total = pres_total)
57+
m = ei_ridge(spec)
58+
rr = ei_riesz(spec, penalty=m$penalty)
59+
60+
expect_no_error({
61+
est_p = ei_est(m, data=spec)
62+
est_w = ei_est(rr, data=spec)
63+
est_d = ei_est(m, rr, data=spec)
64+
})
65+
})

0 commit comments

Comments
 (0)