Skip to content
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
`.to_dual_graph()` methods. These allow Mesh conversion to 0D point clouds, 1D
edge graphs, and 1D dual graphs, respectively, when connectivity information
is not needed.
- Adds `physicsnemo.mesh.generate` subpackage with `marching_cubes` for
isosurface extraction from 3D scalar fields, returning a `Mesh` object.
Supports Warp and scikit-image backends.

### Changed

Expand Down
23 changes: 23 additions & 0 deletions physicsnemo/mesh/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mesh generation from implicit representations.

This module provides functions for generating meshes from scalar fields,
including isosurface extraction via the marching cubes algorithm.
"""

from physicsnemo.mesh.generate.marching_cubes import marching_cubes
140 changes: 140 additions & 0 deletions physicsnemo/mesh/generate/marching_cubes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Isosurface extraction via the marching cubes algorithm."""

from typing import TYPE_CHECKING

import numpy as np
import torch
import warp as wp
from jaxtyping import Float

if TYPE_CHECKING:
from physicsnemo.mesh.mesh import Mesh


def marching_cubes(
field: Float[torch.Tensor, "nx ny nz"],
threshold: float = 0.0,
coords: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> "Mesh":
r"""Extract an isosurface from a 3D scalar field using marching cubes.

Given a volumetric scalar field (e.g. a signed distance field), this
function extracts the isosurface at the specified threshold and returns
it as a triangle :class:`~physicsnemo.mesh.Mesh`.

When ``coords`` is provided, vertex positions are mapped from grid-index
space into the physical coordinate system defined by the coordinate
vectors. When ``coords`` is ``None``, vertices are returned in grid-index
space.

Uses `NVIDIA Warp <https://nvidia.github.io/warp/modules/runtime.html#marching-cubes>`_
for the marching cubes implementation.

Parameters
----------
field : torch.Tensor
A 3D scalar field with shape :math:`(N_x, N_y, N_z)`. Converted to
float32 internally if necessary.
threshold : float, optional
Iso-value at which to extract the surface. Default is ``0.0``, which
is the standard choice for signed distance fields.
coords : tuple of 3 torch.Tensor, optional
Physical coordinates along each grid axis, as 1D tensors of lengths
:math:`N_x`, :math:`N_y`, :math:`N_z` respectively (e.g. from
``torch.linspace``). When provided, output vertices are linearly
mapped from grid-index space into the coordinate system defined by
these vectors. When ``None``, vertices are in grid-index space.

Returns
-------
Mesh
A triangle mesh with ``points`` of shape :math:`(N_v, 3)` (float32)
and ``cells`` of shape :math:`(N_f, 3)` (int64).

Raises
------
NotImplementedError
If ``field`` is not 3-dimensional (higher/lower dimensions may be
supported in a future release).
ValueError
If ``coords`` is provided but the lengths do not match the
corresponding ``field`` dimensions.

Notes
-----
This operation is **not differentiable**. The input tensor is detached
and transferred to CPU/NumPy before being passed to Warp's marching cubes
kernel, so gradients do not flow through this function.

Examples
--------
Extract the zero-level set of a sphere SDF on a 64^3 grid in physical
coordinates:

>>> import torch
>>> from physicsnemo.mesh.generate import marching_cubes
>>> coords = torch.linspace(-1, 1, 64)
>>> xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing="ij")
>>> sdf = torch.sqrt(xx**2 + yy**2 + zz**2) - 0.5
>>> sphere = marching_cubes(sdf, threshold=0.0, coords=(coords, coords, coords))
>>> sphere.n_manifold_dims
2
>>> sphere.n_spatial_dims
3
"""
from physicsnemo.mesh.mesh import Mesh

if field.ndim != 3:
raise NotImplementedError(
f"Only 3D scalar fields are currently supported, got {field.ndim}D "
f"tensor with shape {tuple(field.shape)}"
)

if coords is not None:
for dim, c in enumerate(coords):
if c.shape[0] != field.shape[dim]:
raise ValueError(
f"coords[{dim}] has length {c.shape[0]}, but field has "
f"size {field.shape[dim]} along dimension {dim}"
)

field_np = field.detach().cpu().numpy().astype(np.float32)
field_wp = wp.array(field_np)

mc = wp.MarchingCubes(
nx=field_np.shape[0],
ny=field_np.shape[1],
nz=field_np.shape[2],
)
mc.surface(field=field_wp, threshold=threshold)

points = torch.as_tensor(mc.verts.numpy(), dtype=torch.float32) # (N_v, 3)
cells = torch.as_tensor(
mc.indices.numpy().reshape(-1, 3), dtype=torch.int64
) # (N_f, 3)

### Map from grid-index space to physical coordinates
if coords is not None:
for dim, c in enumerate(coords):
n = c.shape[0]
origin = c[0].item()
spacing = (c[-1].item() - origin) / (n - 1) if n > 1 else 1.0
points[:, dim] = origin + points[:, dim] * spacing

return Mesh(points=points, cells=cells)
15 changes: 15 additions & 0 deletions test/mesh/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
161 changes: 161 additions & 0 deletions test/mesh/generate/test_marching_cubes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for physicsnemo.mesh.generate.marching_cubes."""

import math

import pytest
import torch

from physicsnemo.mesh.generate import marching_cubes


def _sphere_sdf(
resolution: int = 32, radius: float = 0.5
) -> tuple[torch.Tensor, torch.Tensor]:
"""Create a sphere SDF on a [-1, 1]^3 grid.

Returns the SDF field and the 1D coordinate vector (same for all axes).
"""
coords = torch.linspace(-1, 1, resolution)
xx, yy, zz = torch.meshgrid(coords, coords, coords, indexing="ij")
sdf = torch.sqrt(xx**2 + yy**2 + zz**2) - radius
return sdf, coords


class TestMarchingCubes:
"""Tests for marching cubes isosurface extraction."""

def test_returns_triangle_mesh(self):
sdf, _ = _sphere_sdf()
mesh = marching_cubes(sdf)
assert mesh.n_spatial_dims == 3
assert mesh.n_manifold_dims == 2

def test_nonempty_output(self):
sdf, _ = _sphere_sdf()
mesh = marching_cubes(sdf)
assert mesh.n_points > 0
assert mesh.n_cells > 0

def test_cell_indices_in_range(self):
sdf, _ = _sphere_sdf()
mesh = marching_cubes(sdf)
assert mesh.cells.min() >= 0
assert mesh.cells.max() < mesh.n_points

def test_dtypes(self):
sdf, _ = _sphere_sdf()
mesh = marching_cubes(sdf)
assert mesh.points.dtype == torch.float32
assert mesh.cells.dtype == torch.int64

def test_custom_threshold(self):
sdf, _ = _sphere_sdf(resolution=32, radius=0.5)
mesh_small = marching_cubes(sdf, threshold=0.2)
mesh_large = marching_cubes(sdf, threshold=-0.2)
assert mesh_small.n_points > mesh_large.n_points


class TestCoords:
"""Tests for the coords parameter (physical coordinate mapping)."""

def test_vertices_in_physical_space(self):
"""With coords, vertices should lie within the coordinate bounds."""
sdf, coords = _sphere_sdf(resolution=32, radius=0.5)
mesh = marching_cubes(sdf, coords=(coords, coords, coords))
assert mesh.points.min() >= coords[0].item()
assert mesh.points.max() <= coords[-1].item()

def test_vertices_in_index_space_without_coords(self):
"""Without coords, vertices should be in grid-index space."""
sdf, _ = _sphere_sdf(resolution=32, radius=0.5)
mesh = marching_cubes(sdf)
assert mesh.points.min() >= 0
assert mesh.points.max() <= 31

def test_coords_length_mismatch_raises(self):
sdf, _ = _sphere_sdf(resolution=32)
wrong = torch.linspace(0, 1, 64)
with pytest.raises(ValueError, match="coords"):
marching_cubes(sdf, coords=(wrong, wrong, wrong))

def test_anisotropic_coords(self):
"""Different coordinate ranges per axis should scale accordingly."""
sdf, _ = _sphere_sdf(resolution=32, radius=0.5)
cx = torch.linspace(0, 10, 32)
cy = torch.linspace(-5, 5, 32)
cz = torch.linspace(0, 1, 32)
mesh = marching_cubes(sdf, coords=(cx, cy, cz))
assert mesh.points[:, 0].min() >= 0
assert mesh.points[:, 0].max() <= 10
assert mesh.points[:, 1].min() >= -5
assert mesh.points[:, 1].max() <= 5
assert mesh.points[:, 2].min() >= 0
assert mesh.points[:, 2].max() <= 1


class TestGeometricAccuracy:
"""Geometric validation of extracted isosurfaces."""

def test_sphere_surface_area(self):
"""Surface area of extracted sphere should approximate 4*pi*r^2."""
radius = 0.5
sdf, coords = _sphere_sdf(resolution=64, radius=radius)
mesh = marching_cubes(sdf, coords=(coords, coords, coords))

total_area = mesh.cell_areas.sum().item()
expected_area = 4 * math.pi * radius**2

assert total_area == pytest.approx(expected_area, rel=0.02)

def test_sphere_is_watertight(self):
"""Extracted sphere should be a closed surface."""
sdf, coords = _sphere_sdf(resolution=32, radius=0.5)
mesh = marching_cubes(sdf, coords=(coords, coords, coords))
assert mesh.is_watertight()

def test_sphere_is_manifold(self):
"""Extracted sphere should be a valid 2-manifold."""
sdf, coords = _sphere_sdf(resolution=32, radius=0.5)
mesh = marching_cubes(sdf, coords=(coords, coords, coords))
assert mesh.is_manifold()

def test_sphere_centroid_near_origin(self):
"""Centroid of an origin-centered sphere should be near (0, 0, 0)."""
sdf, coords = _sphere_sdf(resolution=64, radius=0.5)
mesh = marching_cubes(sdf, coords=(coords, coords, coords))
centroid = mesh.points.mean(dim=0)
assert torch.allclose(centroid, torch.zeros(3), atol=0.05)

def test_no_degenerate_cells(self):
"""Extracted mesh should have no zero-area triangles."""
sdf, coords = _sphere_sdf(resolution=32, radius=0.5)
mesh = marching_cubes(sdf, coords=(coords, coords, coords))
assert (mesh.cell_areas > 0).all()


class TestMarchingCubesValidation:
"""Input validation and error handling."""

def test_rejects_2d_input(self):
with pytest.raises(NotImplementedError, match="3D scalar fields"):
marching_cubes(torch.randn(10, 10))

def test_rejects_4d_input(self):
with pytest.raises(NotImplementedError, match="3D scalar fields"):
marching_cubes(torch.randn(10, 10, 10, 10))
Loading