Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions physicsnemo/mesh/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

This module contains fundamental geometric operations that are shared across
the codebase, including:
- Cell area (n-simplex volume) computation
- Cell normal computation for codimension-1 simplices
- Interior angle computation for n-simplices
- Dual mesh (Voronoi/circumcentric) computations
- Circumcenter calculations
Expand All @@ -32,6 +34,8 @@
compute_vertex_angle_sums,
compute_vertex_angles,
)
from physicsnemo.mesh.geometry._cell_areas import compute_cell_areas
from physicsnemo.mesh.geometry._cell_normals import compute_cell_normals
from physicsnemo.mesh.geometry.dual_meshes import (
compute_circumcenters,
compute_cotan_weights_fem,
Expand Down
191 changes: 191 additions & 0 deletions physicsnemo/mesh/geometry/_cell_areas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Cell area (n-simplex volume) computation for simplicial meshes.

Computes the volume of each n-simplex from its edge vectors using
dimension-specific closed-form expressions where possible:

- **Edges** (n=1): vector norm.
- **Triangles** (n=2): Lagrange identity (works in any spatial dimension).
- **Tetrahedra** (n=3): scalar triple product in 3-space, or Sarrus' rule
on the 3x3 Gram matrix for higher spatial dimensions.
- **General** (n>=4): Gram determinant via ``torch.det``.

The closed-form branches use only multiply-add-sqrt operations, so they
support reduced-precision dtypes (bfloat16, float16) natively. The general
fallback disables ``torch.autocast`` to keep ``torch.matmul`` in the
native dtype, since ``torch.det`` dispatches to cuBLAS LU factorization
which does not support reduced-precision dtypes.
"""

import math

import torch
from jaxtyping import Float


def compute_cell_areas(
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
) -> Float[torch.Tensor, " n_cells"]:
"""Compute volumes (areas) of n-simplices from edge vectors.

Given the edge vectors ``e_i = v_{i+1} - v_0`` for each simplex, computes
the n-dimensional volume:

.. math::
\\text{vol} = \\frac{1}{n!} \\sqrt{\\lvert \\det(E E^T) \\rvert}

where *E* is the matrix whose rows are the edge vectors. Specialized
closed-form expressions are used for n <= 3 (see module docstring).

Args:
relative_vectors: Edge vectors of shape
``(n_cells, n_manifold_dims, n_spatial_dims)``.
Row *i* is the vector from vertex 0 to vertex *i+1* of each
simplex.

Returns:
Tensor of shape ``(n_cells,)`` with the volume of each simplex.
For 1-simplices this is edge length, for 2-simplices triangle area,
for 3-simplices tetrahedral volume, etc.

Examples:
>>> # Unit right triangle in 2D
>>> vecs = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])
>>> compute_cell_areas(vecs)
tensor([0.5000])

>>> # Unit edge in 3D
>>> vecs = torch.tensor([[[1.0, 0.0, 0.0]]])
>>> compute_cell_areas(vecs)
tensor([1.])

>>> # Regular tetrahedron
>>> vecs = torch.tensor([[[1.0, 0.0, 0.0],
... [0.5, 0.866025, 0.0],
... [0.5, 0.288675, 0.816497]]])
>>> compute_cell_areas(vecs).item() # doctest: +SKIP
0.1178...
"""
n_manifold_dims = relative_vectors.shape[-2]

if n_manifold_dims == 1:
return _edge_lengths(relative_vectors)
if n_manifold_dims == 2:
return _triangle_areas(relative_vectors)
if n_manifold_dims == 3:
return _tetrahedron_volumes(relative_vectors)
return _gram_det_volumes(relative_vectors)


# ---------------------------------------------------------------------------
# Specialized branches
# ---------------------------------------------------------------------------


def _edge_lengths(
relative_vectors: Float[torch.Tensor, "n_cells 1 n_spatial_dims"],
) -> Float[torch.Tensor, " n_cells"]:
"""Edge length = ||e1||."""
return relative_vectors[:, 0].norm(dim=-1)


def _triangle_areas(
relative_vectors: Float[torch.Tensor, "n_cells 2 n_spatial_dims"],
) -> Float[torch.Tensor, " n_cells"]:
r"""Triangle area via Lagrange's identity (any spatial dimension).

.. math::
A = \tfrac{1}{2}\sqrt{\|e_1\|^2 \|e_2\|^2 - (e_1 \cdot e_2)^2}

