Skip to content

Commit 6614b46

Browse files
committed
b_tpsob()
1 parent 528ef81 commit 6614b46

File tree

11 files changed

+179
-16
lines changed

11 files changed

+179
-16
lines changed

.air.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[format]
2-
line-width = 85
2+
line-width = 96
33
indent-width = 4
44
indent-style = "space"

DESCRIPTION

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
Package: bases
22
Title: Basis Expansions for Regression Modeling
33
Version: 0.1.2.9999
4-
Authors@R:
4+
Authors@R:
55
person("Cory", "McCartan", , "mccartan@psu.edu", role = c("aut", "cre", "cph"),
66
comment = c(ORCID = "0000-0002-6251-669X"))
77
Description: Provides various basis expansions for flexible regression modeling,
88
including random Fourier features (Rahimi & Recht, 2007)
99
<https://proceedings.neurips.cc/paper_files/paper/2007/file/013a006f03dbc5392effeb8f18fda755-Paper.pdf>,
10-
exact kernel / Gaussian process feature maps, prior features for Bayesian
10+
exact kernel / Gaussian process feature maps, prior features for Bayesian
1111
Additive Regression Trees (BART) (Chipman et al., 2010) <doi:10.1214/09-AOAS285>,
1212
and a helpful interface for n-way interactions. The provided functions may
1313
be used within any modeling formula, allowing the use of kernel methods and
@@ -20,13 +20,14 @@ Depends:
2020
Imports:
2121
rlang,
2222
stats
23-
Suggests:
23+
Suggests:
2424
recipes,
25+
Sieve,
2526
tibble,
2627
testthat (>= 3.0.0),
2728
knitr,
2829
rmarkdown
29-
LinkingTo:
30+
LinkingTo:
3031
cpp11
3132
License: MIT + file LICENSE
3233
Encoding: UTF-8

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ S3method(makepredictcall,b_inter)
99
S3method(makepredictcall,b_ker)
1010
S3method(makepredictcall,b_nn)
1111
S3method(makepredictcall,b_rff)
12+
S3method(makepredictcall,b_tpsob)
1213
S3method(predict,b_bart)
1314
S3method(predict,b_inter)
1415
S3method(predict,b_ker)
1516
S3method(predict,b_nn)
1617
S3method(predict,b_rff)
18+
S3method(predict,b_tpsob)
1719
S3method(predict,ridge)
1820
S3method(print,ridge)
1921
S3method(print,step_basis)
@@ -25,6 +27,7 @@ export(b_inter)
2527
export(b_ker)
2628
export(b_nn)
2729
export(b_rff)
30+
export(b_tpsob)
2831
export(bart_depth_prior)
2932
export(k_lapl)
3033
export(k_matern)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* `mgcv` smooth interface via `s()` for more flexible penalization TODO
44
* New `b_nn()` for neural network basis expansion
5+
* New `b_tpsob()` for tensor product Sobolev space basis expansion (Zhang and
6+
Simon 2023). Requires the `Sieve` package to be installed.
57
* New `b_conv()` for random convolutional features for regression on images TODO im2col()
68
* New `b_rocket()` for random convolutional features for regression on time series TODO https://arxiv.org/pdf/2012.08791
79
* New `b_echo()` for echo state network features for time series forecasting TODO https://www.ai.rug.nl/minds/uploads/PracticalESN.pdf

R/b_conv.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@ b_conv <- function(
3131
stride = size,
3232
activation = c("max", "mean", "ppv"),
3333
kernels = NULL
34-
) {}
34+
) {
35+
abort("Not yet implemented.")
36+
}

