Skip to content

Commit fb4f159

Browse files
authored
Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D (#1483)
* Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D * Fixes docstring example for compute_cell_normals to reflect correct normal vector output in 2D case. * Refactor compute_cell_areas and compute_cell_normals functions to use match-case syntax for improved readability and maintainability.
1 parent 219aca3 commit fb4f159

File tree

11 files changed

+1097
-252
lines changed

11 files changed

+1097
-252
lines changed

physicsnemo/mesh/geometry/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
1919
This module contains fundamental geometric operations that are shared across
2020
the codebase, including:
21+
- Cell area (n-simplex volume) computation
22+
- Cell normal computation for codimension-1 simplices
2123
- Interior angle computation for n-simplices
2224
- Dual mesh (Voronoi/circumcentric) computations
2325
- Circumcenter calculations
@@ -32,6 +34,8 @@
3234
compute_vertex_angle_sums,
3335
compute_vertex_angles,
3436
)
37+
from physicsnemo.mesh.geometry._cell_areas import compute_cell_areas
38+
from physicsnemo.mesh.geometry._cell_normals import compute_cell_normals
3539
from physicsnemo.mesh.geometry.dual_meshes import (
3640
compute_circumcenters,
3741
compute_cotan_weights_fem,
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Cell area (n-simplex volume) computation for simplicial meshes.
18+
19+
Computes the volume of each n-simplex from its edge vectors using
20+
dimension-specific closed-form expressions where possible:
21+
22+
- **Edges** (n=1): vector norm.
23+
- **Triangles** (n=2): Lagrange identity (works in any spatial dimension).
24+
- **Tetrahedra** (n=3): scalar triple product in 3-space, or Sarrus' rule
25+
on the 3x3 Gram matrix for higher spatial dimensions.
26+
- **General** (n>=4): Gram determinant via ``torch.det``.
27+
28+
The closed-form branches use only multiply-add-sqrt operations, so they
29+
support reduced-precision dtypes (bfloat16, float16) natively. The general
30+
fallback disables ``torch.autocast`` to keep ``torch.matmul`` in the
31+
native dtype, since ``torch.det`` dispatches to cuBLAS LU factorization
32+
which does not support reduced-precision dtypes.
33+
"""
34+
35+
import math
36+
37+
import torch
38+
from jaxtyping import Float
39+
40+
41+
def compute_cell_areas(
42+
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
43+
) -> Float[torch.Tensor, " n_cells"]:
44+
"""Compute volumes (areas) of n-simplices from edge vectors.
45+
46+
Given the edge vectors ``e_i = v_{i+1} - v_0`` for each simplex, computes
47+
the n-dimensional volume:
48+
49+
.. math::
50+
\\text{vol} = \\frac{1}{n!} \\sqrt{\\lvert \\det(E E^T) \\rvert}
51+
52+
where *E* is the matrix whose rows are the edge vectors. Specialized
53+
closed-form expressions are used for n <= 3 (see module docstring).
54+
55+
Args:
56+
relative_vectors: Edge vectors of shape
57+
``(n_cells, n_manifold_dims, n_spatial_dims)``.
58+
Row *i* is the vector from vertex 0 to vertex *i+1* of each
59+
simplex.
60+
61+
Returns:
62+
Tensor of shape ``(n_cells,)`` with the volume of each simplex.
63+
For 1-simplices this is edge length, for 2-simplices triangle area,
64+
for 3-simplices tetrahedral volume, etc.
65+
66+
Examples:
67+
>>> # Unit right triangle in 2D
68+
>>> vecs = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]])
69+
>>> compute_cell_areas(vecs)
70+
tensor([0.5000])
71+
72+
>>> # Unit edge in 3D
73+
>>> vecs = torch.tensor([[[1.0, 0.0, 0.0]]])
74+
>>> compute_cell_areas(vecs)
75+
tensor([1.])
76+
77+
>>> # Regular tetrahedron
78+
>>> vecs = torch.tensor([[[1.0, 0.0, 0.0],
79+
... [0.5, 0.866025, 0.0],
80+
... [0.5, 0.288675, 0.816497]]])
81+
>>> compute_cell_areas(vecs).item() # doctest: +SKIP
82+
0.1178...
83+
"""
84+
n_manifold_dims = relative_vectors.shape[-2]
85+
86+
match n_manifold_dims:
87+
case 1:
88+
return _edge_lengths(relative_vectors)
89+
case 2:
90+
return _triangle_areas(relative_vectors)
91+
case 3:
92+
return _tetrahedron_volumes(relative_vectors)
93+
case _:
94+
return _gram_det_volumes(relative_vectors)
95+
96+
97+
# ---------------------------------------------------------------------------
98+
# Specialized branches
99+
# ---------------------------------------------------------------------------
100+
101+
102+
def _edge_lengths(
103+
relative_vectors: Float[torch.Tensor, "n_cells 1 n_spatial_dims"],
104+
) -> Float[torch.Tensor, " n_cells"]:
105+
"""Edge length = ||e1||."""
106+
return relative_vectors[:, 0].norm(dim=-1)
107+
108+
109+
def _triangle_areas(
110+
relative_vectors: Float[torch.Tensor, "n_cells 2 n_spatial_dims"],
111+
) -> Float[torch.Tensor, " n_cells"]:
112+
r"""Triangle area via Lagrange's identity (any spatial dimension).
113+
114+
.. math::
115+
A = \tfrac{1}{2}\sqrt{\|e_1\|^2 \|e_2\|^2 - (e_1 \cdot e_2)^2}
116+
117+
This is equivalent to ``||e1 x e2|| / 2`` but generalises beyond 3-space.
118+
"""
119+
e1, e2 = relative_vectors[:, 0], relative_vectors[:, 1]
120+
d11 = (e1 * e1).sum(-1)
121+
d22 = (e2 * e2).sum(-1)
122+
d12 = (e1 * e2).sum(-1)
123+
# clamp guards against tiny negative values from floating-point roundoff
124+
return (d11 * d22 - d12 * d12).clamp(min=0).sqrt() / 2
125+
126+
127+
def _tetrahedron_volumes(
128+
relative_vectors: Float[torch.Tensor, "n_cells 3 n_spatial_dims"],
129+
) -> Float[torch.Tensor, " n_cells"]:
130+
"""Tetrahedral volume, dispatching on spatial dimension."""
131+
n_spatial_dims = relative_vectors.shape[-1]
132+
if n_spatial_dims == 3:
133+
return _tetrahedron_volumes_3d(relative_vectors)
134+
return _tetrahedron_volumes_general(relative_vectors)
135+
136+
137+
def _tetrahedron_volumes_3d(
138+
relative_vectors: Float[torch.Tensor, "n_cells 3 3"],
139+
) -> Float[torch.Tensor, " n_cells"]:
140+
r"""Tetrahedral volume via scalar triple product (3D only).
141+
142+
.. math::
143+
V = \frac{1}{6} \lvert e_1 \cdot (e_2 \times e_3) \rvert
144+
"""
145+
e1, e2, e3 = relative_vectors[:, 0], relative_vectors[:, 1], relative_vectors[:, 2]
146+
return (e1 * torch.linalg.cross(e2, e3)).sum(-1).abs() / 6
147+
148+
149+
def _tetrahedron_volumes_general(
150+
relative_vectors: Float[torch.Tensor, "n_cells 3 n_spatial_dims"],
151+
) -> Float[torch.Tensor, " n_cells"]:
152+
r"""Tetrahedral volume via Sarrus' rule on the 3x3 Gram matrix.
153+
154+
Computes the 6 unique entries of the symmetric Gram matrix
155+
:math:`G_{ij} = e_i \cdot e_j` and evaluates its determinant with the
156+
closed-form 3x3 expansion. Works for any spatial dimension >= 3.
157+
"""
158+
e1, e2, e3 = relative_vectors[:, 0], relative_vectors[:, 1], relative_vectors[:, 2]
159+
### 6 unique dot products (G is symmetric)
160+
g11 = (e1 * e1).sum(-1)
161+
g22 = (e2 * e2).sum(-1)
162+
g33 = (e3 * e3).sum(-1)
163+
g12 = (e1 * e2).sum(-1)
164+
g13 = (e1 * e3).sum(-1)
165+
g23 = (e2 * e3).sum(-1)
166+
### Sarrus' rule: det(G) expanded along first row
167+
det_G = (
168+
g11 * (g22 * g33 - g23 * g23)
169+
- g12 * (g12 * g33 - g23 * g13)
170+
+ g13 * (g12 * g23 - g22 * g13)
171+
)
172+
return det_G.clamp(min=0).sqrt() / 6
173+
174+
175+
def _gram_det_volumes(
176+
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
177+
) -> Float[torch.Tensor, " n_cells"]:
178+
r"""General n-simplex volume via Gram determinant (n >= 4).
179+
180+
Falls back to ``torch.matmul`` + ``torch.det`` for manifold dimensions
181+
that lack a closed-form specialization. Disables ``torch.autocast`` so
182+
that ``matmul`` operates in the native dtype of the input, because
183+
``torch.det`` dispatches to cuBLAS LU factorization which does not
184+
support reduced-precision dtypes (bfloat16, float16).
185+
"""
186+
with torch.autocast(device_type=relative_vectors.device.type, enabled=False):
187+
gram_matrix = torch.matmul(
188+
relative_vectors,
189+
relative_vectors.transpose(-2, -1),
190+
)
191+
n_manifold_dims = relative_vectors.shape[-2]
192+
factorial = math.factorial(n_manifold_dims)
193+
return gram_matrix.det().abs().sqrt() / factorial
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
2+
# SPDX-FileCopyrightText: All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Cell normal computation for codimension-1 simplicial meshes.
18+
19+
Computes unit normal vectors for each cell using the generalized cross
20+
product (Hodge star), with dimension-specific closed-form expressions
21+
where possible:
22+
23+
- **Edges in 2D** (d=2): 90-degree counterclockwise rotation.
24+
- **Triangles in 3D** (d=3): ``torch.linalg.cross``.
25+
- **General** (d>=4): signed minor determinants of the edge-vector matrix.
26+
27+
The closed-form branches for d=2 and d=3 use only multiply-add
28+
operations, so they support reduced-precision dtypes (bfloat16, float16)
29+
natively. The general fallback disables ``torch.autocast`` to keep
30+
``torch.det`` in the native dtype, since it dispatches to cuBLAS LU
31+
factorization which does not support reduced-precision dtypes.
32+
"""
33+
34+
import torch
35+
import torch.nn.functional as F
36+
from jaxtyping import Float
37+
38+
39+
def compute_cell_normals(
40+
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
41+
) -> Float[torch.Tensor, "n_cells n_spatial_dims"]:
42+
"""Compute unit normal vectors for codimension-1 simplices.
43+
44+
Given the edge vectors ``e_i = v_{i+1} - v_0`` for each simplex, computes
45+
the outward-pointing unit normal via the generalized cross product.
46+
The caller must ensure the codimension-1 constraint:
47+
``n_manifold_dims == n_spatial_dims - 1``.
48+
49+
Args:
50+
relative_vectors: Edge vectors of shape
51+
``(n_cells, n_manifold_dims, n_spatial_dims)``.
52+
Row *i* is the vector from vertex 0 to vertex *i+1* of each
53+
simplex. Must satisfy ``n_manifold_dims == n_spatial_dims - 1``.
54+
55+
Returns:
56+
Tensor of shape ``(n_cells, n_spatial_dims)`` containing unit normal
57+
vectors. For degenerate cells (zero-area), the normal is a zero
58+
vector (from ``F.normalize``'s default behavior).
59+
60+
Examples:
61+
>>> # Edge in 2D: normal is 90-degree CCW rotation
62+
>>> vecs = torch.tensor([[[1.0, 0.0]]])
63+
>>> compute_cell_normals(vecs)
64+
tensor([[-0., 1.]])
65+
66+
>>> # Triangle in XY-plane: normal is +Z
67+
>>> vecs = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]])
68+
>>> compute_cell_normals(vecs)
69+
tensor([[0., 0., 1.]])
70+
"""
71+
n_spatial_dims = relative_vectors.shape[-1]
72+
73+
match n_spatial_dims:
74+
case 2:
75+
return _normals_2d(relative_vectors)
76+
case 3:
77+
return _normals_3d(relative_vectors)
78+
case _:
79+
return _normals_general(relative_vectors)
80+
81+
82+
# ---------------------------------------------------------------------------
83+
# Specialized branches
84+
# ---------------------------------------------------------------------------
85+
86+
87+
def _normals_2d(
88+
relative_vectors: Float[torch.Tensor, "n_cells 1 2"],
89+
) -> Float[torch.Tensor, "n_cells 2"]:
90+
"""Edge normals in 2D via 90-degree CCW rotation: (x, y) -> (-y, x)."""
91+
e = relative_vectors[:, 0] # (n_cells, 2)
92+
normals = torch.stack([-e[:, 1], e[:, 0]], dim=-1)
93+
return F.normalize(normals, dim=-1)
94+
95+
96+
def _normals_3d(
97+
relative_vectors: Float[torch.Tensor, "n_cells 2 3"],
98+
) -> Float[torch.Tensor, "n_cells 3"]:
99+
"""Triangle normals in 3D via cross product."""
100+
normals = torch.linalg.cross(relative_vectors[:, 0], relative_vectors[:, 1])
101+
return F.normalize(normals, dim=-1)
102+
103+
104+
def _normals_general(
105+
relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"],
106+
) -> Float[torch.Tensor, "n_cells n_spatial_dims"]:
107+
"""Normals in d >= 4 via signed minor determinants (Hodge star).
108+
109+
For (n-1) vectors in R^n (rows of E), the normal components are:
110+
n_i = (-1)^(n-1+i) * det(E with column i removed)
111+
112+
Disables ``torch.autocast`` because ``torch.det`` dispatches to cuBLAS
113+
LU factorization which does not support reduced-precision dtypes.
114+
"""
115+
n_spatial_dims = relative_vectors.shape[-1]
116+
n_manifold_dims = relative_vectors.shape[-2]
117+
118+
with torch.autocast(device_type=relative_vectors.device.type, enabled=False):
119+
normal_components: list[torch.Tensor] = []
120+
121+
for i in range(n_spatial_dims):
122+
# (n-1)x(n-1) submatrix: remove column i
123+
# Uses slice concatenation to avoid aten.nonzero (torch.compile
124+
# graph break from dynamic shapes).
125+
submatrix = torch.cat(
126+
[relative_vectors[:, :, :i], relative_vectors[:, :, i + 1 :]],
127+
dim=-1,
128+
)
129+
det = submatrix.det()
130+
sign = (-1) ** (n_manifold_dims + i)
131+
normal_components.append(sign * det)
132+
133+
normals = torch.stack(normal_components, dim=-1)
134+
135+
return F.normalize(normals, dim=-1)

