-
Notifications
You must be signed in to change notification settings - Fork 611
Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D #1483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
peterdsharpe
merged 6 commits into
NVIDIA:main
from
peterdsharpe:psharpe/add-mesh-improvements-for-GLOBE-3D
Mar 12, 2026
+1,097
−252
Merged
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
6344bfc
Adds the PhysicsNeMo-Mesh changes required for GLOBE 3D
peterdsharpe 2318514
Merge branch 'main' into psharpe/add-mesh-improvements-for-GLOBE-3D
peterdsharpe ce5f84c
Merge branch 'main' into psharpe/add-mesh-improvements-for-GLOBE-3D
peterdsharpe 5e41a20
Fixes docstring example for compute_cell_normals to reflect correct n…
peterdsharpe 6d6806c
Refactor compute_cell_areas and compute_cell_normals functions to use…
peterdsharpe 779b10b
Merge branch 'main' into psharpe/add-mesh-improvements-for-GLOBE-3D
peterdsharpe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Cell area (n-simplex volume) computation for simplicial meshes. | ||
|
|
||
| Computes the volume of each n-simplex from its edge vectors using | ||
| dimension-specific closed-form expressions where possible: | ||
|
|
||
| - **Edges** (n=1): vector norm. | ||
| - **Triangles** (n=2): Lagrange identity (works in any spatial dimension). | ||
| - **Tetrahedra** (n=3): scalar triple product in 3-space, or Sarrus' rule | ||
| on the 3x3 Gram matrix for higher spatial dimensions. | ||
| - **General** (n>=4): Gram determinant via ``torch.det``. | ||
|
|
||
| The closed-form branches use only multiply-add-sqrt operations, so they | ||
| support reduced-precision dtypes (bfloat16, float16) natively. The general | ||
| fallback disables ``torch.autocast`` to keep ``torch.matmul`` in the | ||
| native dtype, since ``torch.det`` dispatches to cuBLAS LU factorization | ||
| which does not support reduced-precision dtypes. | ||
| """ | ||
|
|
||
| import math | ||
|
|
||
| import torch | ||
| from jaxtyping import Float | ||
|
|
||
|
|
||
| def compute_cell_areas( | ||
| relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| """Compute volumes (areas) of n-simplices from edge vectors. | ||
|
|
||
| Given the edge vectors ``e_i = v_{i+1} - v_0`` for each simplex, computes | ||
| the n-dimensional volume: | ||
|
|
||
| .. math:: | ||
| \\text{vol} = \\frac{1}{n!} \\sqrt{\\lvert \\det(E E^T) \\rvert} | ||
|
|
||
| where *E* is the matrix whose rows are the edge vectors. Specialized | ||
| closed-form expressions are used for n <= 3 (see module docstring). | ||
|
|
||
| Args: | ||
| relative_vectors: Edge vectors of shape | ||
| ``(n_cells, n_manifold_dims, n_spatial_dims)``. | ||
| Row *i* is the vector from vertex 0 to vertex *i+1* of each | ||
| simplex. | ||
|
|
||
| Returns: | ||
| Tensor of shape ``(n_cells,)`` with the volume of each simplex. | ||
| For 1-simplices this is edge length, for 2-simplices triangle area, | ||
| for 3-simplices tetrahedral volume, etc. | ||
|
|
||
| Examples: | ||
| >>> # Unit right triangle in 2D | ||
| >>> vecs = torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]) | ||
| >>> compute_cell_areas(vecs) | ||
| tensor([0.5000]) | ||
|
|
||
| >>> # Unit edge in 3D | ||
| >>> vecs = torch.tensor([[[1.0, 0.0, 0.0]]]) | ||
| >>> compute_cell_areas(vecs) | ||
| tensor([1.]) | ||
|
|
||
| >>> # Regular tetrahedron | ||
| >>> vecs = torch.tensor([[[1.0, 0.0, 0.0], | ||
| ... [0.5, 0.866025, 0.0], | ||
| ... [0.5, 0.288675, 0.816497]]]) | ||
| >>> compute_cell_areas(vecs).item() # doctest: +SKIP | ||
| 0.1178... | ||
| """ | ||
| n_manifold_dims = relative_vectors.shape[-2] | ||
|
|
||
| if n_manifold_dims == 1: | ||
| return _edge_lengths(relative_vectors) | ||
| if n_manifold_dims == 2: | ||
| return _triangle_areas(relative_vectors) | ||
| if n_manifold_dims == 3: | ||
| return _tetrahedron_volumes(relative_vectors) | ||
| return _gram_det_volumes(relative_vectors) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Specialized branches | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _edge_lengths( | ||
peterdsharpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| relative_vectors: Float[torch.Tensor, "n_cells 1 n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| """Edge length = ||e1||.""" | ||
| return relative_vectors[:, 0].norm(dim=-1) | ||
|
|
||
|
|
||
| def _triangle_areas( | ||
| relative_vectors: Float[torch.Tensor, "n_cells 2 n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| r"""Triangle area via Lagrange's identity (any spatial dimension). | ||
|
|
||
| .. math:: | ||
| A = \tfrac{1}{2}\sqrt{\|e_1\|^2 \|e_2\|^2 - (e_1 \cdot e_2)^2} | ||
|
|
||
| This is equivalent to ``||e1 x e2|| / 2`` but generalises beyond 3-space. | ||
| """ | ||
| e1, e2 = relative_vectors[:, 0], relative_vectors[:, 1] | ||
| d11 = (e1 * e1).sum(-1) | ||
| d22 = (e2 * e2).sum(-1) | ||
| d12 = (e1 * e2).sum(-1) | ||
| # clamp guards against tiny negative values from floating-point roundoff | ||
| return (d11 * d22 - d12 * d12).clamp(min=0).sqrt() / 2 | ||
|
|
||
|
|
||
| def _tetrahedron_volumes( | ||
| relative_vectors: Float[torch.Tensor, "n_cells 3 n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| """Tetrahedral volume, dispatching on spatial dimension.""" | ||
| n_spatial_dims = relative_vectors.shape[-1] | ||
| if n_spatial_dims == 3: | ||
| return _tetrahedron_volumes_3d(relative_vectors) | ||
| return _tetrahedron_volumes_general(relative_vectors) | ||
|
|
||
|
|
||
| def _tetrahedron_volumes_3d( | ||
| relative_vectors: Float[torch.Tensor, "n_cells 3 3"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| r"""Tetrahedral volume via scalar triple product (3D only). | ||
|
|
||
| .. math:: | ||
| V = \frac{1}{6} \lvert e_1 \cdot (e_2 \times e_3) \rvert | ||
| """ | ||
| e1, e2, e3 = relative_vectors[:, 0], relative_vectors[:, 1], relative_vectors[:, 2] | ||
| return (e1 * torch.linalg.cross(e2, e3)).sum(-1).abs() / 6 | ||
|
|
||
|
|
||
| def _tetrahedron_volumes_general( | ||
| relative_vectors: Float[torch.Tensor, "n_cells 3 n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| r"""Tetrahedral volume via Sarrus' rule on the 3x3 Gram matrix. | ||
|
|
||
| Computes the 6 unique entries of the symmetric Gram matrix | ||
| :math:`G_{ij} = e_i \cdot e_j` and evaluates its determinant with the | ||
| closed-form 3x3 expansion. Works for any spatial dimension >= 3. | ||
| """ | ||
| e1, e2, e3 = relative_vectors[:, 0], relative_vectors[:, 1], relative_vectors[:, 2] | ||
| ### 6 unique dot products (G is symmetric) | ||
| g11 = (e1 * e1).sum(-1) | ||
| g22 = (e2 * e2).sum(-1) | ||
| g33 = (e3 * e3).sum(-1) | ||
| g12 = (e1 * e2).sum(-1) | ||
| g13 = (e1 * e3).sum(-1) | ||
| g23 = (e2 * e3).sum(-1) | ||
| ### Sarrus' rule: det(G) expanded along first row | ||
| det_G = ( | ||
| g11 * (g22 * g33 - g23 * g23) | ||
| - g12 * (g12 * g33 - g23 * g13) | ||
| + g13 * (g12 * g23 - g22 * g13) | ||
| ) | ||
| return det_G.clamp(min=0).sqrt() / 6 | ||
|
|
||
|
|
||
| def _gram_det_volumes( | ||
| relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, " n_cells"]: | ||
| r"""General n-simplex volume via Gram determinant (n >= 4). | ||
|
|
||
| Falls back to ``torch.matmul`` + ``torch.det`` for manifold dimensions | ||
| that lack a closed-form specialization. Disables ``torch.autocast`` so | ||
| that ``matmul`` operates in the native dtype of the input, because | ||
| ``torch.det`` dispatches to cuBLAS LU factorization which does not | ||
| support reduced-precision dtypes (bfloat16, float16). | ||
| """ | ||
| with torch.autocast(device_type=relative_vectors.device.type, enabled=False): | ||
| gram_matrix = torch.matmul( | ||
| relative_vectors, | ||
| relative_vectors.transpose(-2, -1), | ||
| ) | ||
| n_manifold_dims = relative_vectors.shape[-2] | ||
| factorial = math.factorial(n_manifold_dims) | ||
| return gram_matrix.det().abs().sqrt() / factorial | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Cell normal computation for codimension-1 simplicial meshes. | ||
|
|
||
| Computes unit normal vectors for each cell using the generalized cross | ||
| product (Hodge star), with dimension-specific closed-form expressions | ||
| where possible: | ||
|
|
||
| - **Edges in 2D** (d=2): 90-degree counterclockwise rotation. | ||
| - **Triangles in 3D** (d=3): ``torch.linalg.cross``. | ||
| - **General** (d>=4): signed minor determinants of the edge-vector matrix. | ||
|
|
||
| The closed-form branches for d=2 and d=3 use only multiply-add | ||
| operations, so they support reduced-precision dtypes (bfloat16, float16) | ||
| natively. The general fallback disables ``torch.autocast`` to keep | ||
| ``torch.det`` in the native dtype, since it dispatches to cuBLAS LU | ||
| factorization which does not support reduced-precision dtypes. | ||
| """ | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
| from jaxtyping import Float | ||
|
|
||
|
|
||
| def compute_cell_normals( | ||
| relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, "n_cells n_spatial_dims"]: | ||
| """Compute unit normal vectors for codimension-1 simplices. | ||
|
|
||
| Given the edge vectors ``e_i = v_{i+1} - v_0`` for each simplex, computes | ||
| the outward-pointing unit normal via the generalized cross product. | ||
| The caller must ensure the codimension-1 constraint: | ||
| ``n_manifold_dims == n_spatial_dims - 1``. | ||
|
|
||
| Args: | ||
| relative_vectors: Edge vectors of shape | ||
| ``(n_cells, n_manifold_dims, n_spatial_dims)``. | ||
| Row *i* is the vector from vertex 0 to vertex *i+1* of each | ||
| simplex. Must satisfy ``n_manifold_dims == n_spatial_dims - 1``. | ||
|
|
||
| Returns: | ||
| Tensor of shape ``(n_cells, n_spatial_dims)`` containing unit normal | ||
| vectors. For degenerate cells (zero-area), the normal is a zero | ||
| vector (from ``F.normalize``'s default behavior). | ||
|
|
||
| Examples: | ||
| >>> # Edge in 2D: normal is 90-degree CCW rotation | ||
| >>> vecs = torch.tensor([[[1.0, 0.0]]]) | ||
| >>> compute_cell_normals(vecs) | ||
| tensor([[0., 1.]]) | ||
|
|
||
| >>> # Triangle in XY-plane: normal is +Z | ||
| >>> vecs = torch.tensor([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]]) | ||
| >>> compute_cell_normals(vecs) | ||
| tensor([[0., 0., 1.]]) | ||
| """ | ||
| n_spatial_dims = relative_vectors.shape[-1] | ||
|
|
||
| if n_spatial_dims == 2: | ||
| return _normals_2d(relative_vectors) | ||
| if n_spatial_dims == 3: | ||
| return _normals_3d(relative_vectors) | ||
| return _normals_general(relative_vectors) | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Specialized branches | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _normals_2d( | ||
| relative_vectors: Float[torch.Tensor, "n_cells 1 2"], | ||
| ) -> Float[torch.Tensor, "n_cells 2"]: | ||
| """Edge normals in 2D via 90-degree CCW rotation: (x, y) -> (-y, x).""" | ||
| e = relative_vectors[:, 0] # (n_cells, 2) | ||
| normals = torch.stack([-e[:, 1], e[:, 0]], dim=-1) | ||
| return F.normalize(normals, dim=-1) | ||
|
|
||
|
|
||
| def _normals_3d( | ||
| relative_vectors: Float[torch.Tensor, "n_cells 2 3"], | ||
| ) -> Float[torch.Tensor, "n_cells 3"]: | ||
| """Triangle normals in 3D via cross product.""" | ||
| normals = torch.linalg.cross(relative_vectors[:, 0], relative_vectors[:, 1]) | ||
| return F.normalize(normals, dim=-1) | ||
|
|
||
|
|
||
| def _normals_general( | ||
| relative_vectors: Float[torch.Tensor, "n_cells n_manifold_dims n_spatial_dims"], | ||
| ) -> Float[torch.Tensor, "n_cells n_spatial_dims"]: | ||
| """Normals in d >= 4 via signed minor determinants (Hodge star). | ||
|
|
||
| For (n-1) vectors in R^n (rows of E), the normal components are: | ||
| n_i = (-1)^(n-1+i) * det(E with column i removed) | ||
|
|
||
| Disables ``torch.autocast`` because ``torch.det`` dispatches to cuBLAS | ||
| LU factorization which does not support reduced-precision dtypes. | ||
| """ | ||
| n_spatial_dims = relative_vectors.shape[-1] | ||
| n_manifold_dims = relative_vectors.shape[-2] | ||
|
|
||
| with torch.autocast(device_type=relative_vectors.device.type, enabled=False): | ||
| normal_components: list[torch.Tensor] = [] | ||
|
|
||
| for i in range(n_spatial_dims): | ||
| # (n-1)x(n-1) submatrix: remove column i | ||
| # Uses slice concatenation to avoid aten.nonzero (torch.compile | ||
| # graph break from dynamic shapes). | ||
| submatrix = torch.cat( | ||
| [relative_vectors[:, :, :i], relative_vectors[:, :, i + 1 :]], | ||
| dim=-1, | ||
| ) | ||
| det = submatrix.det() | ||
| sign = (-1) ** (n_manifold_dims + i) | ||
| normal_components.append(sign * det) | ||
|
|
||
| normals = torch.stack(normal_components, dim=-1) | ||
|
|
||
| return F.normalize(normals, dim=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.