Skip to content

Commit b12a079

Browse files
horizon-bluemugammaarijit-dasgupta
authored
The forward kernel (#32)
Closes MET-36 ## Summary of Changes Things can run without crash now (see [`test_forward.py`](https://github.com/probcomp/GenMetaBalls/pull/32/files#diff-c455199fc8f8e0e45d6157720518e678e885ae165d3e658f95ed95a2989e132cR16-R41) for examples on how to call the `render_fmbs` function and get back an image on Python side), but I haven't verified correctness yet. I'm going to publish this PR before it's fully ready yet in case we want to refer to it while working on other parts in parallel ## Test Plan ```bash pixi run test ``` --------- Co-authored-by: mugamma <mghavami@mit.edu> Co-authored-by: arijit-dasgupta <arijitdg@mit.edu>
1 parent 773c285 commit b12a079

26 files changed

+2981
-164
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
**/__pycache__/**
2+
**/.ipynb_checkpoints/**
23
build/
34
/data/
45
/output/

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ nanobind_add_module(
6464
# Link the core library to the bindings module
6565
target_link_libraries(_genmetaballs_bindings PRIVATE genmetaballs_core)
6666

67+
# Enable CUDA separable compilation for device code linking
68+
set_target_properties(_genmetaballs_bindings PROPERTIES
69+
CUDA_SEPARABLE_COMPILATION ON
70+
)
71+
6772
# Install the extension into the Python package directory
6873
install(TARGETS _genmetaballs_bindings LIBRARY DESTINATION genmetaballs)
6974

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ pixi install
1414

1515
### Development Setup
1616

17-
For development:
17+
For development, make sure Mesa is installed and then set up hooks.
1818

1919
```bash
20+
sudo apt install mesa-common-dev
2021
pixi install
21-
pixi run dev-setup
22+
pixi run dev-setup # set-up hooks
2223
```
2324

2425
The `dev-setup` task sets up [pre-commit](https://pre-commit.com/) git hooks:

genmetaballs/src/cuda/bindings.cu

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
#include "core/camera.cuh"
1111
#include "core/confidence.cuh"
1212
#include "core/fmb.cuh"
13+
#include "core/forward.cuh"
1314
#include "core/geometry.cuh"
15+
#include "core/getter.cuh"
1416
#include "core/image.cuh"
1517
#include "core/intersector.cuh"
1618
#include "core/utils.cuh"
@@ -25,6 +27,8 @@ template <MemoryLocation location>
2527
void bind_image_view(nb::module_& m, const char* name);
2628
template <MemoryLocation location>
2729
void bind_fmb_scene(nb::module_& m, const char* name);
30+
template <typename Blender, typename Confidence>
31+
void bind_render_fmbs(nb::module_& m, const char* name);
2832

2933
NB_MODULE(_genmetaballs_bindings, m) {
3034

@@ -41,7 +45,7 @@ NB_MODULE(_genmetaballs_bindings, m) {
4145
[](const ZeroParameterConfidence& c) { return nb::str("ZeroParameterConfidence()"); });
4246

4347
nb::class_<TwoParameterConfidence>(confidence, "TwoParameterConfidence")
44-
.def(nb::init<float, float>())
48+
.def(nb::init<float, float>(), nb::arg("beta4"), nb::arg("beta5"))
4549
.def_ro("beta4", &TwoParameterConfidence::beta4)
4650
.def_ro("beta5", &TwoParameterConfidence::beta5)
4751
.def("get_confidence", &TwoParameterConfidence::get_confidence, nb::arg("sumexpd"),
@@ -67,10 +71,26 @@ NB_MODULE(_genmetaballs_bindings, m) {
6771
.def("cov_inv_apply", &FMB::cov_inv_apply,
6872
"apply the inverse covariance matrix to the given vector", nb::arg("vec"))
6973
.def("quadratic_form", &FMB::quadratic_form,
70-
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"));
74+
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"))
75+
.def("__repr__", [](const FMB& self) {
76+
return nb::str("FMB(pose={}, extent={})").format(self.get_pose(), self.get_extent());
77+
});
7178
bind_fmb_scene<MemoryLocation::HOST>(fmb, "CPUFMBScene");
7279
bind_fmb_scene<MemoryLocation::DEVICE>(fmb, "GPUFMBScene");
7380

81+
/*
82+
* Forward (rendering) module bindings
83+
*/
84+
nb::module_ forward = m.def_submodule("forward", "Forward rendering of FMBs");
85+
bind_render_fmbs<FourParameterBlender, ZeroParameterConfidence>(
86+
forward, "render_fmbs_four_param_zero_confidence");
87+
bind_render_fmbs<ThreeParameterBlender, TwoParameterConfidence>(
88+
forward, "render_fmbs_three_param_two_confidence");
89+
bind_render_fmbs<ThreeParameterBlender, ZeroParameterConfidence>(
90+
forward, "render_fmbs_three_param_zero_confidence");
91+
bind_render_fmbs<FourParameterBlender, TwoParameterConfidence>(
92+
forward, "render_fmbs_four_param_two_confidence");
93+
7494
/*
7595
* Geometry module bindings
7696
*/
@@ -99,9 +119,21 @@ NB_MODULE(_genmetaballs_bindings, m) {
99119
.def(nb::init<>())
100120
.def_static("from_quat", &Rotation::from_quat, "Create rotation from quaternion",
101121
nb::arg("x"), nb::arg("y"), nb::arg("z"), nb::arg("w"))
122+
.def_prop_ro(
123+
"quat",
124+
[](const Rotation& self) {
125+
auto quat = self.get_quat();
126+
return std::tuple{quat.x, quat.y, quat.z, quat.w};
127+
},
128+
"Get quaternion components as (x, y, z, w)")
102129
.def("apply", &Rotation::apply, "Apply rotation to vector", nb::arg("vec"))
103130
.def("compose", &Rotation::compose, "Compose with another rotation", nb::arg("rot"))
104-
.def("inv", &Rotation::inv, "Inverse rotation");
131+
.def("inv", &Rotation::inv, "Inverse rotation")
132+
.def("__repr__", [](const Rotation& self) {
133+
auto quat = self.get_quat();
134+
return nb::str("Rotation(x={}, y={}, z={}, w={})")
135+
.format(quat.x, quat.y, quat.z, quat.w);
136+
});
105137

106138
nb::class_<Pose>(geometry, "Pose")
107139
.def(nb::init<>())
@@ -112,15 +144,17 @@ NB_MODULE(_genmetaballs_bindings, m) {
112144
.def_prop_ro("tran", &Pose::get_tran, "get the translation component")
113145
.def("apply", &Pose::apply, "Apply pose to vector", nb::arg("vec"))
114146
.def("compose", &Pose::compose, "Compose with another pose", nb::arg("pose"))
115-
.def("inv", &Pose::inv, "Inverse pose");
116-
147+
.def("inv", &Pose::inv, "Inverse pose")
148+
.def("__repr__", [](const Pose& self) {
149+
return nb::str("Pose(rot={}, tran={})").format(self.get_rot(), self.get_tran());
150+
});
117151
/*
118152
* Camera module bindings
119153
*/
120154
nb::module_ camera = m.def_submodule("camera", "Camera intrinsics and extrinsics");
121155
nb::class_<Intrinsics>(camera, "Intrinsics")
122-
.def(nb::init<uint32_t, uint32_t, float, float, float, float>(), nb::arg("height"),
123-
nb::arg("width"), nb::arg("fx"), nb::arg("fy"), nb::arg("cx"), nb::arg("cy"))
156+
.def(nb::init<uint32_t, uint32_t, float, float, float, float>(), nb::arg("width"),
157+
nb::arg("height"), nb::arg("fx"), nb::arg("fy"), nb::arg("cx"), nb::arg("cy"))
124158
.def_ro("height", &Intrinsics::height)
125159
.def_ro("width", &Intrinsics::width)
126160
.def_ro("fx", &Intrinsics::fx)
@@ -129,7 +163,11 @@ NB_MODULE(_genmetaballs_bindings, m) {
129163
.def_ro("cy", &Intrinsics::cy)
130164
.def("get_ray_direction", &Intrinsics::get_ray_direction,
131165
"Get the direction of the ray going through pixel (px, py) in camera frame",
132-
nb::arg("px"), nb::arg("py"));
166+
nb::arg("px"), nb::arg("py"))
167+
.def("__repr__", [](const Intrinsics& self) {
168+
return nb::str("Intrinsics(width={}, height={}, fx={}, fy={}, cx={}, cy={})")
169+
.format(self.width, self.height, self.fx, self.fy, self.cx, self.cy);
170+
});
133171

134172
/*
135173
* Image module bindings
@@ -163,7 +201,8 @@ NB_MODULE(_genmetaballs_bindings, m) {
163201
// blender submodule
164202
nb::module_ blender = m.def_submodule("blender");
165203
nb::class_<FourParameterBlender>(blender, "FourParameterBlender")
166-
.def(nb::init<float, float, float, float>())
204+
.def(nb::init<float, float, float, float>(), nb::arg("beta1"), nb::arg("beta2"),
205+
nb::arg("beta3"), nb::arg("eta"))
167206
.def_ro("beta1", &FourParameterBlender::beta1)
168207
.def_ro("beta2", &FourParameterBlender::beta2)
169208
.def_ro("beta3", &FourParameterBlender::beta3)
@@ -176,7 +215,7 @@ NB_MODULE(_genmetaballs_bindings, m) {
176215
});
177216

178217
nb::class_<ThreeParameterBlender>(blender, "ThreeParameterBlender")
179-
.def(nb::init<float, float, float>())
218+
.def(nb::init<float, float, float>(), nb::arg("beta1"), nb::arg("beta2"), nb::arg("eta"))
180219
.def_ro("beta1", &ThreeParameterBlender::beta1)
181220
.def_ro("beta2", &ThreeParameterBlender::beta2)
182221
.def_ro("eta", &ThreeParameterBlender::eta)
@@ -273,3 +312,12 @@ void bind_fmb_scene(nb::module_& m, const char* name) {
273312
return nb::str("{}(size={})").format(name, scene.size());
274313
});
275314
}
315+
316+
template <typename Blender, typename Confidence>
317+
void bind_render_fmbs(nb::module_& m, const char* name) {
318+
m.def(name,
319+
&render_fmbs<AllGetter<MemoryLocation::DEVICE>, LinearIntersector, Blender, Confidence>,
320+
"Render the given FMB scene into the provided image view", nb::arg("fmbs"),
321+
nb::arg("blender"), nb::arg("confidence"), nb::arg("intr"), nb::arg("extr"),
322+
nb::arg("img"));
323+
}

genmetaballs/src/cuda/core/blender.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ struct ThreeParameterBlender {
2121
float beta2;
2222
float eta;
2323

24-
CUDA_CALLABLE __forceinline__ float blend(float t, float d) const {
25-
return expf((beta1 * d) - ((beta2 / eta) * t));
24+
CUDA_CALLABLE __forceinline__ float blend(float tmp, float d) const {
25+
return expf((beta1 * tmp) - ((beta2 / eta) * d));
2626
}
2727
};

genmetaballs/src/cuda/core/camera.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ CUDA_CALLABLE PixelCoordRange::Iterator& PixelCoordRange::Iterator::operator++()
2727
}
2828

2929
CUDA_CALLABLE bool PixelCoordRange::Sentinel::operator==(const Iterator& it) const {
30-
return it.py >= py_end;
30+
// stop if we reach the end of rows, or if the range is empty
31+
return it.py >= py_end || it.px_start >= it.px_end || it.py_start >= py_end;
3132
}
3233

3334
CUDA_CALLABLE PixelCoordRange::Iterator PixelCoordRange::begin() const {

genmetaballs/src/cuda/core/camera.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
#include "utils.cuh"
99

1010
struct Intrinsics {
11-
uint32_t height; // in x direction
12-
uint32_t width; // in y direction
11+
uint32_t width; // in x direction
12+
uint32_t height; // in y direction
1313
float fx;
1414
float fy;
1515
float cx;

genmetaballs/src/cuda/core/fmb.cu

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@ CUDA_CALLABLE __forceinline__ Vec3D vecdiv(const Vec3D u, const Vec3D v) {
66
return {u.x / v.x, u.y / v.y, u.z / v.z};
77
}
88

9+
// CUDA_CALLABLE Vec3D FMB::cov_inv_apply(const Vec3D vec) const {
10+
// const auto rot = pose_.get_rot();
11+
// return rot.inv().apply(vecdiv(rot.apply(vec), extent_));
12+
// }
13+
914
CUDA_CALLABLE Vec3D FMB::cov_inv_apply(const Vec3D vec) const {
1015
const auto rot = pose_.get_rot();
11-
return rot.inv().apply(vecdiv(rot.apply(vec), extent_));
16+
// Wanted to add more infor here
17+
// Basically the order of the operation has bee swapper to look something like this:
18+
// R @ diag(1/extent) @ R^T @ vec however, i dont think this fixes everything
19+
return rot.apply(vecdiv(rot.inv().apply(vec), extent_));
1220
}
1321

1422
CUDA_CALLABLE float FMB::quadratic_form(const Vec3D vec) const {

genmetaballs/src/cuda/core/forward.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3
99
const dim3 block_dim, const dim3 grid_dim,
1010
const Intrinsics& intr) {
1111
// compute the number of pixels each thread should process
12-
const auto num_pixels_x = int_ceil_div(intr.height, grid_dim.x * block_dim.x);
13-
const auto num_pixels_y = int_ceil_div(intr.width, grid_dim.y * block_dim.y);
12+
const auto num_pixels_x = int_ceil_div(intr.width, grid_dim.x * block_dim.x);
13+
const auto num_pixels_y = int_ceil_div(intr.height, grid_dim.y * block_dim.y);
1414
const auto start_x = (block_idx.x * block_dim.x + thread_idx.x) * num_pixels_x;
1515
const auto start_y = (block_idx.y * block_dim.y + thread_idx.y) * num_pixels_y;
1616
return PixelCoordRange{.px_start = start_x,
17-
.px_end = min(start_x + num_pixels_x, intr.height),
17+
.px_end = min(start_x + num_pixels_x, intr.width),
1818
.py_start = start_y,
19-
.py_end = min(start_y + num_pixels_y, intr.width)};
19+
.py_end = min(start_y + num_pixels_y, intr.height)};
2020
}

genmetaballs/src/cuda/core/forward.cuh

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,38 @@ CUDA_CALLABLE PixelCoordRange get_pixel_coords(const dim3 thread_idx, const dim3
1818
const Intrinsics& intr);
1919

2020
template <typename Getter, typename Intersector, typename Blender, typename Confidence>
21-
__global__ void render_kernel(const Getter fmb_getter, const Blender blender,
22-
Confidence const* confidence, Intrinsics const intr, Pose const* extr,
23-
ImageView<MemoryLocation::DEVICE> img) {
21+
__global__ void render_kernel(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Blender& blender,
22+
const Confidence& confidence, const Intrinsics& intr,
23+
const Pose& extr, ImageView<MemoryLocation::DEVICE> img) {
2424
auto pixel_coords = get_pixel_coords(threadIdx, blockIdx, blockDim, gridDim, intr);
25+
auto fmb_getter = Getter(fmbs, extr);
2526

26-
for (const auto& [px, py] : pixel_coords) {
27-
float w0 = 0.0f, tf = 0.0f, sumexpd = 0.0f;
27+
for (const auto [px, py] : pixel_coords) {
28+
float depth_denom = 0.0f, depth_numer = 0.0f, conf_tmp = 0.0f;
2829
auto ray = intr.get_ray_direction(px, py);
29-
for (const auto& fmb : fmb_getter->get_metaballs(ray)) {
30-
const auto& [t, d] = Intersector::intersect(fmb, ray, extr);
31-
auto w = blender->blend(t, d, fmb, ray);
32-
sumexpd += exp(d); // numerically unstable. use logsumexp
33-
tf += t;
34-
w0 += w;
30+
for (const auto& [fmb, lambda] : fmb_getter.get_metaballs(ray)) {
31+
// d: intersection point along the ray
32+
// q: square of Mahalanobis distance at intersection point
33+
const auto& [d, q] = Intersector::intersect(fmb, ray, extr);
34+
auto tmp = -0.5f * q + lambda;
35+
// the next check is needed to match the reference implementation
36+
// even though it is not in the paper.
37+
auto w_tilde = d > 0 ? blender.blend(tmp, d) : 1e-20f;
38+
conf_tmp += exp(tmp); // numerically unstable. use logsumexp
39+
depth_numer += d * w_tilde;
40+
depth_denom += w_tilde;
3541
}
36-
img.confidence[px][py] = confidence->get_confidence(sumexpd);
37-
img.depth[px][py] = tf / w0;
42+
// the indexing is done this way because the underlying array2ds use
43+
// ij indexing, whereas the pixels uses xy indexing
44+
img.confidence[intr.height - py - 1][px] = confidence.get_confidence(conf_tmp);
45+
img.depth[intr.height - py - 1][px] = depth_numer / depth_denom;
3846
}
3947
}
4048

4149
template <typename Getter, typename Intersector, typename Blender, typename Confidence>
42-
void render_fmbs(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Intrinsics& intr,
43-
const Pose& extr) {
44-
// initialize the fmb_getter
45-
auto fmb_getter = Getter(fmbs, extr);
46-
auto& kernel = render_kernel<Getter, Intersector, Blender, Confidence>;
47-
kernel<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmb_getter, fmbs, intr, extr);
50+
void render_fmbs(const FMBScene<MemoryLocation::DEVICE>& fmbs, const Blender& blender,
51+
const Confidence& confidence, const Intrinsics& intr, const Pose& extr,
52+
ImageView<MemoryLocation::DEVICE> img) {
53+
render_kernel<Getter, Intersector, Blender, Confidence>
54+
<<<NUM_BLOCKS, THREADS_PER_BLOCK>>>(fmbs, blender, confidence, intr, extr, img);
4855
}

0 commit comments

Comments
 (0)