9191# ' if it is contained within \eqn{[0, 1]}, for instance, then the bounds will
9292# ' be `c(0, 1)`. The default `bounds = FALSE` uses an unbounded outcome.
9393# ' @param sum_one If `TRUE`, the outcome variables are constrained to sum to one.
94- # ' Can only apply when `bounds` are enforced and there are more than one
95- # ' outcome variables.
94+ # ' Can only apply when `bounds` are enforced and there is more than one
95+ # ' outcome variable. The default `NULL` infers `sum_one = TRUE` when the bounds
96+ # ' are `c(0, 1)` the outcome variables sum to 1.
9697# ' @param scale If `TRUE`, scale covariates `z` to have unit variance.
9798# ' @param vcov If `TRUE`, calculate and return the covariance matrix of the
9899# ' estimated coefficients. Ignored when `bounds` are provided.
120121# ' min(fitted(ei_ridge(spec)))
121122# ' min(fitted(ei_ridge(spec, bounds = 0:1)))
122123# ' @export
123- ei_ridge <- function (x , ... , weights , bounds = FALSE , sum_one = FALSE , penalty = NULL , scale = TRUE , vcov = TRUE ) {
124+ ei_ridge <- function (x , ... , weights , bounds = FALSE , sum_one = NULL , penalty = NULL , scale = TRUE , vcov = TRUE ) {
124125 UseMethod(" ei_ridge" )
125126}
126127
127128
128129# ' @export
129130# ' @rdname ei_ridge
130- ei_ridge.formula <- function (formula , data , weights , bounds = FALSE , sum_one = FALSE ,
131+ ei_ridge.formula <- function (formula , data , weights , bounds = FALSE , sum_one = NULL ,
131132 penalty = NULL , scale = TRUE , vcov = TRUE , ... ) {
132133 forms = ei_forms(formula )
133134 form_preds = terms(rlang :: new_formula(lhs = NULL , rhs = forms $ predictors ))
@@ -154,7 +155,7 @@ ei_ridge.formula <- function(formula, data, weights, bounds=FALSE, sum_one = FAL
154155
155156# ' @export
156157# ' @rdname ei_ridge
157- ei_ridge.ei_spec <- function (x , weights , bounds = FALSE , sum_one = FALSE , penalty = NULL ,
158+ ei_ridge.ei_spec <- function (x , weights , bounds = FALSE , sum_one = NULL , penalty = NULL ,
158159 scale = TRUE , vcov = TRUE , ... ) {
159160 spec = x
160161 validate_ei_spec(spec )
@@ -184,7 +185,7 @@ ei_ridge.ei_spec <- function(x, weights, bounds=FALSE, sum_one = FALSE, penalty=
184185
185186# ' @export
186187# ' @rdname ei_ridge
187- ei_ridge.data.frame <- function (x , y , z , weights , bounds = FALSE , sum_one = FALSE , penalty = NULL ,
188+ ei_ridge.data.frame <- function (x , y , z , weights , bounds = FALSE , sum_one = NULL , penalty = NULL ,
188189 scale = TRUE , vcov = TRUE , ... ) {
189190 if (length(both <- intersect(colnames(x ), colnames(z ))) > 0 ) {
190191 cli_abort(c(" Predictors and covariates must be distinct" ,
@@ -213,7 +214,7 @@ ei_ridge.data.frame <- function(x, y, z, weights, bounds=FALSE, sum_one = FALSE,
213214
214215# ' @export
215216# ' @rdname ei_ridge
216- ei_ridge.matrix <- function (x , y , z , weights , bounds = FALSE , sum_one = FALSE , penalty = NULL ,
217+ ei_ridge.matrix <- function (x , y , z , weights , bounds = FALSE , sum_one = NULL , penalty = NULL ,
217218 scale = TRUE , vcov = TRUE , ... ) {
218219 ei_ridge.data.frame(x , y , z , weights , penalty , sum_one , bounds , scale , vcov , ... )
219220}
@@ -275,6 +276,9 @@ ei_ridge_bridge <- function(processed, vcov, ...) {
275276 if (ncol(z ) == 0 ) {
276277 bp $ penalty = 0
277278 }
279+ if (is.null(bp $ sum_one ) && all(bp $ bounds == c(0 , 1 ))) {
280+ bp $ sum_one = isTRUE(all.equal(rowSums(y ), rep(1 , nrow(y ))))
281+ }
278282
279283 fit <- ei_ridge_impl(x , y , z , weights , bp $ bounds , bp $ sum_one , bp $ penalty , vcov )
280284
@@ -315,7 +319,7 @@ ei_ridge_bridge <- function(processed, vcov, ...) {
315319# ' @rdname ei-impl
316320# ' @export
317321ei_ridge_impl <- function (x , y , z , weights = rep(1 , nrow(x )),
318- bounds = c(- Inf , Inf ), sum_one = FALSE , penalty = NULL , vcov = TRUE ) {
322+ bounds = c(- Inf , Inf ), sum_one = NULL , penalty = NULL , vcov = TRUE ) {
319323 int_scale = if (! is.null(penalty ) && penalty == 0 ) 1 + 1e2 * sqrt(penalty ) else 1e4
320324 xz = row_kronecker(x , z , int_scale )
321325 sqrt_w = sqrt(weights / mean(weights ))
@@ -336,6 +340,7 @@ ei_ridge_impl <- function(x, y, z, weights=rep(1, nrow(x)),
336340 if (is.null(penalty )) {
337341 penalty = ridge_auto(udv , y , sqrt_w , FALSE )$ penalty
338342 }
343+
339344 ridge_bounds(xz , z , y , weights , bounds , sum_one , penalty )
340345 }
341346
@@ -415,7 +420,9 @@ print.ei_ridge <- function(x, ...) {
415420 nrow(x $ fitted ), " observations" )
416421 bounds = x $ blueprint $ bounds
417422 if (any(is.finite(bounds ))) {
418- cat_line(" With outcome bounded in (" , bounds [1 ], " , " , bounds [2 ], " )" )
423+ sumt1 = if (isTRUE(x $ blueprint $ sum_one )) " and constrained to sum to 1" else " "
424+ pl = if (ncol(m $ y ) > 1 ) " s" else " "
425+ cat_line(" With outcome" , pl , " bounded in (" , bounds [1 ], " , " , bounds [2 ], " )" , sumt1 )
419426 }
420427 cat_line(" Fit with penalty = " , signif(x $ penalty ))
421428}
0 commit comments