Skip to content

Commit 30f806f

Browse files
committed
kernel reprediction speedup option
1 parent de01da4 commit 30f806f

File tree

4 files changed

+42
-7
lines changed

4 files changed

+42
-7
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* New `b_conv()` for random convolutional features for regression on images TODO im2col()
66
* New `b_rocket()` for random convolutional features for regression on time series TODO https://arxiv.org/pdf/2012.08791
77
* New `b_echo()` for echo state network features for time series forecasting TODO https://www.ai.rug.nl/minds/uploads/PracticalESN.pdf
8-
* More efficient `b_ker()` option for many predictions TODO
8+
* More efficient `b_ker()` option for many predictions
99
* New vignette on other packages that help produce basis expansions or embeddings.
1010

1111
# bases 0.1.2

R/b_ker.R

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@
1313
#' @inheritParams b_rff
1414
#' @param x The (training) data points at which to evaluate the kernel. If
1515
#' provided, overrides `...`.
16+
#' @param L_inv The inverse of the Cholesky factor of the kernel matrix at the
17+
#' training points. Will be automatically computed if not provided, but in
18+
#' order to avoid recomputing it for new predictions, pass `L_inv = TRUE`,
19+
#' which will save and re-use this matrix for future calls.
1620
#'
1721
#' @returns A matrix of kernel features.
1822
#'
@@ -39,19 +43,26 @@
3943
#' @export
4044
b_ker <- function(..., kernel = k_rbf(),
4145
stdize = c("scale", "box", "symbox", "none"),
42-
x = NULL, shift = NULL, scale = NULL) {
46+
x = NULL, shift = NULL, scale = NULL, L_inv = NULL) {
4347
y = as.matrix(cbind(...))
4448
std = do_std(y, stdize, shift, scale)
4549
y = std$x
4650
if (is.null(x)) x = y
4751

48-
II = diag(nrow(x))
49-
K = kernel(x, x) + 1e-9 * II
50-
m = kernel(y, x) %*% backsolve(chol(K), II)
52+
save = isTRUE(L_inv)
53+
if (is.null(L_inv) || save) {
54+
II = diag(nrow(x))
55+
K = kernel(x, x) + 1e-9 * II
56+
L_inv = backsolve(chol(K), II)
57+
}
58+
m = kernel(y, x) %*% L_inv
5159

5260
attr(m, "x") = x
5361
attr(m, "shift") = std$shift
5462
attr(m, "scale") = std$scale
63+
if (save) {
64+
attr(m, "L_inv") = L_inv
65+
}
5566
attr(m, "call") = rlang::current_call()
5667
class(m) = c("b_ker", "matrix", "array")
5768

@@ -70,7 +81,7 @@ predict.b_ker <- function (object, newdata, ...) {
7081
makepredictcall.b_ker <- function(var, call) {
7182
if (as.character(call)[1L] == "b_ker" ||
7283
(is.call(call) && identical(eval(call[[1L]]), b_ker))) {
73-
at = attributes(var)[c("x", "shift", "scale")]
84+
at = attributes(var)[c("x", "shift", "scale", "L_inv")]
7485
call[names(at)] = at
7586
}
7687
call

man/b_ker.Rd

Lines changed: 7 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-ker-exact.R

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,21 @@ test_that("predict() method works correctly", {
2323
pred_m = suppressWarnings(predict(m, newdata=list(x=xn)))
2424
expect_equal(unname(pred_m[2:21]), unname(fitted(m)[1:20]), tolerance=1e-4)
2525
})
26+
27+
test_that("predict() works with saved L_inv", {
28+
y = c(BJsales)
29+
x = seq_along(y)
30+
xn = c(0:20, 150:200)
31+
32+
m0 = lm(y ~ b_ker(x))
33+
m1 = lm(y ~ b_ker(x, L_inv=TRUE))
34+
35+
expect_equal(predict(m1), predict(m0))
36+
expect_equal(predict(m1), fitted(m1), tolerance=1e-4)
37+
expect_equal(predict(m1), predict(m1, list(x=x)))
38+
39+
skip_on_cran() # don't run timing tests on CRAN
40+
t0 = system.time(for (i in 1:50) predict(m0, list(x = x)))["elapsed"]
41+
t1 = system.time(for (i in 1:50) predict(m1, list(x = x)))["elapsed"]
42+
expect_lt(t1, t0)
43+
})

0 commit comments

Comments
 (0)