Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions genmetaballs/src/cuda/bindings.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ template <MemoryLocation location>
void bind_image(nb::module_& m, const char* name);
template <MemoryLocation location>
void bind_image_view(nb::module_& m, const char* name);
template <MemoryLocation location>
void bind_fmb_scene(nb::module_& m, const char* name);

NB_MODULE(_genmetaballs_bindings, m) {

Expand Down Expand Up @@ -66,6 +68,8 @@ NB_MODULE(_genmetaballs_bindings, m) {
"apply the inverse covariance matrix to the given vector", nb::arg("vec"))
.def("quadratic_form", &FMB::quadratic_form,
"Evaluate the associated quadratic form at the given vector", nb::arg("vec"));
bind_fmb_scene<MemoryLocation::HOST>(fmb, "CPUFMBScene");
bind_fmb_scene<MemoryLocation::DEVICE>(fmb, "GPUFMBScene");

/*
* Geometry module bindings
Expand Down Expand Up @@ -244,3 +248,16 @@ void bind_image(nb::module_& m, const char* name) {
return nb::str("{}(height={}, width={})").format(name, img.num_rows(), img.num_cols());
});
}

template <MemoryLocation location>
void bind_fmb_scene(nb::module_& m, const char* name) {
nb::class_<FMBScene<location>>(m, name)
.def(nb::init<size_t>(), nb::arg("size"))
.def_prop_ro("size", &FMBScene<location>::size)
.def("__len__", &FMBScene<location>::size)
.def("__getitem__", &FMBScene<location>::get_fmb, nb::arg("idx"),
"Get the (FMB, log_weight) tuple at index i")
.def("__repr__", [=](const FMBScene<location>& scene) {
return nb::str("{}(size={})").format(name, scene.size());
});
}
17 changes: 17 additions & 0 deletions genmetaballs/src/genmetaballs/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TwoParameterConfidence,
ZeroParameterConfidence,
)
from genmetaballs._genmetaballs_bindings.fmb import CPUFMBScene, GPUFMBScene
from genmetaballs._genmetaballs_bindings.image import CPUImage, GPUImage
from genmetaballs._genmetaballs_bindings.utils import CPUFloatArray2D, GPUFloatArray2D, sigmoid

Expand Down Expand Up @@ -47,6 +48,21 @@ def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUIma
raise ValueError(f"Unsupported device type: {device}")


def make_fmb_scene(size: int, device: DeviceType) -> CPUFMBScene | GPUFMBScene:
"""Create an FMBScene on the specified device.

Args:
size: The number of FMBs in the scene.
device: 'cpu' or 'gpu' to specify the target device.
"""
if device == "cpu":
return CPUFMBScene(size)
elif device == "gpu":
return GPUFMBScene(size)
else:
raise ValueError(f"Unsupported device type: {device}")


__all__ = [
"array2d_float",
"ZeroParameterConfidence",
Expand All @@ -60,4 +76,5 @@ def make_image(height: int, width: int, device: DeviceType) -> CPUImage | GPUIma
"FourParameterBlender",
"ThreeParameterBlender",
"make_image",
"make_fmb_scene",
]
12 changes: 11 additions & 1 deletion tests/python_tests/test_fmb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from scipy.spatial.distance import mahalanobis
from scipy.spatial.transform import Rotation as Rot

from genmetaballs.core import fmb, geometry
from genmetaballs.core import fmb, geometry, make_fmb_scene

FMB = fmb.FMB
Pose, Vec3D, Rotation = geometry.Pose, geometry.Vec3D, geometry.Rotation
Expand Down Expand Up @@ -38,3 +38,13 @@ def test_fmb_quadratic_form(rng):
FMB(pose, *extent).quadratic_form(Vec3D(*vec)),
mahalanobis(vec, tran, np.linalg.inv(cov)) ** 2,
)


def test_fmb_scene_creation():
cpu_scene = make_fmb_scene(10, device="cpu")
assert isinstance(cpu_scene, fmb.CPUFMBScene)
assert len(cpu_scene) == 10

gpu_scene = make_fmb_scene(20, device="gpu")
assert isinstance(gpu_scene, fmb.GPUFMBScene)
assert len(gpu_scene) == 20