1313# ' @inheritParams b_rff
1414# ' @param x The (training) data points at which to evaluate the kernel. If
1515# ' provided, overrides `...`.
16+ # ' @param L_inv The inverse of the Cholesky factor of the kernel matrix at the
17+ # ' training points. Will be automatically computed if not provided, but in
18+ # ' order to avoid recomputing it for new predictions, pass `L_inv = TRUE`,
19+ # ' which will save and re-use this matrix for future calls.
1620# '
1721# ' @returns A matrix of kernel features.
1822# '
3943# ' @export
4044b_ker <- function (... , kernel = k_rbf(),
4145 stdize = c(" scale" , " box" , " symbox" , " none" ),
42- x = NULL , shift = NULL , scale = NULL ) {
46+ x = NULL , shift = NULL , scale = NULL , L_inv = NULL ) {
4347 y = as.matrix(cbind(... ))
4448 std = do_std(y , stdize , shift , scale )
4549 y = std $ x
4650 if (is.null(x )) x = y
4751
48- II = diag(nrow(x ))
49- K = kernel(x , x ) + 1e-9 * II
50- m = kernel(y , x ) %*% backsolve(chol(K ), II )
52+ save = isTRUE(L_inv )
53+ if (is.null(L_inv ) || save ) {
54+ II = diag(nrow(x ))
55+ K = kernel(x , x ) + 1e-9 * II
56+ L_inv = backsolve(chol(K ), II )
57+ }
58+ m = kernel(y , x ) %*% L_inv
5159
5260 attr(m , " x" ) = x
5361 attr(m , " shift" ) = std $ shift
5462 attr(m , " scale" ) = std $ scale
63+ if (save ) {
64+ attr(m , " L_inv" ) = L_inv
65+ }
5566 attr(m , " call" ) = rlang :: current_call()
5667 class(m ) = c(" b_ker" , " matrix" , " array" )
5768
@@ -70,7 +81,7 @@ predict.b_ker <- function (object, newdata, ...) {
7081makepredictcall.b_ker <- function (var , call ) {
7182 if (as.character(call )[1L ] == " b_ker" ||
7283 (is.call(call ) && identical(eval(call [[1L ]]), b_ker ))) {
73- at = attributes(var )[c(" x" , " shift" , " scale" )]
84+ at = attributes(var )[c(" x" , " shift" , " scale" , " L_inv " )]
7485 call [names(at )] = at
7586 }
7687 call
0 commit comments