Skip to content

Commit c152f44

Browse files
committed
MG: generate pyop3 kernels
1 parent baa2fb8 commit c152f44

File tree

8 files changed

+262
-130
lines changed

8 files changed

+262
-130
lines changed

firedrake/cofunction.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,17 @@ def vector(self):
349349
:class:`Cofunction`"""
350350
return vector.Vector(self)
351351

352+
def nodal_dat(self):
353+
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw)
354+
352355
@property
353356
def node_set(self):
354357
r"""A :class:`pyop2.types.set.Set` containing the nodes of this
355358
:class:`Cofunction`. One or (for rank-1 and 2
356359
:class:`.FunctionSpace`\s) more degrees of freedom are stored
357360
at each node.
358361
"""
359-
return self.function_space().node_set
362+
return self.function_space().nodes
360363

361364
def ufl_id(self):
362365
return self.uid

firedrake/cython/mgimpl.pyx

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def coarse_to_fine_nodes(Vc, Vf, np.ndarray[PetscInt, ndim=2, mode="c"] coarse_t
8585
ndof = fine_per_cell * fine_cell_per_coarse_cell
8686
if extruded:
8787
ndof *= ratio
88-
coarse_to_fine_map = np.full((coarse_cells,
88+
coarse_to_fine_map = np.full((Vc.node_count,
8989
ndof),
9090
-1,
9191
dtype=IntType)
@@ -142,7 +142,7 @@ def fine_to_coarse_nodes(Vf, Vc, np.ndarray[PetscInt, ndim=2, mode="c"] fine_to_
142142
coarse_per_fine = fine_to_coarse_cells.shape[1]
143143
coarse_per_cell = coarse_map.shape[1]
144144
fine_per_cell = fine_map.shape[1]
145-
fine_to_coarse_map = np.full((fine_cells,
145+
fine_to_coarse_map = np.full((Vf.node_count,
146146
coarse_per_fine*coarse_per_cell),
147147
-1,
148148
dtype=IntType)

firedrake/function.py

+3
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ def vector(self):
381381
r"""Return a :class:`.Vector` wrapping the data in this :class:`Function`"""
382382
return vector.Vector(self)
383383

384+
def nodal_dat(self):
385+
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw)
386+
384387
@PETSc.Log.EventDecorator()
385388
def interpolate(
386389
self,

firedrake/functionspaceimpl.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -697,17 +697,28 @@ def _cdim(self):
697697

698698
@utils.cached_property
699699
def cell_node_list(self):
700+
from firedrake.parloops import pack_pyop3_tensor
700701
r"""A numpy array mapping mesh cells to function space nodes (includes halo)."""
701702
cells = self.mesh().cells
702-
ncells = cells.size
703-
packed_axes = self.block_axes[self.mesh().closure(cells.index())]
704-
705-
return packed_axes.tabulated_offsets.buffer.data.reshape((ncells, -1))
703+
# Pass self.sub(0) to get nodes from the scalar version of this function space
704+
packed_axes = pack_pyop3_tensor(self.block_axes, self.sub(0), cells.index(), "cell")
705+
return packed_axes.tabulated_offsets.buffer.data.reshape((cells.size, -1))
706706

707707
@utils.cached_property
708708
def owned_cell_node_list(self):
709709
r"""A numpy array mapping owned mesh cells to function space nodes."""
710-
return self.cell_node_list[:self.mesh().owned_cells.size]
710+
cells = self.mesh().cells
711+
return self.cell_node_list[:cells.owned.size]
712+
713+
@utils.cached_property
714+
def nodes(self):
715+
if self.value_size > 1:
716+
raise NotImplementedError("TODO")
717+
return op3.Axis({"XXX": self.node_count}, "nodes", numbering=None)
718+
719+
@utils.cached_property
720+
def nodal_axes(self):
721+
return op3.AxisTree(self.nodes)
711722

712723
@utils.cached_property
713724
def topological(self):

firedrake/mg/interface.py

+62-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pyop2 import op2
1+
import pyop3 as op3
22

33
import firedrake
44
from firedrake import ufl_expr
@@ -81,13 +81,21 @@ def prolong(coarse, fine):
8181
# Have to do this, because the node set core size is not right for
8282
# this expanded stencil
8383
for d in [coarse, coarse_coords]:
84-
d.dat.global_to_local_begin(op2.READ)
85-
d.dat.global_to_local_end(op2.READ)
86-
op2.par_loop(kernel, next.node_set,
87-
next.dat(op2.WRITE),
88-
coarse.dat(op2.READ, fine_to_coarse),
89-
node_locations.dat(op2.READ),
90-
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
84+
d.dat.assemble()
85+
#op2.par_loop(kernel, next.node_set,
86+
# next.dat(op2.WRITE),
87+
# coarse.dat(op2.READ, fine_to_coarse),
88+
# node_locations.dat(op2.READ),
89+
# coarse_coords.dat(op2.READ, fine_to_coarse_coords))
90+
op3.do_loop(
91+
n := Vf.nodes.index(),
92+
kernel(
93+
next.nodal_dat()[n],
94+
coarse.nodal_dat()[fine_to_coarse(n)],
95+
node_locations.nodal_dat()[n],
96+
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
97+
),
98+
)
9199
coarse = next
92100
Vc = Vf
93101
return fine
@@ -142,14 +150,22 @@ def restrict(fine_dual, coarse_dual):
142150
# Have to do this, because the node set core size is not right for
143151
# this expanded stencil
144152
for d in [coarse_coords]:
145-
d.dat.global_to_local_begin(op2.READ)
146-
d.dat.global_to_local_end(op2.READ)
153+
d.dat.assemble()
147154
kernel = kernels.restrict_kernel(Vf, Vc)
148-
op2.par_loop(kernel, fine_dual.node_set,
149-
next.dat(op2.INC, fine_to_coarse),
150-
fine_dual.dat(op2.READ),
151-
node_locations.dat(op2.READ),
152-
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
155+
#op2.par_loop(kernel, fine_dual.node_set,
156+
# next.dat(op2.INC, fine_to_coarse),
157+
# fine_dual.dat(op2.READ),
158+
# node_locations.dat(op2.READ),
159+
# coarse_coords.dat(op2.READ, fine_to_coarse_coords))
160+
op3.do_loop(
161+
n := Vf.nodes.index(),
162+
kernel(
163+
next.nodal_dat()[n],
164+
fine_dual.nodal_dat()[fine_to_coarse(n)],
165+
node_locations.nodal_dat()[n],
166+
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
167+
),
168+
)
153169
fine_dual = next
154170
Vf = Vc
155171
return coarse_dual
@@ -217,13 +233,21 @@ def inject(fine, coarse):
217233
# Have to do this, because the node set core size is not right for
218234
# this expanded stencil
219235
for d in [fine, fine_coords]:
220-
d.dat.global_to_local_begin(op2.READ)
221-
d.dat.global_to_local_end(op2.READ)
222-
op2.par_loop(kernel, next.node_set,
223-
next.dat(op2.INC),
224-
node_locations.dat(op2.READ),
225-
fine.dat(op2.READ, coarse_node_to_fine_nodes),
226-
fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
236+
d.dat.assemble()
237+
#op2.par_loop(kernel, next.node_set,
238+
# next.dat(op2.INC),
239+
# node_locations.dat(op2.READ),
240+
# fine.dat(op2.READ, coarse_node_to_fine_nodes),
241+
# fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
242+
op3.do_loop(
243+
n := Vc.nodes.index(),
244+
kernel(
245+
next.nodal_dat()[n],
246+
node_locations.nodal_dat()[n],
247+
fine.nodal_dat()[coarse_node_to_fine_nodes(n)],
248+
fine_coords.nodal_dat()[coarse_node_to_fine_coords(n)],
249+
),
250+
)
227251
else:
228252
coarse_coords = Vc.mesh().coordinates
229253
fine_coords = Vf.mesh().coordinates
@@ -232,13 +256,22 @@ def inject(fine, coarse):
232256
# Have to do this, because the node set core size is not right for
233257
# this expanded stencil
234258
for d in [fine, fine_coords]:
235-
d.dat.global_to_local_begin(op2.READ)
236-
d.dat.global_to_local_end(op2.READ)
237-
op2.par_loop(kernel, Vc.mesh().cell_set,
238-
next.dat(op2.INC, next.cell_node_map()),
239-
fine.dat(op2.READ, coarse_cell_to_fine_nodes),
240-
fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
241-
coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
259+
d.dat.assemble()
260+
#op2.par_loop(kernel, Vc.mesh().cell_set,
261+
# next.dat(op2.INC, next.cell_node_map()),
262+
# fine.dat(op2.READ, coarse_cell_to_fine_nodes),
263+
# fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
264+
# coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
265+
op3.do_loop(
266+
n := Vc.nodes.index(),
267+
kernel(
268+
next.nodal_dat()[n],
269+
fine.nodal_dat()[coarse_node_to_fine_nodes(n)],
270+
fine_coords.nodal_dat()[coarse_node_to_fine_coords(n)],
271+
coarse_coords.nodal_dat()[coarse_coords(n)],
272+
),
273+
)
274+
242275
fine = next
243276
Vf = Vc
244277
return coarse

0 commit comments

Comments
 (0)