Skip to content

Commit 93bd80d

Browse files
committed
remove mesh.init()
1 parent 73d63aa commit 93bd80d

File tree

6 files changed

+129
-42
lines changed

6 files changed

+129
-42
lines changed

firedrake/cython/dmcommon.pyx

+38-2
Original file line numberDiff line numberDiff line change
@@ -3687,7 +3687,8 @@ def create_halo_exchange_sf(PETSc.DM dm):
36873687
@cython.wraparound(False)
36883688
def submesh_create(PETSc.DM dm,
36893689
label_name,
3690-
PetscInt label_value):
3690+
PetscInt label_value,
3691+
PetscBool ignore_label_halo):
36913692
"""Create submesh.
36923693
36933694
Parameters
@@ -3698,6 +3699,8 @@ def submesh_create(PETSc.DM dm,
36983699
Name of the label
36993700
label_value : int
37003701
Value in the label
3702+
ignore_label_halo : bool
3703+
If labeled points in the halo are ignored.
37013704
37023705
"""
37033706
cdef:
@@ -3706,7 +3709,7 @@ def submesh_create(PETSc.DM dm,
37063709
PETSc.SF ownership_transfer_sf = PETSc.SF()
37073710

37083711
label = dm.getLabel(label_name)
3709-
CHKERR(DMPlexFilter(dm.dm, label.dmlabel, label_value, PETSC_FALSE, PETSC_TRUE, &ownership_transfer_sf.sf, &subdm.dm))
3712+
CHKERR(DMPlexFilter(dm.dm, label.dmlabel, label_value, ignore_label_halo, PETSC_TRUE, &ownership_transfer_sf.sf, &subdm.dm))
37103713
submesh_update_facet_labels(dm, subdm)
37113714
submesh_correct_entity_classes(dm, subdm, ownership_transfer_sf)
37123715
return subdm
@@ -3919,3 +3922,36 @@ def submesh_create_cell_closure_cell_submesh(PETSc.DM subdm,
39193922
CHKERR(PetscFree(subpoint_indices_inv))
39203923
CHKERR(ISRestoreIndices(subpoint_is.iset, &subpoint_indices))
39213924
return subcell_closure
3925+
3926+
3927+
@cython.boundscheck(False)
3928+
@cython.wraparound(False)
3929+
def get_dm_cell_types(PETSc.DM dm):
3930+
"""Return all cell types in the mesh.
3931+
3932+
Parameters
3933+
----------
3934+
dm : PETSc.DM
3935+
The parent dm.
3936+
3937+
Returns
3938+
-------
3939+
tuple
3940+
Tuple of all cell types in the mesh.
3941+
3942+
"""
3943+
cdef:
3944+
PetscInt cStart, cEnd, c
3945+
np.ndarray found, found_all
3946+
PetscDMPolytopeType celltype
3947+
3948+
cStart, cEnd = dm.getHeightStratum(0)
3949+
found = np.zeros((DM_NUM_POLYTOPES, ), dtype=IntType)
3950+
found_all = np.zeros((DM_NUM_POLYTOPES, ), dtype=IntType)
3951+
for c in range(cStart, cEnd):
3952+
CHKERR(DMPlexGetCellType(dm.dm, c, &celltype))
3953+
found[celltype] = 1
3954+
dm.comm.tompi4py().Allreduce(found, found_all, op=MPI.MAX)
3955+
return tuple(
3956+
polytope_type_enum for polytope_type_enum, found in enumerate(found_all) if found
3957+
)

firedrake/cython/petschdr.pxi

+24
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,27 @@ cdef extern from "petscsys.h" nogil:
2929
int PetscFree2(void*,void*)
3030
int PetscSortIntWithArray(PetscInt,PetscInt[],PetscInt[])
3131

32+
cdef extern from "petscdmtypes.h" nogil:
33+
ctypedef enum PetscDMPolytopeType "DMPolytopeType":
34+
DM_POLYTOPE_POINT
35+
DM_POLYTOPE_SEGMENT
36+
DM_POLYTOPE_POINT_PRISM_TENSOR
37+
DM_POLYTOPE_TRIANGLE
38+
DM_POLYTOPE_QUADRILATERAL
39+
DM_POLYTOPE_SEG_PRISM_TENSOR
40+
DM_POLYTOPE_TETRAHEDRON
41+
DM_POLYTOPE_HEXAHEDRON
42+
DM_POLYTOPE_TRI_PRISM
43+
DM_POLYTOPE_TRI_PRISM_TENSOR
44+
DM_POLYTOPE_QUAD_PRISM_TENSOR
45+
DM_POLYTOPE_PYRAMID
46+
DM_POLYTOPE_FV_GHOST
47+
DM_POLYTOPE_INTERIOR_GHOST
48+
DM_POLYTOPE_UNKNOWN
49+
DM_POLYTOPE_UNKNOWN_CELL
50+
DM_POLYTOPE_UNKNOWN_FACE
51+
DM_NUM_POLYTOPES
52+
3253
cdef extern from "petscdmplex.h" nogil:
3354
int DMPlexGetHeightStratum(PETSc.PetscDM,PetscInt,PetscInt*,PetscInt*)
3455
int DMPlexGetDepthStratum(PETSc.PetscDM,PetscInt,PetscInt*,PetscInt*)
@@ -56,6 +77,9 @@ cdef extern from "petscdmplex.h" nogil:
5677
int DMPlexGetSubpointMap(PETSc.PetscDM,PETSc.PetscDMLabel*)
5778
int DMPlexSetSubpointMap(PETSc.PetscDM,PETSc.PetscDMLabel)
5879

80+
int DMPlexSetCellType(PETSc.PetscDM,PetscInt,PetscDMPolytopeType)
81+
int DMPlexGetCellType(PETSc.PetscDM,PetscInt,PetscDMPolytopeType*)
82+
5983
cdef extern from "petscdmlabel.h" nogil:
6084
struct _n_DMLabel
6185
ctypedef _n_DMLabel* DMLabel "DMLabel"

firedrake/mesh.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,12 @@ def ufl_mesh(self):
742742
"""
743743
return self._ufl_mesh
744744

745+
@property
746+
@abc.abstractmethod
747+
def dm_cell_types(self):
748+
"""All ``DM.PolytopeType``s of cells in the mesh."""
749+
pass
750+
745751
@property
746752
@abc.abstractmethod
747753
def cell_closure(self):
@@ -1247,6 +1253,11 @@ def _renumber_entities(self, reorder):
12471253
reordering = None
12481254
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, reordering)
12491255

1256+
@property
1257+
def dm_cell_types(self):
1258+
"""All ``DM.PolytopeType``s of cells in the mesh."""
1259+
return dmcommon.get_dm_cell_types(self.topology_dm)
1260+
12501261
@utils.cached_property
12511262
def cell_closure(self):
12521263
"""2D array of ordered cell closures
@@ -1744,6 +1755,11 @@ def _ufl_mesh(self):
17441755
cell = self._ufl_cell
17451756
return ufl.Mesh(finat.ufl.VectorElement("Lagrange", cell, 1, dim=cell.topological_dimension()))
17461757

1758+
@property
1759+
def dm_cell_types(self):
1760+
"""All ``DM.PolytopeType``s of cells in the mesh."""
1761+
raise NotImplementedError("Notimplemented for ExtrudedMeshTopology")
1762+
17471763
@utils.cached_property
17481764
def cell_closure(self):
17491765
"""2D array of ordered cell closures
@@ -1981,6 +1997,11 @@ def _renumber_entities(self, reorder):
19811997
else:
19821998
return dmcommon.plex_renumbering(self.topology_dm, self._entity_classes, None)
19831999

2000+
@property
2001+
def dm_cell_types(self):
2002+
"""All ``DM.PolytopeType``s of cells in the mesh."""
2003+
return (PETSc.DM.PolytopeType.POINT,)
2004+
19842005
@utils.cached_property # TODO: Recalculate if mesh moves
19852006
def cell_closure(self):
19862007
"""2D array of ordered cell closures
@@ -4603,7 +4624,7 @@ def SubDomainData(geometric_expr):
46034624
return op2.Subset(m.cell_set, indices)
46044625

46054626

4606-
def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None):
4627+
def Submesh(mesh, subdim, subdomain_id, label_name=None, ignore_label_halo=False, name=None):
46074628
"""Construct a submesh from a given mesh.
46084629
46094630
Parameters
@@ -4616,6 +4637,8 @@ def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None):
46164637
Subdomain ID representing the submesh.
46174638
label_name : str
46184639
Name of the label to search ``subdomain_id`` in.
4640+
ignore_label_halo : bool
4641+
If labeled points in the halo are ignored.
46194642
name : str
46204643
Name of the submesh.
46214644
@@ -4669,7 +4692,7 @@ def Submesh(mesh, subdim, subdomain_id, label_name=None, name=None):
46694692
if label_name is None:
46704693
label_name = dmcommon.CELL_SETS_LABEL
46714694
name = name or _generate_default_submesh_name(mesh.name)
4672-
subplex = dmcommon.submesh_create(plex, label_name, subdomain_id)
4695+
subplex = dmcommon.submesh_create(plex, label_name, subdomain_id, ignore_label_halo)
46734696
subplex.setName(_generate_default_mesh_topology_name(name))
46744697
if subplex.getDimension() != subdim:
46754698
raise RuntimeError(f"Found subplex dim ({subplex.getDimension()}) != expected ({subdim})")

firedrake/mg/mesh.py

+40-26
Original file line numberDiff line numberDiff line change
@@ -131,36 +131,32 @@ def MeshHierarchy(mesh, refinement_levels,
131131
raise RuntimeError("Cannot create a NetgenHierarchy from a mesh that has not been generated by\
132132
Netgen.")
133133

134+
mesh_with_overlap = mesh
135+
mesh.init()
136+
dm_cell_type, = mesh.dm_cell_types
137+
tdim = mesh.topology_dm.getDimension()
138+
# Virtually "invert" addOverlap.
139+
# -- This is algorithmically guaranteed.
140+
mesh = firedrake.Submesh(mesh_with_overlap, tdim, dm_cell_type, label_name="celltype", ignore_label_halo=True)
134141
cdm = mesh.topology_dm
142+
cdm.removeLabel("pyop2_core")
143+
cdm.removeLabel("pyop2_owned")
144+
cdm.removeLabel("pyop2_ghost")
135145
cdm.setRefinementUniform(True)
136146
dms = []
137147
if mesh.comm.size > 1 and mesh._grown_halos:
138148
raise RuntimeError("Cannot refine parallel overlapped meshes "
139149
"(make sure the MeshHierarchy is built immediately after the Mesh)")
140-
parameters = {}
141-
if distribution_parameters is not None:
142-
parameters.update(distribution_parameters)
143-
else:
144-
parameters.update(mesh._distribution_parameters)
145-
146-
parameters["partition"] = False
147-
distribution_parameters = parameters
148-
149150
if callbacks is not None:
150151
before, after = callbacks
151152
else:
152153
before = after = lambda dm, i: None
153-
154154
for i in range(refinement_levels*refinements_per_level):
155155
if i % refinements_per_level == 0:
156156
before(cdm, i)
157157
rdm = cdm.refine()
158158
if i % refinements_per_level == 0:
159159
after(rdm, i)
160-
rdm.removeLabel("pyop2_core")
161-
rdm.removeLabel("pyop2_owned")
162-
rdm.removeLabel("pyop2_ghost")
163-
164160
dms.append(rdm)
165161
cdm = rdm
166162
# Fix up coords if refining embedded circle or sphere
@@ -172,20 +168,38 @@ def MeshHierarchy(mesh, refinement_levels,
172168
coords = cdm.getCoordinatesLocal().array.reshape(-1, mesh.geometric_dimension())
173169
scale = mesh._radius / np.linalg.norm(coords, axis=1).reshape(-1, 1)
174170
coords *= scale
175-
176-
meshes = [mesh] + [mesh_builder(dm, dim=mesh.geometric_dimension(),
177-
distribution_parameters=distribution_parameters,
178-
reorder=reorder, comm=mesh.comm)
179-
for dm in dms]
180-
181-
lgmaps = []
182-
for i, m in enumerate(meshes):
183-
no = impl.create_lgmap(m.topology_dm)
171+
lgmaps_without_overlap = [
172+
impl.create_lgmap(dm) for dm in [mesh.topology_dm] + dms
173+
]
174+
parameters = {}
175+
if distribution_parameters is not None:
176+
parameters.update(distribution_parameters)
177+
else:
178+
parameters.update(mesh_with_overlap._distribution_parameters)
179+
parameters["partition"] = False
180+
meshes = [mesh_with_overlap] + [
181+
mesh_builder(
182+
dm,
183+
dim=mesh_with_overlap.geometric_dimension(),
184+
distribution_parameters=parameters,
185+
reorder=reorder,
186+
comm=mesh.comm,
187+
)
188+
for dm in dms
189+
]
190+
for m in meshes:
184191
m.init()
185-
o = impl.create_lgmap(m.topology_dm)
192+
#distribution_parameters_noop={
193+
# "partition": False,
194+
# "overlap_type": (firedrake.mesh.DistributedMeshOverlapType.NONE, 0),
195+
#}
196+
lgmaps_with_overlap = []
197+
for i, m in enumerate(meshes):
198+
lgmaps_with_overlap.append(impl.create_lgmap(m.topology_dm))
186199
m.topology_dm.setRefineLevel(i)
187-
lgmaps.append((no, o))
188-
200+
lgmaps = [
201+
(no, o) for no, o in zip(lgmaps_without_overlap, lgmaps_with_overlap)
202+
]
189203
coarse_to_fine_cells = []
190204
fine_to_coarse_cells = [None]
191205
for (coarse, fine), (clgmaps, flgmaps) in zip(zip(meshes[:-1], meshes[1:]),

tests/firedrake/multigrid/test_basics.py

-10
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,3 @@ def test_refine_square_ncell_parallel():
6060
# Should be fewer than 4 times the number of coarse cells due to
6161
# halo shrinking.
6262
assert mh[1].num_cells() < 4 * mh[0].num_cells()
63-
64-
65-
@pytest.mark.parallel(nprocs=2)
66-
def test_refining_overlapped_mesh_fails_parallel():
67-
m = UnitSquareMesh(4, 4)
68-
69-
m.init()
70-
71-
with pytest.raises(RuntimeError):
72-
MeshHierarchy(m, 1)

tests/firedrake/multigrid/test_grid_transfer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,8 @@ def test_grid_transfer_deformed(deformed_hierarchy, deformed_transfer_type):
269269
run_prolongation(deformed_hierarchy, vector, space, degrees)
270270

271271

272-
@pytest.fixture(params=["interval", "triangle", "quadrilateral", "tetrahedron"], scope="module")
272+
#@pytest.fixture(params=["interval", "triangle", "quadrilateral", "tetrahedron"], scope="module")
273+
@pytest.fixture(params=["interval"], scope="module")
273274
def periodic_cell(request):
274275
return request.param
275276

@@ -313,7 +314,6 @@ def exact_primal_periodic(mesh, vector, degree):
313314
return expr
314315

315316

316-
@pytest.mark.parallel(nprocs=3)
317317
def test_grid_transfer_periodic(periodic_hierarchy, periodic_space):
318318
degrees = [4]
319319
vector = False

0 commit comments

Comments
 (0)