Skip to content

Commit 4b6e3a1

Browse files
mugammahorizon-blue
authored andcommitted
all tests passing
1 parent eac6c13 commit 4b6e3a1

File tree

7 files changed

+81
-17
lines changed

7 files changed

+81
-17
lines changed

genmetaballs/src/cuda/bindings.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ NB_MODULE(_genmetaballs_bindings, m) {
6262
auto extent = self.get_extent();
6363
return std::tuple{extent.x, extent.y, extent.z};
6464
})
65+
.def("cov_inv_apply", &FMB::cov_inv_apply,
66+
"apply the inverse covariance matrix to the given vector", nb::arg("vec"))
6567
.def("quadratic_form", &FMB::quadratic_form,
6668
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"));
6769

@@ -146,11 +148,11 @@ NB_MODULE(_genmetaballs_bindings, m) {
146148
nb::module_ intersector = m.def_submodule("intersector");
147149
intersector.def(
148150
"linear_intersect",
149-
[](const FMB& fmb, const Ray& ray) {
150-
auto [t, d] = LinearIntersector::intersect(fmb, ray);
151+
[](const FMB& fmb, const Ray& ray, const Pose& cam_pose) {
152+
auto [t, d] = LinearIntersector::intersect(fmb, ray, cam_pose);
151153
return std::make_tuple(t, d);
152154
},
153-
"Linear intersection of ray and FMB.", nb::arg("fmb"), nb::arg("ray"));
155+
"Linear intersection of ray and FMB.", nb::arg("fmb"), nb::arg("ray"), nb::arg("cam_pose"));
154156

155157
/*
156158
* Utils module bindings

genmetaballs/src/cuda/core/fmb.cu

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22
#include "geometry.cuh"
33
#include "utils.cuh"
44

5+
CUDA_CALLABLE __forceinline__ Vec3D vecdiv(const Vec3D u, const Vec3D v) {
6+
return {u.x / v.x, u.y / v.y, u.z / v.z};
7+
}
8+
9+
CUDA_CALLABLE Vec3D FMB::cov_inv_apply(const Vec3D vec) const {
10+
const auto rot = pose_.get_rot();
11+
return rot.inv().apply(vecdiv(rot.apply(vec), extent_));
12+
}
13+
514
CUDA_CALLABLE float FMB::quadratic_form(const Vec3D vec) const {
6-
const auto shftd_vec = vec - pose_.get_tran();
7-
const auto rot_shftd_vec = pose_.get_rot().apply(shftd_vec);
8-
const auto scaled_rot_shftd_vec = Vec3D(
9-
rot_shftd_vec.x / extent_.x, rot_shftd_vec.y / extent_.y, rot_shftd_vec.z / extent_.z);
10-
return dot(rot_shftd_vec, scaled_rot_shftd_vec);
15+
const auto shifted_vec = vec - get_mean();
16+
return dot(shifted_vec, cov_inv_apply(shifted_vec));
1117
}
1218

1319
template <>

genmetaballs/src/cuda/core/fmb.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ public:
3434
CUDA_CALLABLE float3 get_extent() const {
3535
return extent_;
3636
}
37+
CUDA_CALLABLE float3 get_mean() const {
38+
return pose_.get_tran();
39+
}
40+
41+
CUDA_CALLABLE Vec3D cov_inv_apply(const Vec3D) const;
3742

3843
CUDA_CALLABLE float quadratic_form(const Vec3D) const;
3944
};

genmetaballs/src/cuda/core/forward.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ __global__ render_kernel(const Getter fmb_getter, const Blender blender,
3333
for (const auto& [pixel_coords, ray] : pixel_coords_and_rays) {
3434
float w0 = 0.0f, tf = 0.0f, sumexpd = 0.0f;
3535
for (const auto& fmb : fmb_getter->get_metaballs(ray)) {
36-
const auto& [t, d] = Intersector::intersect(fmb, ray);
36+
const auto& [t, d] = Intersector::intersect(fmb, ray, extr);
3737
w = blender->blend(t, d, fmb, ray);
3838
sumexpd += exp(d); // numerically unstable. use logsumexp
3939
tf += t;

genmetaballs/src/cuda/core/intersector.cuh

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
// implement equation (6) in the paper
99
class LinearIntersector {
1010
public:
11-
CUDA_CALLABLE static cuda::std::tuple<float, float> intersect(const FMB& fmb, const Ray& ray) {
12-
auto vecdiv = [](const Vec3D& u, const Vec3D& v) {
13-
return Vec3D{u.x / v.x, u.y / v.y, u.z / v.z};
14-
};
15-
auto rot = fmb.get_pose().get_rot();
16-
auto tmp = rot.inv().apply(vecdiv(rot.apply(ray.direction), fmb.get_extent()));
17-
auto t = dot(fmb.get_pose().get_tran() - ray.start, tmp) / dot(ray.direction, tmp);
18-
return {t, fmb.quadratic_form(ray.start + t * ray.direction)};
11+
/*
12+
* Ray should be in camera frame
13+
*/
14+
CUDA_CALLABLE static cuda::std::tuple<float, float> intersect(const FMB& fmb, const Ray& ray,
15+
const Pose& cam_pose) {
16+
const auto v = cam_pose.get_rot().apply(ray.direction);
17+
const auto cov_inv_v = fmb.cov_inv_apply(v);
18+
const auto cam_tran = cam_pose.get_tran();
19+
const auto t = dot(fmb.get_mean() - cam_tran, cov_inv_v) / dot(v, cov_inv_v);
20+
return {t, fmb.quadratic_form(cam_tran + t * v)};
1921
}
2022
};

tests/python_tests/test_fmb.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,19 @@ def rng() -> np.random.Generator:
1414
return np.random.default_rng(0)
1515

1616

17+
def test_fmb_cov_inv_apply(rng):
18+
for _ in range(100):
19+
quat = rng.uniform(size=4).astype(np.float32)
20+
tran, extent, vec = rng.uniform(size=(3, 3)).astype(np.float32)
21+
pose = Pose.from_components(Rotation.from_quat(*quat), Vec3D(*tran))
22+
scipy_rot_mat = Rot.from_quat(quat).as_matrix()
23+
cov = scipy_rot_mat.T @ np.diag(extent) @ scipy_rot_mat
24+
theirs = np.linalg.solve(cov, vec)
25+
ours = FMB(pose, *extent).cov_inv_apply(Vec3D(*vec))
26+
ourvec = np.array([ours.x, ours.y, ours.z], dtype=np.float32)
27+
assert np.allclose(theirs, ourvec, atol=1e-6)
28+
29+
1730
def test_fmb_quadratic_form(rng):
1831
for _ in range(100):
1932
quat = rng.uniform(size=4)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pytest
3+
from scipy.spatial.distance import mahalanobis
4+
from scipy.spatial.transform import Rotation as Rot
5+
6+
from genmetaballs.core import fmb, geometry, intersector
7+
8+
FMB = fmb.FMB
9+
Pose, Vec3D, Rotation, Ray = geometry.Pose, geometry.Vec3D, geometry.Rotation, geometry.Ray
10+
11+
12+
@pytest.fixture
13+
def rng() -> np.random.Generator:
14+
return np.random.default_rng(0)
15+
16+
17+
def test_linear_intersect(rng):
18+
for _ in range(100):
19+
# sample
20+
cam_quat, fmb_quat = rng.uniform(size=(2, 4)).astype(np.float32)
21+
fmb_extent, fmb_mu = rng.uniform(size=(2, 3)).astype(np.float32)
22+
cam_tran, ray_dir = rng.uniform(size=(2, 3)).astype(np.float32)
23+
ray_start = np.zeros(3, dtype=np.float32) # in camera frame
24+
# ground truth computation
25+
v = Rot.from_quat(cam_quat).apply(ray_dir)
26+
fmb_rotmat = Rot.from_quat(fmb_quat).as_matrix()
27+
cov_inv = fmb_rotmat.T @ np.diag(1 / fmb_extent) @ fmb_rotmat
28+
t = ((fmb_mu - cam_tran) @ cov_inv @ v) / (v @ cov_inv @ v)
29+
d = mahalanobis(fmb_mu, cam_tran + t * v, cov_inv) ** 2
30+
# cuda computation
31+
fmb_pose = Pose.from_components(Rotation.from_quat(*fmb_quat), Vec3D(*fmb_mu))
32+
cam_pose = Pose.from_components(Rotation.from_quat(*cam_quat), Vec3D(*cam_tran))
33+
fmb = FMB(fmb_pose, *fmb_extent)
34+
ray = Ray(Vec3D(*ray_start), Vec3D(*ray_dir))
35+
t_, d_ = intersector.linear_intersect(fmb, ray, cam_pose)
36+
assert np.isclose(t, t_) and np.isclose(d, d_)

0 commit comments

Comments
 (0)