Skip to content

Commit fa21e9d

Browse files
authored
Refactor to use Vec3D instead of Ray (#29)
Part of MET-36 ## Summary of Changes Replace all instances of the `Ray` class with `Vec3D` (primarily in intersector and getter) to streamline the code and remove unused components. This is part of my sequence of PRs for the forward kernels. The hope is to keep each of them small so that even if we are not going to review it now, it will be easier to catch up after the final deadline and understand what changes were maded. ## Test Plan ```bash pixi run test ```
1 parent f6b7866 commit fa21e9d

File tree

6 files changed

+22
-46
lines changed

6 files changed

+22
-46
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,6 @@ NB_MODULE(_genmetaballs_bindings, m) {
110110
.def("compose", &Pose::compose, "Compose with another pose", nb::arg("pose"))
111111
.def("inv", &Pose::inv, "Inverse pose");
112112

113-
nb::class_<Ray>(geometry, "Ray")
114-
.def(nb::init<Vec3D, Vec3D>())
115-
.def_ro("start", &Ray::start)
116-
.def_ro("direction", &Ray::direction);
117-
118113
/*
119114
* Camera module bindings
120115
*/
@@ -148,7 +143,7 @@ NB_MODULE(_genmetaballs_bindings, m) {
148143
nb::module_ intersector = m.def_submodule("intersector");
149144
intersector.def(
150145
"linear_intersect",
151-
[](const FMB& fmb, const Ray& ray, const Pose& cam_pose) {
146+
[](const FMB& fmb, const Vec3D& ray, const Pose& cam_pose) {
152147
auto [t, d] = LinearIntersector::intersect(fmb, ray, cam_pose);
153148
return std::make_tuple(t, d);
154149
},

genmetaballs/src/cuda/core/geometry.cuh

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,3 @@ public:
9898
return {rotinv, -rotinv.apply(tran_)};
9999
}
100100
};
101-
102-
struct Ray {
103-
Vec3D start;
104-
Vec3D direction;
105-
106-
CUDA_CALLABLE Ray(const Vec3D _start, const Vec3D _dir) : start{_start}, direction{_dir} {}
107-
};

genmetaballs/src/cuda/core/getter.cuh

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
// This is the dummy version of getter, where all FMBs are relevant to any ray
1212
template <MemoryLocation location>
1313
struct AllGetter {
14-
FMBScene<location>& scene;
15-
Pose& extr; // Current assumption: rays are in camera frame
14+
const FMBScene<location>& scene;
15+
const Pose& extr; // Current assumption: rays are in camera frame
1616

17-
CUDA_CALLABLE AllGetter(FMBScene<location>& scene, Pose& extr) : scene(scene), extr(extr) {}
17+
CUDA_CALLABLE AllGetter(const FMBScene<location>& scene, const Pose& extr)
18+
: scene(scene), extr(extr) {}
1819

1920
// It does not bother using the ray, because it simply returns all FMBs
20-
CUDA_CALLABLE FMBScene<location>& get_metaballs(const Ray& ray) const {
21+
CUDA_CALLABLE const FMBScene<location>& get_metaballs(const Vec3D& ray) const {
2122
return scene;
2223
}
2324
};

genmetaballs/src/cuda/core/intersector.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
class LinearIntersector {
1010
public:
1111
/*
12-
* Ray should be in camera frame
12+
* ray should be in camera frame
1313
*/
14-
CUDA_CALLABLE static cuda::std::tuple<float, float> intersect(const FMB& fmb, const Ray& ray,
14+
CUDA_CALLABLE static cuda::std::tuple<float, float> intersect(const FMB& fmb, const Vec3D ray,
1515
const Pose& cam_pose) {
16-
const auto v = cam_pose.get_rot().apply(ray.direction);
16+
const auto v = cam_pose.get_rot().apply(ray);
1717
const auto cov_inv_v = fmb.cov_inv_apply(v);
1818
const auto cam_tran = cam_pose.get_tran();
1919
const auto t = dot(fmb.get_mean() - cam_tran, cov_inv_v) / dot(v, cov_inv_v);

tests/cpp_tests/test_getter.cu

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,11 @@ TEST(AllGetterTest, AllGetterHostTest) {
1919
AllGetter<MemoryLocation::HOST> getter(scene, extr);
2020

2121
// Create test rays
22-
std::vector<Ray> rays = {
23-
Ray{Vec3D{0.0f, 0.0f, 0.0f}, Vec3D{1.0f, 0.0f, 0.0f}},
24-
Ray{Vec3D{1.0f, 1.0f, 1.0f}, Vec3D{0.0f, 1.0f, 0.0f}},
25-
Ray{Vec3D{-1.0f, -1.0f, -1.0f}, Vec3D{0.0f, 0.0f, 1.0f}},
26-
Ray{Vec3D{2.5f, -3.1f, 0.2f}, Vec3D{-0.5f, 0.6f, 0.0f}},
27-
Ray{Vec3D{4.4f, 0.0f, -0.9f}, Vec3D{0.3f, -0.2f, 1.0f}},
28-
Ray{Vec3D{5.0f, 2.2f, 1.1f}, Vec3D{-1.0f, 2.0f, 0.2f}},
29-
Ray{Vec3D{0.0f, 7.0f, 6.0f}, Vec3D{0.0f, -1.0f, -1.0f}},
30-
Ray{Vec3D{-2.0f, 0.0f, 0.0f}, Vec3D{0.2f, 1.1f, 0.7f}},
31-
Ray{Vec3D{9.1f, -0.3f, 2.7f}, Vec3D{-0.3f, 0.1f, 0.0f}},
32-
Ray{Vec3D{1.2f, 8.8f, -4.5f}, Vec3D{1.0f, 0.0f, 1.0f}},
22+
std::vector<Vec3D> rays = {
23+
Vec3D{0.0f, 0.0f, 0.0f}, Vec3D{1.0f, 1.0f, 1.0f}, Vec3D{-1.0f, -1.0f, -1.0f},
24+
Vec3D{2.5f, -3.1f, 0.2f}, Vec3D{4.4f, 0.0f, -0.9f}, Vec3D{5.0f, 2.2f, 1.1f},
25+
Vec3D{0.0f, 7.0f, 6.0f}, Vec3D{-2.0f, 0.0f, 0.0f}, Vec3D{9.1f, -0.3f, 2.7f},
26+
Vec3D{1.2f, 8.8f, -4.5f},
3327
};
3428

3529
// Get reference to all FMBs from the original FMBs object
@@ -49,7 +43,7 @@ TEST(AllGetterTest, AllGetterHostTest) {
4943
}
5044

5145
__global__ void test_get_metaballs_kernel_device(const AllGetter<MemoryLocation::DEVICE> fmb_getter,
52-
const Ray* rays, int num_rays, int* out_sizes) {
46+
const Vec3D* rays, int num_rays, int* out_sizes) {
5347
int idx = threadIdx.x + blockIdx.x * blockDim.x;
5448
const auto& fmbs_returned = fmb_getter.get_metaballs(rays[idx]);
5549
out_sizes[idx] = static_cast<int>(fmbs_returned.size());
@@ -63,17 +57,11 @@ TEST(AllGetterTest, AllGetterDeviceTest) {
6357
AllGetter<MemoryLocation::DEVICE> getter(device_scene, extr);
6458

6559
// Create test rays
66-
std::vector<Ray> rays = {
67-
Ray{Vec3D{0.0f, 0.0f, 0.0f}, Vec3D{1.0f, 0.0f, 0.0f}},
68-
Ray{Vec3D{1.0f, 1.0f, 1.0f}, Vec3D{0.0f, 1.0f, 0.0f}},
69-
Ray{Vec3D{-1.0f, -1.0f, -1.0f}, Vec3D{0.0f, 0.0f, 1.0f}},
70-
Ray{Vec3D{2.5f, -3.1f, 0.2f}, Vec3D{-0.5f, 0.6f, 0.0f}},
71-
Ray{Vec3D{4.4f, 0.0f, -0.9f}, Vec3D{0.3f, -0.2f, 1.0f}},
72-
Ray{Vec3D{5.0f, 2.2f, 1.1f}, Vec3D{-1.0f, 2.0f, 0.2f}},
73-
Ray{Vec3D{0.0f, 7.0f, 6.0f}, Vec3D{0.0f, -1.0f, -1.0f}},
74-
Ray{Vec3D{-2.0f, 0.0f, 0.0f}, Vec3D{0.2f, 1.1f, 0.7f}},
75-
Ray{Vec3D{9.1f, -0.3f, 2.7f}, Vec3D{-0.3f, 0.1f, 0.0f}},
76-
Ray{Vec3D{1.2f, 8.8f, -4.5f}, Vec3D{1.0f, 0.0f, 1.0f}},
60+
std::vector<Vec3D> rays = {
61+
Vec3D{0.0f, 0.0f, 0.0f}, Vec3D{1.0f, 1.0f, 1.0f}, Vec3D{-1.0f, -1.0f, -1.0f},
62+
Vec3D{2.5f, -3.1f, 0.2f}, Vec3D{4.4f, 0.0f, -0.9f}, Vec3D{5.0f, 2.2f, 1.1f},
63+
Vec3D{0.0f, 7.0f, 6.0f}, Vec3D{-2.0f, 0.0f, 0.0f}, Vec3D{9.1f, -0.3f, 2.7f},
64+
Vec3D{1.2f, 8.8f, -4.5f},
7765
};
7866

7967
// Test on GPU for device containers

tests/python_tests/test_intersector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from genmetaballs.core import fmb, geometry, intersector
77

88
FMB = fmb.FMB
9-
Pose, Vec3D, Rotation, Ray = geometry.Pose, geometry.Vec3D, geometry.Rotation, geometry.Ray
9+
Pose, Vec3D, Rotation = geometry.Pose, geometry.Vec3D, geometry.Rotation
1010

1111

1212
@pytest.fixture
@@ -20,7 +20,6 @@ def test_linear_intersect(rng):
2020
cam_quat, fmb_quat = rng.uniform(size=(2, 4)).astype(np.float32)
2121
fmb_extent, fmb_mu = rng.uniform(size=(2, 3)).astype(np.float32)
2222
cam_tran, ray_dir = rng.uniform(size=(2, 3)).astype(np.float32)
23-
ray_start = np.zeros(3, dtype=np.float32) # in camera frame
2423
# ground truth computation
2524
v = Rot.from_quat(cam_quat).apply(ray_dir)
2625
fmb_rotmat = Rot.from_quat(fmb_quat).as_matrix()
@@ -31,6 +30,6 @@ def test_linear_intersect(rng):
3130
fmb_pose = Pose.from_components(Rotation.from_quat(*fmb_quat), Vec3D(*fmb_mu))
3231
cam_pose = Pose.from_components(Rotation.from_quat(*cam_quat), Vec3D(*cam_tran))
3332
fmb = FMB(fmb_pose, *fmb_extent)
34-
ray = Ray(Vec3D(*ray_start), Vec3D(*ray_dir))
33+
ray = Vec3D(*ray_dir)
3534
t_, d_ = intersector.linear_intersect(fmb, ray, cam_pose)
3635
assert np.isclose(t, t_) and np.isclose(d, d_)

0 commit comments

Comments
 (0)