Skip to content

Commit 8e15c7a

Browse files
authored
Merge pull request #324 from santisoler/all-nodes
Add `total_nodes` method to TreeMeshes
2 parents ee39262 + 327ccc2 commit 8e15c7a

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

discretize/_extensions/tree_ext.pyx

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5953,16 +5953,18 @@ cdef class _TreeMesh:
59535953

59545954
@property
59555955
def cell_nodes(self):
5956-
"""The index of nodes for each cell.
5956+
"""The index of all nodes for each cell.
5957+
5958+
These indices point to non-hanging and hanging nodes.
59575959
59585960
Returns
59595961
-------
59605962
numpy.ndarray of int
59615963
Index array of shape (n_cells, 4) if 2D, or (n_cells, 8) if 3D
59625964
5963-
Notes
5964-
-----
5965-
These indices will also point to hanging nodes.
5965+
See also
5966+
--------
5967+
TreeMesh.total_nodes
59665968
"""
59675969
cdef int_t npc = 4 if self.dim == 2 else 8
59685970
inds = np.empty((self.n_cells, npc), dtype=np.int64)

discretize/tree_mesh.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,21 @@ def refine_surface(
740740
if finalize:
741741
self.finalize()
742742

743+
@property
744+
def total_nodes(self):
745+
"""Gridded hanging and non-hanging nodes locations.
746+
747+
This property returns a numpy array of shape
748+
``(n_total_nodes, dim)`` containing gridded locations for
749+
all hanging and non-hanging nodes in the mesh.
750+
751+
Returns
752+
-------
753+
(n_total_nodes, dim) numpy.ndarray of float
754+
Gridded hanging and non-hanging node locations
755+
"""
756+
return np.vstack((self.nodes, self.hanging_nodes))
757+
743758
@property
744759
def vntF(self):
745760
"""Vector number of total faces along each axis.

tests/tree/test_tree.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import unittest
3+
import pytest
34
import discretize
45

56
TOL = 1e-8
@@ -340,6 +341,43 @@ def test_cell_nodes(self):
340341
np.testing.assert_equal(cell_2.nodes, cell_nodes[2])
341342

342343

344+
class TestTreeMeshNodes:
345+
@pytest.fixture(params=["2D", "3D"])
346+
def sample_mesh(self, request):
347+
"""Return a sample TreeMesh"""
348+
nc = 8
349+
h1 = np.random.rand(nc) * nc * 0.5 + nc * 0.5
350+
h2 = np.random.rand(nc) * nc * 0.5 + nc * 0.5
351+
if request.param == "2D":
352+
h = [hi / np.sum(hi) for hi in [h1, h2]] # normalize
353+
mesh = discretize.TreeMesh(h)
354+
points = np.array([[0.2, 0.1], [0.8, 0.4]])
355+
levels = np.array([1, 2])
356+
mesh.insert_cells(points, levels, finalize=True)
357+
else:
358+
h3 = np.random.rand(nc) * nc * 0.5 + nc * 0.5
359+
h = [hi / np.sum(hi) for hi in [h1, h2, h3]] # normalize
360+
mesh = discretize.TreeMesh(h, levels=3)
361+
points = np.array([[0.2, 0.1, 0.7], [0.8, 0.4, 0.2]])
362+
levels = np.array([1, 2])
363+
mesh.insert_cells(points, levels, finalize=True)
364+
return mesh
365+
366+
def test_total_nodes(self, sample_mesh):
367+
"""
368+
Test if ``TreeMesh.total_nodes`` works as expected
369+
"""
370+
n_non_hanging_nodes = sample_mesh.n_nodes
371+
# Check if total_nodes contain all non hanging nodes (in the right order)
372+
np.testing.assert_equal(
373+
sample_mesh.total_nodes[:n_non_hanging_nodes, :], sample_mesh.nodes
374+
)
375+
# Check if total_nodes contain all hanging nodes (in the right order)
376+
np.testing.assert_equal(
377+
sample_mesh.total_nodes[n_non_hanging_nodes:, :], sample_mesh.hanging_nodes
378+
)
379+
380+
343381
class Test2DInterpolation(unittest.TestCase):
344382
def setUp(self):
345383
def topo(x):

0 commit comments

Comments
 (0)