Skip to content

Commit 5418e7f

Browse files
committed
MG: prolong working in parallel for single refinement
1 parent b3b93c1 commit 5418e7f

File tree

6 files changed

+57
-58
lines changed

6 files changed

+57
-58
lines changed

firedrake/function.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ def vector(self):
382382
return vector.Vector(self)
383383

384384
def nodal_dat(self):
385-
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw_with_halos)
385+
return op3.HierarchicalArray(self.function_space().nodal_axes,
386+
data=self.dat.data_rw_with_halos)
386387

387388
@PETSc.Log.EventDecorator()
388389
def interpolate(

firedrake/functionspaceimpl.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -712,15 +712,13 @@ def owned_cell_node_list(self):
712712

713713
@utils.cached_property
714714
def nodes(self):
715-
if self.value_size > 1:
716-
raise NotImplementedError("TODO")
717715
ax = self.block_axes
718716
return op3.Axis([op3.AxisComponent((ax.owned.size, ax.size),
719717
"XXX", rank_equal=False)], "nodes", numbering=None, sf=ax.sf)
720718

721719
@utils.cached_property
722720
def nodal_axes(self):
723-
return op3.AxisTree(self.nodes)
721+
return op3.AxisTree.from_iterable([self.nodes, self.value_size])
724722

725723
@utils.cached_property
726724
def topological(self):

firedrake/mg/interface.py

+15-25
Original file line numberDiff line numberDiff line change
@@ -83,17 +83,13 @@ def prolong(coarse, fine):
8383
for d in [coarse, coarse_coords]:
8484
d.dat.assemble()
8585

86-
#op2.par_loop(kernel, next.node_set,
87-
# next.dat(op2.WRITE),
88-
# coarse.dat(op2.READ, fine_to_coarse),
89-
# node_locations.dat(op2.READ),
90-
# coarse_coords.dat(op2.READ, fine_to_coarse_coords))
86+
id_map = utils.owned_node_map(Vf)
9187
op3.do_loop(
92-
n := Vf.nodes.index(),
88+
n := Vf.nodes.owned.index(),
9389
kernel(
94-
next.nodal_dat()[n],
90+
next.nodal_dat()[id_map(n)],
9591
coarse.nodal_dat()[fine_to_coarse(n)],
96-
node_locations.nodal_dat()[n],
92+
node_locations.nodal_dat()[id_map(n)],
9793
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
9894
),
9995
)
@@ -153,17 +149,14 @@ def restrict(fine_dual, coarse_dual):
153149
for d in [coarse_coords]:
154150
d.dat.assemble()
155151
kernel = kernels.restrict_kernel(Vf, Vc)
156-
#op2.par_loop(kernel, fine_dual.node_set,
157-
# next.dat(op2.INC, fine_to_coarse),
158-
# fine_dual.dat(op2.READ),
159-
# node_locations.dat(op2.READ),
160-
# coarse_coords.dat(op2.READ, fine_to_coarse_coords))
152+
153+
id_map = utils.owned_node_map(Vf)
161154
op3.do_loop(
162-
n := Vf.nodes.index(),
155+
n := Vf.nodes.owned.index(),
163156
kernel(
164157
next.nodal_dat()[fine_to_coarse(n)],
165-
fine_dual.nodal_dat()[n],
166-
node_locations.nodal_dat()[n],
158+
fine_dual.nodal_dat()[id_map(n)],
159+
node_locations.nodal_dat()[id_map(n)],
167160
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
168161
),
169162
)
@@ -235,16 +228,13 @@ def inject(fine, coarse):
235228
# this expanded stencil
236229
for d in [fine, fine_coords]:
237230
d.dat.assemble()
238-
#op2.par_loop(kernel, next.node_set,
239-
# next.dat(op2.INC),
240-
# node_locations.dat(op2.READ),
241-
# fine.dat(op2.READ, coarse_node_to_fine_nodes),
242-
# fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
231+
232+
id_map = utils.owned_node_map(Vc)
243233
op3.do_loop(
244-
n := Vc.nodes.index(),
234+
n := Vc.nodes.owned.index(),
245235
kernel(
246-
next.nodal_dat()[n],
247-
node_locations.nodal_dat()[n],
236+
next.nodal_dat()[id_map(n)],
237+
node_locations.nodal_dat()[id_map(n)],
248238
fine.nodal_dat()[coarse_node_to_fine_nodes(n)],
249239
fine_coords.nodal_dat()[coarse_node_to_fine_coords(n)],
250240
),
@@ -264,7 +254,7 @@ def inject(fine, coarse):
264254
# fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
265255
# coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
266256
op3.do_loop(
267-
c := Vc.mesh().cells.index(),
257+
c := Vc.mesh().cells.owned.index(),
268258
kernel(
269259
next.dat[c],
270260
fine.nodal_dat()[coarse_cell_to_fine_nodes(c)],

firedrake/mg/mesh.py

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from firedrake.utils import cached_property
99
from firedrake.cython import mgimpl as impl
1010
from .utils import set_level
11-
from firedrake.petsc import PETSc
1211

1312
__all__ = ("HierarchyBase", "MeshHierarchy", "ExtrudedMeshHierarchy", "NonNestedHierarchy",
1413
"SemiCoarsenedExtrudedHierarchy")

firedrake/mg/utils.py

+33-20
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,31 @@ def coarse_to_fine_cell_map(coarse_mesh, fine_mesh, coarse_to_fine_data):
1818
return op3.Map(connectivity)
1919

2020

21+
def create_node_map(iterset, toset, arity=1, values=None):
22+
axes = op3.AxisTree.from_iterable([iterset, arity])
23+
if values is None:
24+
values = numpy.arange(iterset.size, dtype=IntType)
25+
dat = op3.HierarchicalArray(axes, data=values)
26+
return op3.Map({
27+
freeze({"nodes": iterset.owned.component.label}): [
28+
op3.TabulatedMapComponent("nodes", toset.component.label, dat)
29+
]
30+
})
31+
32+
33+
def owned_node_map(V):
34+
""" This should not be necessary
35+
"""
36+
mesh = V.mesh()
37+
key = entity_dofs_key(V.finat_element.entity_dofs())
38+
cache = mesh._shared_data_cache["owned_node_map"]
39+
try:
40+
return cache[key]
41+
except KeyError:
42+
node_map = create_node_map(V.nodes.owned, V.nodes)
43+
return cache.setdefault(key, node_map)
44+
45+
2146
def fine_node_to_coarse_node_map(Vf, Vc):
2247
if len(Vf) > 1:
2348
assert len(Vf) == len(Vc)
@@ -51,15 +76,9 @@ def fine_node_to_coarse_node_map(Vf, Vc):
5176

5277
fine_to_coarse = hierarchy.fine_to_coarse_cells[levelf]
5378
fine_to_coarse_nodes = impl.fine_to_coarse_nodes(Vf, Vc, fine_to_coarse)
54-
55-
axes = op3.AxisTree.from_iterable([Vf.nodes, fine_to_coarse_nodes.shape[1]])
56-
fine_to_coarse_node_dat = op3.HierarchicalArray(axes, data=fine_to_coarse_nodes)
57-
fine_to_coarse_node_map = op3.Map({
58-
freeze({"nodes": "XXX"}): [
59-
op3.TabulatedMapComponent("nodes", "XXX", fine_to_coarse_node_dat)
60-
]
61-
})
62-
return cache.setdefault(key, fine_to_coarse_node_map)
79+
return cache.setdefault(key, create_node_map(Vf.nodes, Vc.nodes,
80+
arity=fine_to_coarse_nodes.shape[1],
81+
values=fine_to_coarse_nodes))
6382

6483

6584
def coarse_node_to_fine_node_map(Vc, Vf):
@@ -95,15 +114,9 @@ def coarse_node_to_fine_node_map(Vc, Vf):
95114

96115
coarse_to_fine = hierarchy.coarse_to_fine_cells[levelc]
97116
coarse_to_fine_nodes = impl.coarse_to_fine_nodes(Vc, Vf, coarse_to_fine)
98-
99-
axes = op3.AxisTree.from_iterable([Vc.nodes, coarse_to_fine_nodes.shape[1]])
100-
coarse_to_fine_node_dat = op3.HierarchicalArray(axes, data=coarse_to_fine_nodes)
101-
coarse_to_fine_node_map = op3.Map({
102-
freeze({"nodes": "XXX"}): [
103-
op3.TabulatedMapComponent("nodes", "XXX", coarse_to_fine_node_dat)
104-
]
105-
})
106-
return cache.setdefault(key, coarse_to_fine_node_map)
117+
return cache.setdefault(key, create_node_map(Vc.nodes, Vf.nodes,
118+
arity=coarse_to_fine_nodes.shape[1],
119+
values=coarse_to_fine_nodes))
107120

108121

109122
def coarse_cell_to_fine_node_map(Vc, Vf):
@@ -153,8 +166,8 @@ def coarse_cell_to_fine_node_map(Vc, Vf):
153166
axes = op3.AxisTree.from_iterable([Vc.nodes, arity*level_ratio])
154167
coarse_cell_to_fine_node_dat = op3.HierarchicalArray(axes, data=coarse_to_fine_nodes)
155168
coarse_cell_to_fine_node_map = op3.Map({
156-
freeze({Vc.mesh().topology.name: Vc.mesh().cell_label}): [
157-
op3.TabulatedMapComponent("nodes", "XXX", coarse_cell_to_fine_node_dat)
169+
freeze({Vc.mesh().topology.name: Vc.mesh().cells.owned.label}): [
170+
op3.TabulatedMapComponent("nodes", Vf.nodes.component.label, coarse_cell_to_fine_node_dat)
158171
]
159172
})
160173
return cache.setdefault(key, coarse_cell_to_fine_node_map)

tests/multigrid/test_grid_transfer.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,23 @@ def space(request, cell):
2121
return request.param
2222

2323

24-
@pytest.fixture(params=[1, 2], scope="module")
25-
def refinements_per_level(request):
24+
@pytest.fixture(params=[(1, 1), (2, 1), (1, 2)], scope="module")
25+
def nref_refinements_per_level(request):
2626
return request.param
2727

2828

2929
@pytest.fixture(scope="module")
30-
def hierarchy(cell, refinements_per_level):
30+
def hierarchy(cell, nref_refinements_per_level):
3131
if cell == "interval":
3232
mesh = UnitIntervalMesh(3)
33-
return MeshHierarchy(mesh, 2)
3433
elif cell in {"triangle", "triangle-nonnested", "prism"}:
3534
mesh = UnitSquareMesh(3, 3, quadrilateral=False)
3635
elif cell in {"quadrilateral", "hexahedron"}:
3736
mesh = UnitSquareMesh(3, 3, quadrilateral=True)
3837
elif cell == "tetrahedron":
3938
mesh = UnitCubeMesh(2, 2, 2)
4039

41-
nref = {2: 1, 1: 2}[refinements_per_level]
40+
nref, refinements_per_level = nref_refinements_per_level
4241
hierarchy = MeshHierarchy(mesh, nref, refinements_per_level=refinements_per_level)
4342

4443
if cell in {"prism", "hexahedron"}:
@@ -328,6 +327,5 @@ def test_grid_transfer_periodic(periodic_hierarchy, periodic_space):
328327

329328
if __name__ == "__main__":
330329
bmesh = UnitIntervalMesh(3)
331-
mh = MeshHierarchy(bmesh, 1)
332-
333-
run_restriction(mh, False, "CG", [1])
330+
mh = MeshHierarchy(bmesh, 2, refinements_per_level=1)
331+
run_prolongation(mh, False, "CG", [1])

0 commit comments

Comments
 (0)