Skip to content

Commit ea59650

Browse files
committed
local sum-to-1; more robust sum-to-1 ridge
1 parent f108263 commit ea59650

File tree

8 files changed

+183
-70
lines changed

8 files changed

+183
-70
lines changed

R/ei_est.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,12 @@ est_check_regr = function(regr, data, n, xcols, n_y, sd = FALSE) {
327327

328328
preds = list()
329329
sds = if (sd) matrix(nrow = n, ncol = n_x^2) else NULL
330+
if (sd && is.null(regr$vcov_u)) {
331+
cli_abort(c(
332+
"Standard errors not available for this {.arg regr} object.",
333+
">"="Call {.fn ei_ridge} with {.arg vcov = TRUE} to enable."
334+
), call=parent.frame())
335+
}
330336
for (group in seq_along(xcols)) {
331337
use = c(group, n_x + p*(group-1) + seq_len(p))
332338
preds[[xcols[group]]] = z %*% regr$coef[use, ]

R/ei_local.R

Lines changed: 108 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
#' violations of the accounting identity. If `bounds = NULL`, they will be
2626
#' inferred from the outcome variable: if it is contained within \eqn{[0, 1]},
2727
#' for instance, then the bounds will be `c(0, 1)`. Setting `bounds = FALSE`
28-
#' forces unbounded estimates.
28+
#' forces unbounded estimates. The default uses the `bounds` attribute of
29+
#' `regr`, if available, or infers from the outcome variable otherwise.
30+
#' @inheritParams ei_ridge
2931
#' @param conf_level A numeric specifying the level for confidence intervals.
3032
#' If `FALSE` (the default), no confidence intervals are calculated.
3133
#' For `regr` arguments from [ei_wrap_model()], confidence intervals will not
@@ -50,24 +52,41 @@
5052
#' suppressWarnings(ei_est_local(m, spec, bounds=c(0.01, 0.2)))
5153
#' }
5254
#' @export
53-
ei_est_local = function(regr, data, r_cov=NULL, bounds=NULL, conf_level=FALSE, unimodal=TRUE) {
55+
ei_est_local = function(
56+
regr,
57+
data,
58+
r_cov = NULL,
59+
bounds = regr$blueprint$bounds,
60+
sum_one = NULL,
61+
conf_level = FALSE,
62+
unimodal = TRUE
63+
) {
5464
y = est_check_outcome(regr, data, NULL)
5565
n = nrow(y)
5666
n_y = ncol(y)
5767

58-
cli_warn("Local confidence intervals do not yet incorporate prediction uncertainty.",
59-
.frequency="regularly", .frequency_id="ei_est_local_temp")
68+
cli_warn(
69+
"Local confidence intervals do not yet incorporate prediction uncertainty.",
70+
.frequency = "regularly",
71+
.frequency_id = "ei_est_local_temp"
72+
)
6073

6174
rl = est_check_regr(regr, data, n, NULL, n_y, sd = TRUE)
6275
rl <<- rl
6376
n_x = length(rl$preds)
6477
if (inherits(regr, "ei_wrapped") && !isFALSE(conf_level)) {
65-
cli_warn("Local confidence intervals with wrapped model objects
78+
cli_warn(
79+
"Local confidence intervals with wrapped model objects
6680
do not incorporate prediction uncertainty.",
67-
.frequency="regularly", .frequency_id="ei_est_local")
81+
.frequency = "regularly",
82+
.frequency_id = "ei_est_local"
83+
)
6884
}
6985

70-
bounds = ei_bounds(bounds, y)
86+
bounds = ei_bounds(bounds, y, clamp = 1e-8)
87+
if (is.null(sum_one) && all(bounds == c(0, 1))) {
88+
sum_one = isTRUE(all.equal(rowSums(y), rep(1, nrow(y))))
89+
}
7190

7291
# Process r_cov; TODO: heteroskedastic model
7392
if (is.null(r_cov)) {
@@ -86,25 +105,32 @@ ei_est_local = function(regr, data, r_cov=NULL, bounds=NULL, conf_level=FALSE, u
86105
r_cov = lapply(r_cov, chol)
87106

88107
ests = list()
108+
ests[[k]] =
109+
eta = do.call(cbind, rl$preds)
110+
eps = y - rl$yhat
111+
R_cov = diag(n_x * n_y)
89112
for (k in seq_len(n_y)) {
90-
eta = vapply(rl$preds, function(p) p[, k], numeric(n))
91-
eta <<- eta
92-
eta_proj = local_proj(rl$x, eta, y[, k] - rl$yhat[, k], r_cov[[k]], bounds)
93-
eta_proj <<- eta_proj
94-
95-
ests[[k]] = tibble::new_tibble(list(
96-
.row = rep(seq_len(n), n_x),
97-
predictor = rep(colnames(rl$x), each=n),
98-
outcome = rep(colnames(y)[k], n * n_x),
99-
estimate = c(eta_proj),
100-
std.error = NA #sqrt(c(proj[[2]]))
101-
), class="ei_est_local")
113+
idx = (k - 1) * n_x + seq_len(n_x)
114+
R_cov[idx, idx] = r_cov[[k]]
102115
}
103-
104-
ests = do.call(rbind, ests)
116+
eta_proj = local_proj(rl$x, eta, eps, R_cov, bounds, sum_one)
117+
ests = lapply(seq_len(n_y), function(k) {
118+
tibble::new_tibble(
119+
list(
120+
.row = rep(seq_len(n), n_x),
121+
predictor = rep(colnames(rl$x), each = n),
122+
outcome = rep(colnames(y)[k], n * n_x),
123+
estimate = c(eta_proj[, k + seq(0, by=n_y, length.out=n_x)]),
124+
std.error = NA #sqrt(c(proj[[2]]))
125+
),
126+
class = "ei_est_local"
127+
)
128+
}) |>
129+
do.call(rbind, args = _)
130+
attr(ests, "proj_misses") = attr(eta_proj, "misses")
105131

106132
if (!isFALSE(conf_level)) {
107-
fac = if (isTRUE(unimodal)) 4/9 else 1
133+
fac = if (isTRUE(unimodal)) 4 / 9 else 1
108134
chebyshev = sqrt(fac / (1 - conf_level))
109135
ests$conf.low = ests$estimate - chebyshev * ests$std.error
110136
ests$conf.high = ests$estimate + chebyshev * ests$std.error
@@ -135,38 +161,78 @@ as.array.ei_est_local = function(x, ...) {
135161

136162
# Solve QP to project estimates onto tomography plane and into bounds
137163
# Not the fastest possible implementation (pure C++ would be better), but fast enough
138-
local_proj = function(x, eta, eps, r_cov, bounds) {
164+
local_proj = function(x, eta, eps, r_cov, bounds, sum_one) {
139165
n = nrow(eta)
140166
n_x = ncol(x)
141-
eta_diff = matrix(nrow = n, ncol = n_x)
142-
143-
zeros = rep(0, n_x)
144-
Amat = cbind(zeros)
145-
b0 = cbind(eps)
167+
n_y = ncol(eps)
168+
sum_one = isTRUE(sum_one)
169+
eta_diff = matrix(nrow = n, ncol = n_x * n_y)
170+
171+
# avoid overflow
172+
r_cov = r_cov / sqrt(norm(crossprod(r_cov), "2"))
173+
174+
# parameters are the displacement in each estimate
175+
# (x1y1, x1y2, x1y3, x2y1, x2y2, x2y3, ...)
176+
# minimize overall displacement st x-weighted displacement = residual
177+
# and (optionally) bounds and sum-to-one constraints are satisfied
178+
zeros = rep(0, n_x * n_y)
179+
Amat = matrix(0, nrow = n_x * n_y, ncol = n_y * 2) # i-specific, filled later
180+
b0 = cbind(eps, -eps)
181+
if (sum_one) {
182+
if (n_y == 1 || all(bounds == c(-Inf, Inf))) {
183+
cli_abort(
184+
"Using{.arg sum_one} requires multiple bounded outcomes.",
185+
call = parent.frame()
186+
)
187+
}
188+
rs_mat = diag(n_x) %x% rep(1, n_y)
189+
Amat = cbind(rs_mat, Amat)
190+
b0 = cbind(1 - eta %*% rs_mat, b0)
191+
}
146192
if (!is.infinite(bounds[1])) {
147-
Amat = cbind(Amat, diag(n_x))
193+
Amat = cbind(Amat, diag(n_x * n_y))
148194
b0 = cbind(b0, bounds[1] - eta)
149195
}
150196
if (!is.infinite(bounds[2])) {
151-
Amat = cbind(Amat, -diag(n_x))
197+
Amat = cbind(Amat, -diag(n_x * n_y))
152198
b0 = cbind(b0, -bounds[2] + eta)
153199
}
154200

201+
idx_eps = sum_one * n_x + seq_len(2*n_y)
202+
patt_eps = cbind(diag(n_y), -diag(n_y))
203+
204+
constr_pt = function(Dmat, bvec, tol) {
205+
bvec[idx_eps] = bvec[idx_eps] - tol
206+
quadprog::solve.QP(
207+
Dmat = Dmat, # distance metric
208+
dvec = zeros,
209+
Amat = Amat,
210+
bvec = bvec,
211+
meq = sum_one * n_x,
212+
factorized = TRUE
213+
)$solution
214+
}
215+
216+
misses = integer(0)
155217
for (i in seq_len(n)) {
156-
Amat[, 1] = x[i, ]
157-
eta_diff[i, ] = tryCatch({
158-
quadprog::solve.QP(
159-
Dmat = r_cov,
160-
dvec = zeros,
161-
Amat = Amat,
162-
bvec = b0[i, ],
163-
meq = 1,
164-
factorized = TRUE
165-
)$solution
166-
}, error = \(e) eps[i])
218+
Amat[, idx_eps] = x[i, ] %x% patt_eps
219+
tol = 1e-12
220+
repeat {
221+
ans = tryCatch(constr_pt(r_cov, b0[i, ], tol), error = \(e) NULL)
222+
if (!is.null(ans)) break
223+
if (tol > 0.005) {
224+
misses <<- c(misses, i)
225+
ans = rep(eps[i, ], n_x)
226+
break
227+
}
228+
tol = tol * 1000
229+
}
230+
eta_diff[i, ] = ans
167231
}
168232

169-
eta + eta_diff
233+
out = eta + eta_diff
234+
attr(out, "misses") = misses
235+
out
170236
}
171237

172238
local_basis = function(x) {

R/ei_ridge.R

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@
9696
#' are `c(0, 1)` the outcome variables sum to 1.
9797
#' @param scale If `TRUE`, scale covariates `z` to have unit variance.
9898
#' @param vcov If `TRUE`, calculate and return the covariance matrix of the
99-
#' estimated coefficients. Ignored when `bounds` are provided.
99+
#' estimated coefficients. When `bounds` are provided, the covariance matrix
100+
#' for the unbounded estimate is returned as a conservative approximation.
100101
#' @param ... Not currently used, but required for extensibility.
101102
#'
102103
#' @returns An `ei_ridge` object, which supports various [ridge-methods].
@@ -327,21 +328,27 @@ ei_ridge_impl <- function(x, y, z, weights=rep(1, nrow(x)),
327328

328329
vcov = isTRUE(vcov)
329330
enforce = is.finite(bounds)
330-
fit = if (!any(enforce)) { # unbounded
331+
if (!any(enforce)) { # unbounded
331332
if (isTRUE(sum_one)) {
332333
cli_abort("{.fn ei_ridge} cannot enforce sum-to-one constraint when outcome is unbounded.")
333334
}
334-
if (is.null(penalty)) {
335+
fit = if (is.null(penalty)) {
335336
ridge_auto(udv, y, sqrt_w, vcov)
336337
} else {
337338
ridge_svd(udv, y, sqrt_w, penalty, vcov)
338339
}
339340
} else {
341+
if (is.null(penalty) || vcov) {
342+
unb_fit = ridge_auto(udv, y, sqrt_w, vcov)
343+
}
340344
if (is.null(penalty)) {
341-
penalty = ridge_auto(udv, y, sqrt_w, FALSE)$penalty
345+
penalty = unb_fit$penalty
342346
}
343347

344-
ridge_bounds(xz, z, y, weights, bounds, sum_one, penalty)
348+
fit = ridge_bounds(xz, z, y, weights, bounds, sum_one, penalty)
349+
if (vcov) {
350+
fit$vcov_u = unb_fit$vcov_u
351+
}
345352
}
346353

347354
rownames(fit$coef) = colnames(xz)

R/rr_impl.R

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,23 @@ ridge_bounds <- function(xz, z, y, weights, bounds, sum_one=FALSE, penalty=0) {
163163
}
164164

165165
# relax to inequality constraint if sum-to-one fails
166-
fit <- tryCatch(
167-
do_fit(n * n_x),
168-
error = \(e_outer) {
169-
cli_warn("Relaxing sum-to-one constraint to inequality to achieve feasible solution.", call=NULL)
170-
tryCatch(do_fit(0), error = fit_err)
166+
eq_constr = n * n_x
167+
repeat {
168+
fit = tryCatch(do_fit(eq_constr), error = \(e) NULL)
169+
if (!is.null(fit)) break
170+
if (eq_constr > 0) {
171+
eq_constr = max(eq_constr - n, 0) # reduce by one group
172+
} else {
173+
fit_err()
174+
break
171175
}
172-
)
176+
}
177+
if (eq_constr < n * n_x) {
178+
cli_warn(
179+
"Relaxing {n * n_x - eq_constr} sum-to-one constraint{?s} to inequality to achieve feasible solution.",
180+
call = NULL
181+
)
182+
}
173183
coefs = matrix(fit$solution, nrow = nrow(dvecs), ncol = ncol(dvecs))
174184
}
175185

explore/local.R

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,40 @@ devtools::load_all(".")
22
library(tidyverse)
33

44
data(elec_1968)
5-
elec_1968$vap_nonwhite = 1 - elec_1968$vap_white
5+
elec_1968 = elec_1968 |>
6+
mutate(vap_nonwhite = 1 - vap_white, pres_abs = pmax(1e-6, pres_abs)) |>
7+
ei_proportions(pres_dem_hum, pres_rep_nix, pres_ind_wal, pres_abs, clamp = 1e-12) |>
8+
select(-.total)
69

7-
spec = ei_spec(elec_1968, c(vap_white, vap_nonwhite), pres_ind_wal,
10+
spec = ei_spec(elec_1968, c(vap_white, vap_black, vap_other), c(pres_dem_hum, pres_rep_nix, pres_ind_wal, pres_abs),
811
total = pres_total, covariates = c(state, pop_urban, pop_rural, educ_elem:educ_coll, farm, inc_00_03k:inc_25_99k))
912

10-
m = ei_ridge(spec)
13+
m = ei_ridge(spec, bounds=0:1, sum_one=F)
14+
m = ei_ridge(spec, bounds=F)
1115
rr = ei_riesz(spec, penalty = m$penalty)
1216

1317
# mean(c(y - rowSums(eta * x)) * weights(rr)[, 2])
14-
wx = x * weights(spec) / rep(colMeans(x * weights(spec)), each=n)
15-
16-
ei_est(m, data = spec)
17-
ei_est(m, rr, data = spec)
18-
eif = eta_proj * wx
19-
est = colMeans(eif)
20-
vcov = crossprod(shift_cols(eif, est)) / (n - 1)^2
21-
cbind(estimate=est, std.error=sqrt(diag(vcov)))
22-
23-
ei_est_local(m, spec, conf_level = 0.95)
18+
# wx = x * weights(spec) / rep(colMeans(x * weights(spec)), each=n)
19+
20+
ei_est(m, data = spec) |>
21+
summarize(err = sum(estimate) - 1, .by = predictor)
22+
ei_est(m, rr, data = spec) |>
23+
summarize(err = sum(estimate) - 1, .by = predictor)
24+
25+
# eif = eta_proj * wx
26+
# est = colMeans(eif)
27+
# vcov = crossprod(shift_cols(eif, est)) / (n - 1)^2
28+
# cbind(estimate=est, std.error=sqrt(diag(vcov)))
29+
30+
ei_est_local(m, spec, conf_level = 0.95, bounds=c(0, 1), sum_one = F) |>
31+
# ei_est_local(m, spec, conf_level = 0.95, bounds=F, sum_one = F) |>
32+
(\(x) { print(attr(x, "proj_misses")); x })() |>
33+
# dplyr::filter(estimate < -1e-6 | estimate > 1)
34+
# print()
35+
summarize(err = sum(estimate) - 1, .by = c(.row, predictor)) |>
36+
# arrange(-err)
37+
pull() |>
38+
hist()
2439

2540
k = 1
2641
n = nrow(spec)

man/ei-impl.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ei_est_local.Rd

Lines changed: 9 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ei_ridge.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)