2525# ' violations of the accounting identity. If `bounds = NULL`, they will be
2626# ' inferred from the outcome variable: if it is contained within \eqn{[0, 1]},
2727# ' for instance, then the bounds will be `c(0, 1)`. Setting `bounds = FALSE`
28- # ' forces unbounded estimates.
28+ # ' forces unbounded estimates. The default uses the `bounds` attribute of
29+ # ' `regr`, if available, or infers from the outcome variable otherwise.
30+ # ' @inheritParams ei_ridge
2931# ' @param conf_level A numeric specifying the level for confidence intervals.
3032# ' If `FALSE` (the default), no confidence intervals are calculated.
3133# ' For `regr` arguments from [ei_wrap_model()], confidence intervals will not
5052# ' suppressWarnings(ei_est_local(m, spec, bounds=c(0.01, 0.2)))
5153# ' }
5254# ' @export
53- ei_est_local = function (regr , data , r_cov = NULL , bounds = NULL , conf_level = FALSE , unimodal = TRUE ) {
55+ ei_est_local = function (
56+ regr ,
57+ data ,
58+ r_cov = NULL ,
59+ bounds = regr $ blueprint $ bounds ,
60+ sum_one = NULL ,
61+ conf_level = FALSE ,
62+ unimodal = TRUE
63+ ) {
5464 y = est_check_outcome(regr , data , NULL )
5565 n = nrow(y )
5666 n_y = ncol(y )
5767
58- cli_warn(" Local confidence intervals do not yet incorporate prediction uncertainty." ,
59- .frequency = " regularly" , .frequency_id = " ei_est_local_temp" )
68+ cli_warn(
69+ " Local confidence intervals do not yet incorporate prediction uncertainty." ,
70+ .frequency = " regularly" ,
71+ .frequency_id = " ei_est_local_temp"
72+ )
6073
6174 rl = est_check_regr(regr , data , n , NULL , n_y , sd = TRUE )
6275 rl <<- rl
6376 n_x = length(rl $ preds )
6477 if (inherits(regr , " ei_wrapped" ) && ! isFALSE(conf_level )) {
65- cli_warn(" Local confidence intervals with wrapped model objects
78+ cli_warn(
79+ " Local confidence intervals with wrapped model objects
6680 do not incorporate prediction uncertainty." ,
67- .frequency = " regularly" , .frequency_id = " ei_est_local" )
81+ .frequency = " regularly" ,
82+ .frequency_id = " ei_est_local"
83+ )
6884 }
6985
70- bounds = ei_bounds(bounds , y )
86+ bounds = ei_bounds(bounds , y , clamp = 1e-8 )
87+ if (is.null(sum_one ) && all(bounds == c(0 , 1 ))) {
88+ sum_one = isTRUE(all.equal(rowSums(y ), rep(1 , nrow(y ))))
89+ }
7190
7291 # Process r_cov; TODO: heteroskedastic model
7392 if (is.null(r_cov )) {
@@ -86,25 +105,32 @@ ei_est_local = function(regr, data, r_cov=NULL, bounds=NULL, conf_level=FALSE, u
86105 r_cov = lapply(r_cov , chol )
87106
88107 ests = list ()
108+ ests [[k ]] =
109+ eta = do.call(cbind , rl $ preds )
110+ eps = y - rl $ yhat
111+ R_cov = diag(n_x * n_y )
89112 for (k in seq_len(n_y )) {
90- eta = vapply(rl $ preds , function (p ) p [, k ], numeric (n ))
91- eta <<- eta
92- eta_proj = local_proj(rl $ x , eta , y [, k ] - rl $ yhat [, k ], r_cov [[k ]], bounds )
93- eta_proj <<- eta_proj
94-
95- ests [[k ]] = tibble :: new_tibble(list (
96- .row = rep(seq_len(n ), n_x ),
97- predictor = rep(colnames(rl $ x ), each = n ),
98- outcome = rep(colnames(y )[k ], n * n_x ),
99- estimate = c(eta_proj ),
100- std.error = NA # sqrt(c(proj[[2]]))
101- ), class = " ei_est_local" )
113+ idx = (k - 1 ) * n_x + seq_len(n_x )
114+ R_cov [idx , idx ] = r_cov [[k ]]
102115 }
103-
104- ests = do.call(rbind , ests )
116+ eta_proj = local_proj(rl $ x , eta , eps , R_cov , bounds , sum_one )
117+ ests = lapply(seq_len(n_y ), function (k ) {
118+ tibble :: new_tibble(
119+ list (
120+ .row = rep(seq_len(n ), n_x ),
121+ predictor = rep(colnames(rl $ x ), each = n ),
122+ outcome = rep(colnames(y )[k ], n * n_x ),
123+ estimate = c(eta_proj [, k + seq(0 , by = n_y , length.out = n_x )]),
124+ std.error = NA # sqrt(c(proj[[2]]))
125+ ),
126+ class = " ei_est_local"
127+ )
128+ }) | >
129+ do.call(rbind , args = _)
130+ attr(ests , " proj_misses" ) = attr(eta_proj , " misses" )
105131
106132 if (! isFALSE(conf_level )) {
107- fac = if (isTRUE(unimodal )) 4 / 9 else 1
133+ fac = if (isTRUE(unimodal )) 4 / 9 else 1
108134 chebyshev = sqrt(fac / (1 - conf_level ))
109135 ests $ conf.low = ests $ estimate - chebyshev * ests $ std.error
110136 ests $ conf.high = ests $ estimate + chebyshev * ests $ std.error
@@ -135,38 +161,78 @@ as.array.ei_est_local = function(x, ...) {
135161
136162# Solve QP to project estimates onto tomography plane and into bounds
137163# Not the fastest possible implementation (pure C++ would be better), but fast enough
138- local_proj = function (x , eta , eps , r_cov , bounds ) {
164+ local_proj = function (x , eta , eps , r_cov , bounds , sum_one ) {
139165 n = nrow(eta )
140166 n_x = ncol(x )
141- eta_diff = matrix (nrow = n , ncol = n_x )
142-
143- zeros = rep(0 , n_x )
144- Amat = cbind(zeros )
145- b0 = cbind(eps )
167+ n_y = ncol(eps )
168+ sum_one = isTRUE(sum_one )
169+ eta_diff = matrix (nrow = n , ncol = n_x * n_y )
170+
171+ # avoid overflow
172+ r_cov = r_cov / sqrt(norm(crossprod(r_cov ), " 2" ))
173+
174+ # parameters are the displacement in each estimate
175+ # (x1y1, x1y2, x1y3, x2y1, x2y2, x2y3, ...)
176+ # minimize overall displacement st x-weighted displacement = residual
177+ # and (optionally) bounds and sum-to-one constraints are satisfied
178+ zeros = rep(0 , n_x * n_y )
179+ Amat = matrix (0 , nrow = n_x * n_y , ncol = n_y * 2 ) # i-specific, filled later
180+ b0 = cbind(eps , - eps )
181+ if (sum_one ) {
182+ if (n_y == 1 || all(bounds == c(- Inf , Inf ))) {
183+ cli_abort(
184+ " Using{.arg sum_one} requires multiple bounded outcomes." ,
185+ call = parent.frame()
186+ )
187+ }
188+ rs_mat = diag(n_x ) %x% rep(1 , n_y )
189+ Amat = cbind(rs_mat , Amat )
190+ b0 = cbind(1 - eta %*% rs_mat , b0 )
191+ }
146192 if (! is.infinite(bounds [1 ])) {
147- Amat = cbind(Amat , diag(n_x ))
193+ Amat = cbind(Amat , diag(n_x * n_y ))
148194 b0 = cbind(b0 , bounds [1 ] - eta )
149195 }
150196 if (! is.infinite(bounds [2 ])) {
151- Amat = cbind(Amat , - diag(n_x ))
197+ Amat = cbind(Amat , - diag(n_x * n_y ))
152198 b0 = cbind(b0 , - bounds [2 ] + eta )
153199 }
154200
201+ idx_eps = sum_one * n_x + seq_len(2 * n_y )
202+ patt_eps = cbind(diag(n_y ), - diag(n_y ))
203+
204+ constr_pt = function (Dmat , bvec , tol ) {
205+ bvec [idx_eps ] = bvec [idx_eps ] - tol
206+ quadprog :: solve.QP(
207+ Dmat = Dmat , # distance metric
208+ dvec = zeros ,
209+ Amat = Amat ,
210+ bvec = bvec ,
211+ meq = sum_one * n_x ,
212+ factorized = TRUE
213+ )$ solution
214+ }
215+
216+ misses = integer(0 )
155217 for (i in seq_len(n )) {
156- Amat [, 1 ] = x [i , ]
157- eta_diff [i , ] = tryCatch({
158- quadprog :: solve.QP(
159- Dmat = r_cov ,
160- dvec = zeros ,
161- Amat = Amat ,
162- bvec = b0 [i , ],
163- meq = 1 ,
164- factorized = TRUE
165- )$ solution
166- }, error = \(e ) eps [i ])
218+ Amat [, idx_eps ] = x [i , ] %x% patt_eps
219+ tol = 1e-12
220+ repeat {
221+ ans = tryCatch(constr_pt(r_cov , b0 [i , ], tol ), error = \(e ) NULL )
222+ if (! is.null(ans )) break
223+ if (tol > 0.005 ) {
224+ misses <<- c(misses , i )
225+ ans = rep(eps [i , ], n_x )
226+ break
227+ }
228+ tol = tol * 1000
229+ }
230+ eta_diff [i , ] = ans
167231 }
168232
169- eta + eta_diff
233+ out = eta + eta_diff
234+ attr(out , " misses" ) = misses
235+ out
170236}
171237
172238local_basis = function (x ) {
0 commit comments