physicsnemo/mesh/io/io_pyvista.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def to_pyvista(
324324
pv = importlib.import_module("pyvista")
325325

326326
### Convert points to numpy and pad to 3D if needed (PyVista requires 3D points)
327-
points_np = mesh.points.cpu().numpy()
327+
points_np = mesh.points.float().cpu().numpy()
328328

329329
if mesh.n_spatial_dims < 3:
330330
# Pad with zeros to make 3D
@@ -373,19 +373,19 @@ def to_pyvista(
373373

374374
### Convert data dictionaries (flatten high-rank tensors for VTK compatibility)
375375
for k, v in mesh.point_data.items(include_nested=True, leaves_only=True):
376-
arr = v.cpu().numpy()
376+
arr = v.float().cpu().numpy()
377377
pv_mesh.point_data[str(k)] = (
378378
arr.reshape(arr.shape[0], -1) if arr.ndim > 2 else arr
379379
)
380380

381381
for k, v in mesh.cell_data.items(include_nested=True, leaves_only=True):
382-
arr = v.cpu().numpy()
382+
arr = v.float().cpu().numpy()
383383
pv_mesh.cell_data[str(k)] = (
384384
arr.reshape(arr.shape[0], -1) if arr.ndim > 2 else arr
385385
)
386386

387387
for k, v in mesh.global_data.items(include_nested=True, leaves_only=True):
388-
arr = v.cpu().numpy()
388+
arr = v.float().cpu().numpy()
389389
pv_mesh.field_data[str(k)] = (
390390
arr.reshape(arr.shape[0], -1) if arr.ndim > 2 else arr
391391
)

0 commit comments

Comments
 (0)