Skip to content

Commit f108263

Browse files
committed
tidy implementation; tests, printing; etc.
1 parent cd1261f commit f108263

File tree

5 files changed

+63
-32
lines changed

5 files changed

+63
-32
lines changed

R/ei_ridge.R

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@
9191
#' if it is contained within \eqn{[0, 1]}, for instance, then the bounds will
9292
#' be `c(0, 1)`. The default `bounds = FALSE` uses an unbounded outcome.
9393
#' @param sum_one If `TRUE`, the outcome variables are constrained to sum to one.
94-
#' Can only apply when `bounds` are enforced and there are more than one
95-
#' outcome variables.
94+
#' Can only apply when `bounds` are enforced and there is more than one
95+
#' outcome variable. The default `NULL` infers `sum_one = TRUE` when the bounds
96+
#' are `c(0, 1)` the outcome variables sum to 1.
9697
#' @param scale If `TRUE`, scale covariates `z` to have unit variance.
9798
#' @param vcov If `TRUE`, calculate and return the covariance matrix of the
9899
#' estimated coefficients. Ignored when `bounds` are provided.
@@ -120,14 +121,14 @@
120121
#' min(fitted(ei_ridge(spec)))
121122
#' min(fitted(ei_ridge(spec, bounds = 0:1)))
122123
#' @export
123-
ei_ridge <- function(x, ..., weights, bounds = FALSE, sum_one = FALSE, penalty = NULL, scale = TRUE, vcov = TRUE) {
124+
ei_ridge <- function(x, ..., weights, bounds = FALSE, sum_one = NULL, penalty = NULL, scale = TRUE, vcov = TRUE) {
124125
UseMethod("ei_ridge")
125126
}
126127

127128

128129
#' @export
129130
#' @rdname ei_ridge
130-
ei_ridge.formula <- function(formula, data, weights, bounds=FALSE, sum_one = FALSE,
131+
ei_ridge.formula <- function(formula, data, weights, bounds=FALSE, sum_one = NULL,
131132
penalty=NULL, scale=TRUE, vcov=TRUE, ...) {
132133
forms = ei_forms(formula)
133134
form_preds = terms(rlang::new_formula(lhs=NULL, rhs=forms$predictors))
@@ -154,7 +155,7 @@ ei_ridge.formula <- function(formula, data, weights, bounds=FALSE, sum_one = FAL
154155

155156
#' @export
156157
#' @rdname ei_ridge
157-
ei_ridge.ei_spec <- function(x, weights, bounds=FALSE, sum_one = FALSE, penalty=NULL,
158+
ei_ridge.ei_spec <- function(x, weights, bounds=FALSE, sum_one = NULL, penalty=NULL,
158159
scale=TRUE, vcov=TRUE, ...) {
159160
spec = x
160161
validate_ei_spec(spec)
@@ -184,7 +185,7 @@ ei_ridge.ei_spec <- function(x, weights, bounds=FALSE, sum_one = FALSE, penalty=
184185

185186
#' @export
186187
#' @rdname ei_ridge
187-
ei_ridge.data.frame <- function(x, y, z, weights, bounds=FALSE, sum_one = FALSE, penalty=NULL,
188+
ei_ridge.data.frame <- function(x, y, z, weights, bounds=FALSE, sum_one = NULL, penalty=NULL,
188189
scale=TRUE, vcov=TRUE, ...) {
189190
if (length(both <- intersect(colnames(x), colnames(z))) > 0) {
190191
cli_abort(c("Predictors and covariates must be distinct",
@@ -213,7 +214,7 @@ ei_ridge.data.frame <- function(x, y, z, weights, bounds=FALSE, sum_one = FALSE,
213214

214215
#' @export
215216
#' @rdname ei_ridge
216-
ei_ridge.matrix <- function(x, y, z, weights, bounds=FALSE, sum_one = FALSE, penalty=NULL,
217+
ei_ridge.matrix <- function(x, y, z, weights, bounds=FALSE, sum_one = NULL, penalty=NULL,
217218
scale=TRUE, vcov=TRUE, ...) {
218219
ei_ridge.data.frame(x, y, z, weights, penalty, sum_one, bounds, scale, vcov, ...)
219220
}
@@ -275,6 +276,9 @@ ei_ridge_bridge <- function(processed, vcov, ...) {
275276
if (ncol(z) == 0) {
276277
bp$penalty = 0
277278
}
279+
if (is.null(bp$sum_one) && all(bp$bounds == c(0, 1))) {
280+
bp$sum_one = isTRUE(all.equal(rowSums(y), rep(1, nrow(y))))
281+
}
278282

279283
fit <- ei_ridge_impl(x, y, z, weights, bp$bounds, bp$sum_one, bp$penalty, vcov)
280284

@@ -315,7 +319,7 @@ ei_ridge_bridge <- function(processed, vcov, ...) {
315319
#' @rdname ei-impl
316320
#' @export
317321
ei_ridge_impl <- function(x, y, z, weights=rep(1, nrow(x)),
318-
bounds=c(-Inf, Inf), sum_one=FALSE, penalty=NULL, vcov=TRUE) {
322+
bounds=c(-Inf, Inf), sum_one=NULL, penalty=NULL, vcov=TRUE) {
319323
int_scale = if (!is.null(penalty) && penalty == 0) 1 + 1e2*sqrt(penalty) else 1e4
320324
xz = row_kronecker(x, z, int_scale)
321325
sqrt_w = sqrt(weights / mean(weights))
@@ -336,6 +340,7 @@ ei_ridge_impl <- function(x, y, z, weights=rep(1, nrow(x)),
336340
if (is.null(penalty)) {
337341
penalty = ridge_auto(udv, y, sqrt_w, FALSE)$penalty
338342
}
343+
339344
ridge_bounds(xz, z, y, weights, bounds, sum_one, penalty)
340345
}
341346

@@ -415,7 +420,9 @@ print.ei_ridge <- function(x, ...) {
415420
nrow(x$fitted), " observations")
416421
bounds = x$blueprint$bounds
417422
if (any(is.finite(bounds))) {
418-
cat_line("With outcome bounded in (", bounds[1], ", ", bounds[2], ")")
423+
sumt1 = if (isTRUE(x$blueprint$sum_one)) " and constrained to sum to 1" else ""
424+
pl = if (ncol(m$y) > 1) "s" else ""
425+
cat_line("With outcome", pl, " bounded in (", bounds[1], ", ", bounds[2], ")", sumt1)
419426
}
420427
cat_line("Fit with penalty = ", signif(x$penalty))
421428
}

R/rr_impl.R

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,19 @@ ridge_bounds <- function(xz, z, y, weights, bounds, sum_one=FALSE, penalty=0) {
122122
cli_abort("{.fn ridge_bounds} requires at least one finite bound.")
123123
}
124124

125+
fit_err = \(e) {
126+
cli_abort(c(
127+
"Constrained ridge regression failed with inconsistent constraints.",
128+
">" = "Try setting {.arg sum_one=FALSE} or relaxing the bounds."
129+
), call = NULL)
130+
}
125131
if (isFALSE(sum_one)) {
126132
coefs = matrix(nrow = nrow(dvecs), ncol = ncol(dvecs))
127133
for (i in seq_len(n_y)) {
128-
fit = quadprog::solve.QP.compact(R, dvecs[, i], Amat, Aind, bvec, factorized = TRUE)
134+
fit = tryCatch(
135+
quadprog::solve.QP.compact(R, dvecs[, i], Amat, Aind, bvec, factorized = TRUE),
136+
error = fit_err
137+
)
129138
coefs[, i] = fit$solution
130139
}
131140
} else {
@@ -149,14 +158,17 @@ ridge_bounds <- function(xz, z, y, weights, bounds, sum_one=FALSE, penalty=0) {
149158
}
150159
bvec_y = c(rep(1, n * n_x), rep(1, n_y) %x% bvec)
151160

152-
fit = quadprog::solve.QP.compact(
153-
R_y,
154-
c(dvecs),
155-
Amat_y,
156-
Aind_y,
157-
bvec_y,
158-
meq = n * n_x,
159-
factorized = TRUE
161+
do_fit = function(eq) {
162+
quadprog::solve.QP.compact(R_y, c(dvecs), Amat_y, Aind_y, bvec_y, meq = eq, factorized = TRUE)
163+
}
164+
165+
# relax to inequality constraint if sum-to-one fails
166+
fit <- tryCatch(
167+
do_fit(n * n_x),
168+
error = \(e_outer) {
169+
cli_warn("Relaxing sum-to-one constraint to inequality to achieve feasible solution.", call=NULL)
170+
tryCatch(do_fit(0), error = fit_err)
171+
}
160172
)
161173
coefs = matrix(fit$solution, nrow = nrow(dvecs), ncol = ncol(dvecs))
162174
}

man/ei-impl.Rd

Lines changed: 4 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/ei_ridge.Rd

Lines changed: 8 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-ridge.R

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,17 +92,27 @@ test_that("leave-one-out shortcut is correct for Riesz regression", {
9292
expect_equal(fit_naive$loo, loo_act, tolerance = 0.2)
9393
})
9494

95-
test_that("ridge bounds work", {
95+
test_that("ridge constraints work", {
9696
d = elec_1968
97-
form = pres_rep_nix ~ vap_white + vap_black + vap_other | state + pop_city +
97+
form = pres_dem_hum + pres_rep_nix + pres_ind_wal + pres_abs ~ vap_white |
9898
pop_urban + pop_rural + farm + educ_elem + educ_hsch + educ_coll +
9999
inc_00_03k + inc_03_08k + inc_08_25k + inc_25_99k + log(pop) + pres_turn
100100

101101
m = ei_ridge(form, data=elec_1968)
102-
m01 = ei_ridge(form, data=elec_1968, bounds=0:1)
102+
m01 = ei_ridge(form, data=elec_1968, bounds=0:1, sum_one=FALSE)
103+
m01s = ei_ridge(form, data=elec_1968, bounds=c(0, 1), sum_one=TRUE)
104+
m01def = ei_ridge(form, data=elec_1968, bounds=NULL, sum_one=NULL)
103105

104106
expect_true(min(fitted(m)) < 0)
105-
expect_true(min(fitted(m01)) > 0)
107+
expect_true(min(fitted(m01)) > -.Machine$double.eps)
106108
expect_true(all(ei_est(m01, data=elec_1968, total=pres_total)$estimate > 0))
107109
expect_true(all(ei_est(m01, data=elec_1968, total=pres_total)$estimate < 1))
110+
111+
expect_true(min(fitted(m01s)) > -.Machine$double.eps)
112+
expect_true(all(ei_est(m01s, data=elec_1968, total=pres_total)$estimate > 0))
113+
expect_true(all(ei_est(m01s, data=elec_1968, total=pres_total)$estimate < 1))
114+
115+
tots = rowSums(as.matrix(ei_est(m01s, data=elec_1968, total=pres_total)))
116+
expect_true(all.equal(tots, c(vap_white=1, .other=1)))
117+
expect_identical(m01def, m01s) # check defaults infer correctly
108118
})

0 commit comments

Comments
 (0)