@@ -79,15 +79,15 @@ ridge_hat_naive <- function(
7979riesz_naive <- function (xz , p , total , weights , group = 1 , penalty = 0 ) {
8080 n_x = ncol(xz ) %/% (1L + p )
8181 use = c(group , n_x + p * (group - 1 ) + seq_len(p ))
82- Dz = crossprod(xz [, use ], total ) / mean(xz [, group ] * total )
82+ Dz = crossprod(xz [, use , drop = FALSE ], total ) / mean(xz [, group ] * total )
8383 Lambda = diag(rep(penalty , ncol(xz )))
8484
8585 XXinv = solve(crossprod(xz , weights * xz ) + Lambda )
86- xzAinv = xz %*% XXinv [, use ]
86+ xzAinv = xz %*% XXinv [, use , drop = FALSE ]
8787 alpha = c(xzAinv %*% Dz )
8888
8989 h1m = 1 - ridge_hat_naive(xz , weights , XXinv , penalty )
90- xzi = xz [, use ] * total / mean(xz [, group ] * total )
90+ xzi = xz [, use , drop = FALSE ] * total / mean(xz [, group ] * total )
9191 loo = c((alpha - rowSums(xzAinv * xzi )) / h1m )
9292
9393 list (alpha = alpha , loo = loo , nu2 = NA )
@@ -98,18 +98,19 @@ riesz_naive <- function(xz, p, total, weights, group=1, penalty=0) {
9898riesz_svd <- function (xz , udv , p , total , weights , sqrt_w , group = 1 , penalty = 0 ) {
9999 n_x = ncol(xz ) %/% (1L + p )
100100 use = c(group , n_x + p * (group - 1 ) + seq_len(p ))
101- Dz = colSums(xz [, use ] * total ) / mean(xz [, group ] * total )
101+ Dz = colSums(xz [, use , drop = FALSE ] * total ) / mean(xz [, group ] * total )
102102 d_pen = c(udv $ d / (udv $ d ^ 2 + penalty ))
103103
104- xzAinv = (udv $ u / sqrt_w ) %*% (d_pen * t(udv $ v [use , ]))
104+ xzAinv = (udv $ u / sqrt_w ) %*% (d_pen * t(udv $ v [use , , drop = FALSE ]))
105105 alpha = xzAinv %*% Dz
106106
107107 h1m = 1 - ridge_hat_svd(udv , penalty )
108- xzi = xz [, use ] * total / mean(xz [, group ] * total )
108+ xzi = xz [, use , drop = FALSE ] * total / mean(xz [, group ] * total )
109109 loo = rowSums(- xzAinv * shift_cols(xzi , Dz )) / h1m
110110
111111 # Neyman-orthogonal estimate of criterion fn
112- nu2 = sum(crossprod(Dz , udv $ v [use , ])^ 2 * (2 / (udv $ d ^ 2 + penalty ) - d_pen ^ 2 )) / nrow(xz )
112+ nu2 = sum(crossprod(Dz , udv $ v [use , , drop = FALSE ])^ 2 *
113+ (2 / (udv $ d ^ 2 + penalty ) - d_pen ^ 2 )) / nrow(xz )
113114
114115 list (alpha = alpha , loo = loo , nu2 = nu2 )
115116}
0 commit comments