R/b_tpsob.R

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,79 @@
88
#' \phi_1(x) = 1 \quad\text{and}\quad
99
#' \phi_j(x) = \sqrt{2}\cos(\pi (j-1) x).
1010
#' }
11+
#' The multi-indices \eqn{\mathbf{j}} are generated in a specific order to
12+
#' maximize statistical efficiency.
13+
#' All inputs are standardized to lie in the unit hypercube \eqn{[0, 1]^d}.
1114
#'
15+
#' @inheritParams b_rff
16+
#' @param p The number of basis functions to generate.
17+
#'
18+
#' @returns A matrix of tensor-product Sobolev space basis features.
19+
#'
20+
#' @references
21+
#' Zhang, T., & Simon, N. (2023). Regression in tensor product spaces by the
22+
#' method of sieves. _Electronic journal of statistics_, 17(2), 3660.
23+
#'
24+
#' @examples
25+
#' data(quakes)
26+
#'
27+
#' m = ridge(depth ~ b_tpsob(lat, long, p = 100), quakes)
28+
#' plot(fitted(m), quakes$depth)
29+
#'
30+
#' x = 1:150
31+
#' y = as.numeric(BJsales)
32+
#' m = lm(y ~ b_tpsob(x, p = 10))
33+
#' plot(x, y)
34+
#' lines(x, fitted(m), col="blue")
35+
#' @export
1236
b_tpsob <- function(
1337
...,
1438
p = 100,
15-
stdize = c("scale", "box", "symbox", "none"),
1639
shift = NULL,
1740
scale = NULL
18-
) {}
41+
) {
42+
x = as.matrix(cbind(...))
43+
n = nrow(x)
44+
d = ncol(x)
45+
46+
std = do_std(x, "box", shift, scale)
47+
x = std$x
48+
49+
rlang::check_installed("Sieve", "for this basis function.")
50+
idx = Sieve::create_index_matrix(d, p + 1L, interaction_order = p)
51+
idx = idx[1L + seq_len(p), -1, drop = FALSE]
52+
53+
m = matrix(nrow = n, ncol = p)
54+
for (j in seq_len(p)) {
55+
resc = 2^(sum(idx[j, ] > 1L) / 2)
56+
m[, j] = resc * row_prod(cos(pi * x * rep(idx[j, ] - 1L, each = n)))
57+
}
58+
59+
attr(m, "shift") = std$shift
60+
attr(m, "scale") = std$scale
61+
attr(m, "call") = rlang::current_call()
62+
class(m) = c("b_tpsob", "matrix", "array")
63+
64+
m
65+
}
66+
67+
68+
#' @export
69+
predict.b_tpsob <- function(object, newdata, ...) {
70+
if (missing(newdata)) {
71+
return(object)
72+
}
73+
rlang::eval_tidy(makepredictcall(object, attr(object, "call")), newdata)
74+
}
75+
76+
#' @export
77+
makepredictcall.b_tpsob <- function(var, call) {
78+
if (
79+
as.character(call)[1L] == "b_tpsob" ||
80+
(is.call(call) && identical(eval(call[[1L]]), b_tpsob))
81+
) {
82+
at = attributes(var)[c("shift", "scale")]
83+
call[names(at)] = at
84+
}
85+
call
86+
}

R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ dist_l1 <- function(x, y) {
1212
.Call(`_bases_dist_l1`, x, y)
1313
}
1414

15+
row_prod <- function(x) {
16+
.Call(`_bases_row_prod`, x)
17+
}
18+
1519
im2col <- function(x, h, w, c, size, stride) {
1620
.Call(`_bases_im2col`, x, h, w, c, size, stride)
1721
}

man/b_tpsob.Rd

Lines changed: 36 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/cpp11.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ extern "C" SEXP _bases_dist_l1(SEXP x, SEXP y) {
2626
return cpp11::as_sexp(dist_l1(cpp11::as_cpp<cpp11::decay_t<const doubles_matrix<>>>(x), cpp11::as_cpp<cpp11::decay_t<const doubles_matrix<>>>(y)));
2727
END_CPP11
2828
}
29+
// dist.cpp
30+
doubles row_prod(const doubles_matrix<> x);
31+
extern "C" SEXP _bases_row_prod(SEXP x) {
32+
BEGIN_CPP11
33+
return cpp11::as_sexp(row_prod(cpp11::as_cpp<cpp11::decay_t<const doubles_matrix<>>>(x)));
34+
END_CPP11
35+
}
2936
// im2col.cpp
3037
doubles_matrix<> im2col(const doubles& x, int h, int w, int c, int size, int stride);
3138
extern "C" SEXP _bases_im2col(SEXP x, SEXP h, SEXP w, SEXP c, SEXP size, SEXP stride) {
@@ -40,6 +47,7 @@ static const R_CallMethodDef CallEntries[] = {
4047
{"_bases_dist_l2", (DL_FUNC) &_bases_dist_l2, 2},
4148
{"_bases_forest_mat", (DL_FUNC) &_bases_forest_mat, 4},
4249
{"_bases_im2col", (DL_FUNC) &_bases_im2col, 6},
50+
{"_bases_row_prod", (DL_FUNC) &_bases_row_prod, 1},
4351
{NULL, NULL, 0}
4452
};
4553
}

src/dist.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,27 @@ doubles_matrix<> dist_l1(const doubles_matrix<> x, const doubles_matrix<> y) {
5151

5252
return out;
5353
}
54+
55+
56+
/*
57+
* L1 distance between `x` and `y`, as a matrix with a row for each row of `x`
58+
* and a column for each row of `y`.
59+
*/
60+
[[cpp11::register]]
61+
doubles row_prod(const doubles_matrix<> x) {
62+
int n = x.nrow();
63+
int p = x.ncol();
64+
writable::doubles out(n);
65+
66+
for (int i = 0; i < n; i++) {
67+
out[i] = x(i, 0);
68+
}
69+
for (int j = 1; j < p; j++) {
70+
for (int i = 0; i < n; i++) {
71+
out[i] *= x(i, j);
72+
}
73+
}
74+
75+
return out;
76+
}
77+

0 commit comments

Comments
 (0)