Skip to content

Commit de01da4

Browse files
committed
im2col written
1 parent 0987a43 commit de01da4

File tree

5 files changed

+87
-2
lines changed

5 files changed

+87
-2
lines changed

R/cpp11.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ dist_l2 <- function(x, y) {
1111
dist_l1 <- function(x, y) {
1212
.Call(`_bases_dist_l1`, x, y)
1313
}
14+
15+
im2col <- function(x, h, w, c, size, stride) {
16+
.Call(`_bases_im2col`, x, h, w, c, size, stride)
17+
}

src/bart.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <cpp11.hpp>
2-
#include <vector>
3-
#include <iostream>
2+
43
using namespace cpp11;
54

65
/*

src/cpp11.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,20 @@ extern "C" SEXP _bases_dist_l1(SEXP x, SEXP y) {
2626
return cpp11::as_sexp(dist_l1(cpp11::as_cpp<cpp11::decay_t<const doubles_matrix<>>>(x), cpp11::as_cpp<cpp11::decay_t<const doubles_matrix<>>>(y)));
2727
END_CPP11
2828
}
29+
// im2col.cpp
30+
doubles_matrix<> im2col(const doubles& x, int h, int w, int c, int size, int stride);
31+
extern "C" SEXP _bases_im2col(SEXP x, SEXP h, SEXP w, SEXP c, SEXP size, SEXP stride) {
32+
BEGIN_CPP11
33+
return cpp11::as_sexp(im2col(cpp11::as_cpp<cpp11::decay_t<const doubles&>>(x), cpp11::as_cpp<cpp11::decay_t<int>>(h), cpp11::as_cpp<cpp11::decay_t<int>>(w), cpp11::as_cpp<cpp11::decay_t<int>>(c), cpp11::as_cpp<cpp11::decay_t<int>>(size), cpp11::as_cpp<cpp11::decay_t<int>>(stride)));
34+
END_CPP11
35+
}
2936

3037
extern "C" {
3138
static const R_CallMethodDef CallEntries[] = {
3239
{"_bases_dist_l1", (DL_FUNC) &_bases_dist_l1, 2},
3340
{"_bases_dist_l2", (DL_FUNC) &_bases_dist_l2, 2},
3441
{"_bases_forest_mat", (DL_FUNC) &_bases_forest_mat, 4},
42+
{"_bases_im2col", (DL_FUNC) &_bases_im2col, 6},
3543
{NULL, NULL, 0}
3644
};
3745
}

src/dist.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <cpp11.hpp>
2+
23
using namespace cpp11;
34

45
/*

src/im2col.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include <cpp11.hpp>
2+
3+
using namespace cpp11;
4+
5+
/*
6+
* Convert an (h, w, c) image to a (h'*w'*c, s*s) matrix so that convolutions can
7+
* be performed by matrix multiplication.
8+
* - `size` is the height/width of the kernel (which must be square)
9+
* - `stride` is the spacing between kernel evaluations
10+
* - No padding is performed
11+
* - The output columns are in column-major order, ready for multiplication by
12+
* a column-major-order kernel
13+
*/
14+
[[cpp11::register]]
15+
doubles_matrix<> im2col(const doubles& x, int h, int w, int c,
16+
int size, int stride) {
17+
int h_out = (h - size + 1) / stride;
18+
int w_out = (w - size + 1) / stride;
19+
writable::doubles_matrix<> out(h_out * w_out * c, size * size);
20+
21+
for (int l = 0; l < c; l++) { // channels
22+
for (int j = 0; j < w_out; j++) { // columns
23+
for (int i = 0; i < h_out; i++) { // rows
24+
int out_row = i + h_out * (j + w_out * l);
25+
int idx_in = i*stride + h * (j*stride + w * l);
26+
27+
if (size == 2) {
28+
out(out_row, 0) = x[idx_in];
29+
out(out_row, 1) = x[idx_in + 1];
30+
out(out_row, 2) = x[idx_in + h];
31+
out(out_row, 3) = x[idx_in + 1 + h];
32+
} else if (size == 3) {
33+
out(out_row, 0) = x[idx_in];
34+
out(out_row, 1) = x[idx_in + 1];
35+
out(out_row, 2) = x[idx_in + 2];
36+
out(out_row, 3) = x[idx_in + h];
37+
out(out_row, 4) = x[idx_in + 1 + h];
38+
out(out_row, 5) = x[idx_in + 2 + h];
39+
out(out_row, 6) = x[idx_in + 2*h];
40+
out(out_row, 7) = x[idx_in + 1 + 2*h];
41+
out(out_row, 8) = x[idx_in + 2 + 2*h];
42+
} else if (size == 4) {
43+
out(out_row, 0) = x[idx_in];
44+
out(out_row, 1) = x[idx_in + 1];
45+
out(out_row, 2) = x[idx_in + 2];
46+
out(out_row, 3) = x[idx_in + 3];
47+
out(out_row, 4) = x[idx_in + h];
48+
out(out_row, 5) = x[idx_in + 1 + h];
49+
out(out_row, 6) = x[idx_in + 2 + h];
50+
out(out_row, 7) = x[idx_in + 3 + h];
51+
out(out_row, 8) = x[idx_in + 2*h];
52+
out(out_row, 9) = x[idx_in + 1 + 2*h];
53+
out(out_row, 10) = x[idx_in + 2 + 2*h];
54+
out(out_row, 11) = x[idx_in + 3 + 2*h];
55+
out(out_row, 12) = x[idx_in + 3*h];
56+
out(out_row, 13) = x[idx_in + 1 + 3*h];
57+
out(out_row, 14) = x[idx_in + 2 + 3*h];
58+
out(out_row, 15) = x[idx_in + 3 + 3*h];
59+
} else {
60+
// general double loop for other sizes
61+
for (int k_j = 0; k_j < size; k_j++) {
62+
for (int k_i = 0; k_i < size; k_i++) {
63+
int out_col = k_i + size * k_j;
64+
out(out_row, out_col) = x[idx_in + k_i + h*k_j];
65+
}
66+
}
67+
}
68+
}
69+
}
70+
}
71+
72+
return out;
73+
}

0 commit comments

Comments
 (0)