This is equivalent to ``||e1 x e2|| / 2`` but generalises beyond 3-space.
"""
e1, e2 = relative_vectors[:, 0], relative_vectors[:, 1]
d11 = (e1 * e1).sum(-1)
d22 = (e2 * e2).sum(-1)
d12 = (e1 * e2).sum(-1)
# clamp guards against tiny negative values from floating-point roundoff
return (d11 * d22 - d12 * d12).clamp(min=0).sqrt() / 2


def _tetrahedron_volumes(
relative_vectors: Float[torch.Tensor, "n_cells 3 n_spatial_dims"],
) -> Float[torch.Tensor, " n_cells"]:
"""Tetrahedral volume, dispatching on spatial dimension."""
n_spatial_dims = relative_vectors.shape[-1]
if n_spatial_dims == 3:
return _tetrahedron_volumes_3d(relative_vectors)
return _tetrahedron_volumes_general(relative_vectors)


def _tetrahedron_volumes_3d(
relative_vectors: Float[torch.Tensor, "n_cells 3 3"],
) -> Float[torch.Tensor, " n_cells"]:
r"""Tetrahedral volume via scalar triple product (3D only).

.. math::
V = \frac{1}{6} \lvert e_1 \cdot (e_2 \times e_3) \rvert
"""
e1, e2, e3 = relative_vectors[:, 0], relative_vectors[:, 1], relative_vectors[:, 2]
return (e1 * torch.linalg.cross(e2, e3)).sum(-1).abs() / 6


def _tetrahedron_volumes_general(
relative_vectors: Float[torch.Tensor, "n_cells 3 n_spatial_dims"],
) -> Float[torch.Tensor, " n_cells"]:
r"""Tetrahedral volume via Sarrus' rule on the 3x3 Gram matrix.

Computes the 6 unique entries of the symmetric Gram matrix
:math:`G_{ij} = e_i \cdot e_j` and evaluates its determinant with the
closed-form 3x3 expansion. Works for any spatial dimension >= 3.
"""
e1, e2, e3 = relative_vectors[:, 0], relative_vectors[:, 1], relative_vectors[:, 2]
### 6 unique dot products (G is symmetric)
g11 = (e1 * e1).sum(-1)
g22 = (e2 * e2).sum(-1)
g33 = (e3 * e3).sum(-1)
g12 = (e1 * e2).sum(-1)
g13 = (e1 * e3).sum(-1)
g23 = (e2 * e3).sum(-1)
### Sarrus' rule: det(G) expanded along first row
det_G = (
g11 * (g22 * g33 - g23 * g23)
- g12 * (g12 * g33 - g23 * g13)
+ g13 * (g12 * g23 - g22 * g13)
)
return det_G.clamp(min=0).sqrt() / 6


def _gram_det_volumes(
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
) -> Float[torch.Tensor, " n_cells"]:
r"""General n-simplex volume via Gram determinant (n >= 4).

Falls back to ``torch.matmul`` + ``torch.det`` for manifold dimensions
that lack a closed-form specialization. Disables ``torch.autocast`` so
that ``matmul`` operates in the native dtype of the input, because
``torch.det`` dispatches to cuBLAS LU factorization which does not
support reduced-precision dtypes (bfloat16, float16).
"""
with torch.autocast(device_type=relative_vectors.device.type, enabled=False):
gram_matrix = torch.matmul(
relative_vectors,
relative_vectors.transpose(-2, -1),
)
n_manifold_dims = relative_vectors.shape[-2]
factorial = math.factorial(n_manifold_dims)
return gram_matrix.det().abs().sqrt() / factorial
133 changes: 133 additions & 0 deletions physicsnemo/mesh/geometry/_cell_normals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Cell normal computation for codimension-1 simplicial meshes.

Computes unit normal vectors for each cell using the generalized cross
product (Hodge star), with dimension-specific closed-form expressions
where possible:

- **Edges in 2D** (d=2): 90-degree counterclockwise rotation.
- **Triangles in 3D** (d=3): ``torch.linalg.cross``.
- **General** (d>=4): signed minor determinants of the edge-vector matrix.

The closed-form branches for d=2 and d=3 use only multiply-add
operations, so they support reduced-precision dtypes (bfloat16, float16)
natively. The general fallback disables ``torch.autocast`` to keep
``torch.det`` in the native dtype, since it dispatches to cuBLAS LU
factorization which does not support reduced-precision dtypes.
"""

import torch
import torch.nn.functional as F
from jaxtyping import Float


def compute_cell_normals(
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
) -> Float[torch.Tensor, "n_cells n_spatial_dims"]:
"""Compute unit normal vectors for codimension-1 simplices.

Given the edge vectors ``e_i = v_{i+1} - v_0`` for each simplex, computes
the outward-pointing unit normal via the generalized cross product.
The caller must ensure the codimension-1 constraint:
``n_manifold_dims == n_spatial_dims - 1``.

Args:
relative_vectors: Edge vectors of shape
``(n_cells, n_manifold_dims, n_spatial_dims)``.
Row *i* is the vector from vertex 0 to vertex *i+1* of each
simplex. Must satisfy ``n_manifold_dims == n_spatial_dims - 1``.

Returns:
Tensor of shape ``(n_cells, n_spatial_dims)`` containing unit normal
vectors. For degenerate cells (zero-area), the normal is a zero
vector (from ``F.normalize``'s default behavior).

Examples:
>>> # Edge in 2D: normal is 90-degree CCW rotation
>>> vecs = torch.tensor([[[1.0, 0.0]]])
>>> compute_cell_normals(vecs)
tensor([[0., 1.]])

>>> # Triangle in XY-plane: normal is +Z
>>> vecs = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]])
>>> compute_cell_normals(vecs)
tensor([[0., 0., 1.]])
"""
n_spatial_dims = relative_vectors.shape[-1]

if n_spatial_dims == 2:
return _normals_2d(relative_vectors)
if n_spatial_dims == 3:
return _normals_3d(relative_vectors)
return _normals_general(relative_vectors)


# ---------------------------------------------------------------------------
# Specialized branches
# ---------------------------------------------------------------------------


def _normals_2d(
relative_vectors: Float[torch.Tensor, "n_cells 1 2"],
) -> Float[torch.Tensor, "n_cells 2"]:
"""Edge normals in 2D via 90-degree CCW rotation: (x, y) -> (-y, x)."""
e = relative_vectors[:, 0] # (n_cells, 2)
normals = torch.stack([-e[:, 1], e[:, 0]], dim=-1)
return F.normalize(normals, dim=-1)


def _normals_3d(
relative_vectors: Float[torch.Tensor, "n_cells 2 3"],
) -> Float[torch.Tensor, "n_cells 3"]:
"""Triangle normals in 3D via cross product."""
normals = torch.linalg.cross(relative_vectors[:, 0], relative_vectors[:, 1])
return F.normalize(normals, dim=-1)


def _normals_general(
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
) -> Float[torch.Tensor, "n_cells n_spatial_dims"]:
"""Normals in d >= 4 via signed minor determinants (Hodge star).

For (n-1) vectors in R^n (rows of E), the normal components are:
n_i = (-1)^(n-1+i) * det(E with column i removed)

Disables ``torch.autocast`` because ``torch.det`` dispatches to cuBLAS
LU factorization which does not support reduced-precision dtypes.
"""
n_spatial_dims = relative_vectors.shape[-1]
n_manifold_dims = relative_vectors.shape[-2]

with torch.autocast(device_type=relative_vectors.device.type, enabled=False):
normal_components: list[torch.Tensor] = []

for i in range(n_spatial_dims):
# (n-1)x(n-1) submatrix: remove column i
# Uses slice concatenation to avoid aten.nonzero (torch.compile
# graph break from dynamic shapes).
submatrix = torch.cat(
[relative_vectors[:, :, :i], relative_vectors[:, :, i + 1 :]],
dim=-1,
)
det = submatrix.det()
sign = (-1) ** (n_manifold_dims + i)
normal_components.append(sign * det)

normals = torch.stack(normal_components, dim=-1)

return F.normalize(normals, dim=-1)
8 changes: 4 additions & 4 deletions physicsnemo/mesh/io/io_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def to_pyvista(
pv = importlib.import_module("pyvista")

### Convert points to numpy and pad to 3D if needed (PyVista requires 3D points)
points_np = mesh.points.cpu().numpy()
points_np = mesh.points.float().cpu().numpy()

if mesh.n_spatial_dims < 3:
# Pad with zeros to make 3D
Expand Down Expand Up @@ -373,19 +373,19 @@ def to_pyvista(

### Convert data dictionaries (flatten high-rank tensors for VTK compatibility)
for k, v in mesh.point_data.items(include_nested=True, leaves_only=True):
arr = v.cpu().numpy()
arr = v.float().cpu().numpy()
pv_mesh.point_data[str(k)] = (
arr.reshape(arr.shape[0], -1) if arr.ndim > 2 else arr
)

for k, v in mesh.cell_data.items(include_nested=True, leaves_only=True):
arr = v.cpu().numpy()
arr = v.float().cpu().numpy()
pv_mesh.cell_data[str(k)] = (
arr.reshape(arr.shape[0], -1) if arr.ndim > 2 else arr
)

for k, v in mesh.global_data.items(include_nested=True, leaves_only=True):
arr = v.cpu().numpy()
arr = v.float().cpu().numpy()
pv_mesh.field_data[str(k)] = (
arr.reshape(arr.shape[0], -1) if arr.ndim > 2 else arr
)
Expand Down
Loading
Loading