Skip to content

Commit b3b93c1

Browse files
committed
MG: generate pyop3 kernels
1 parent baa2fb8 commit b3b93c1

File tree

9 files changed

+277
-128
lines changed

9 files changed

+277
-128
lines changed

firedrake/cofunction.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -349,14 +349,17 @@ def vector(self):
349349
:class:`Cofunction`"""
350350
return vector.Vector(self)
351351

352+
def nodal_dat(self):
353+
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw_with_halos)
354+
352355
@property
353356
def node_set(self):
354357
r"""A :class:`pyop2.types.set.Set` containing the nodes of this
355358
:class:`Cofunction`. One or (for rank-1 and 2
356359
:class:`.FunctionSpace`\s) more degrees of freedom are stored
357360
at each node.
358361
"""
359-
return self.function_space().node_set
362+
return self.function_space().nodes
360363

361364
def ufl_id(self):
362365
return self.uid
@@ -386,5 +389,10 @@ def __str__(self):
386389
else:
387390
return super(Cofunction, self).__str__()
388391

389-
def cell_node_map(self):
390-
return self.function_space().cell_node_map()
392+
@property
393+
def cell_node_list(self):
394+
return self.function_space().cell_node_list
395+
396+
@property
397+
def owned_cell_node_list(self):
398+
return self.function_space().owned_cell_node_list

firedrake/cython/mgimpl.pyx

+2-2
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def coarse_to_fine_nodes(Vc, Vf, np.ndarray[PetscInt, ndim=2, mode="c"] coarse_t
8585
ndof = fine_per_cell * fine_cell_per_coarse_cell
8686
if extruded:
8787
ndof *= ratio
88-
coarse_to_fine_map = np.full((coarse_cells,
88+
coarse_to_fine_map = np.full((Vc.node_count,
8989
ndof),
9090
-1,
9191
dtype=IntType)
@@ -142,7 +142,7 @@ def fine_to_coarse_nodes(Vf, Vc, np.ndarray[PetscInt, ndim=2, mode="c"] fine_to_
142142
coarse_per_fine = fine_to_coarse_cells.shape[1]
143143
coarse_per_cell = coarse_map.shape[1]
144144
fine_per_cell = fine_map.shape[1]
145-
fine_to_coarse_map = np.full((fine_cells,
145+
fine_to_coarse_map = np.full((Vf.node_count,
146146
coarse_per_fine*coarse_per_cell),
147147
-1,
148148
dtype=IntType)

firedrake/function.py

+3
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,9 @@ def vector(self):
381381
r"""Return a :class:`.Vector` wrapping the data in this :class:`Function`"""
382382
return vector.Vector(self)
383383

384+
def nodal_dat(self):
385+
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw_with_halos)
386+
384387
@PETSc.Log.EventDecorator()
385388
def interpolate(
386389
self,

firedrake/functionspaceimpl.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -698,16 +698,29 @@ def _cdim(self):
698698
@utils.cached_property
699699
def cell_node_list(self):
700700
r"""A numpy array mapping mesh cells to function space nodes (includes halo)."""
701+
from firedrake.parloops import pack_pyop3_tensor
701702
cells = self.mesh().cells
702-
ncells = cells.size
703-
packed_axes = self.block_axes[self.mesh().closure(cells.index())]
704-
705-
return packed_axes.tabulated_offsets.buffer.data.reshape((ncells, -1))
703+
# Pass self.sub(0) to get nodes from the scalar version of this function space
704+
packed_axes = pack_pyop3_tensor(self.block_axes, self.sub(0), cells.index(include_ghost_points=True), "cell")
705+
return packed_axes.tabulated_offsets.buffer.data_rw_with_halos.reshape((cells.size, -1))
706706

707707
@utils.cached_property
708708
def owned_cell_node_list(self):
709709
r"""A numpy array mapping owned mesh cells to function space nodes."""
710-
return self.cell_node_list[:self.mesh().owned_cells.size]
710+
cells = self.mesh().cells
711+
return self.cell_node_list[:cells.owned.size]
712+
713+
@utils.cached_property
714+
def nodes(self):
715+
if self.value_size > 1:
716+
raise NotImplementedError("TODO")
717+
ax = self.block_axes
718+
return op3.Axis([op3.AxisComponent((ax.owned.size, ax.size),
719+
"XXX", rank_equal=False)], "nodes", numbering=None, sf=ax.sf)
720+
721+
@utils.cached_property
722+
def nodal_axes(self):
723+
return op3.AxisTree(self.nodes)
711724

712725
@utils.cached_property
713726
def topological(self):

firedrake/mg/interface.py

+63-29
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pyop2 import op2
1+
import pyop3 as op3
22

33
import firedrake
44
from firedrake import ufl_expr
@@ -81,13 +81,22 @@ def prolong(coarse, fine):
8181
# Have to do this, because the node set core size is not right for
8282
# this expanded stencil
8383
for d in [coarse, coarse_coords]:
84-
d.dat.global_to_local_begin(op2.READ)
85-
d.dat.global_to_local_end(op2.READ)
86-
op2.par_loop(kernel, next.node_set,
87-
next.dat(op2.WRITE),
88-
coarse.dat(op2.READ, fine_to_coarse),
89-
node_locations.dat(op2.READ),
90-
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
84+
d.dat.assemble()
85+
86+
#op2.par_loop(kernel, next.node_set,
87+
# next.dat(op2.WRITE),
88+
# coarse.dat(op2.READ, fine_to_coarse),
89+
# node_locations.dat(op2.READ),
90+
# coarse_coords.dat(op2.READ, fine_to_coarse_coords))
91+
op3.do_loop(
92+
n := Vf.nodes.index(),
93+
kernel(
94+
next.nodal_dat()[n],
95+
coarse.nodal_dat()[fine_to_coarse(n)],
96+
node_locations.nodal_dat()[n],
97+
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
98+
),
99+
)
91100
coarse = next
92101
Vc = Vf
93102
return fine
@@ -142,14 +151,22 @@ def restrict(fine_dual, coarse_dual):
142151
# Have to do this, because the node set core size is not right for
143152
# this expanded stencil
144153
for d in [coarse_coords]:
145-
d.dat.global_to_local_begin(op2.READ)
146-
d.dat.global_to_local_end(op2.READ)
154+
d.dat.assemble()
147155
kernel = kernels.restrict_kernel(Vf, Vc)
148-
op2.par_loop(kernel, fine_dual.node_set,
149-
next.dat(op2.INC, fine_to_coarse),
150-
fine_dual.dat(op2.READ),
151-
node_locations.dat(op2.READ),
152-
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
156+
#op2.par_loop(kernel, fine_dual.node_set,
157+
# next.dat(op2.INC, fine_to_coarse),
158+
# fine_dual.dat(op2.READ),
159+
# node_locations.dat(op2.READ),
160+
# coarse_coords.dat(op2.READ, fine_to_coarse_coords))
161+
op3.do_loop(
162+
n := Vf.nodes.index(),
163+
kernel(
164+
next.nodal_dat()[fine_to_coarse(n)],
165+
fine_dual.nodal_dat()[n],
166+
node_locations.nodal_dat()[n],
167+
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
168+
),
169+
)
153170
fine_dual = next
154171
Vf = Vc
155172
return coarse_dual
@@ -217,13 +234,21 @@ def inject(fine, coarse):
217234
# Have to do this, because the node set core size is not right for
218235
# this expanded stencil
219236
for d in [fine, fine_coords]:
220-
d.dat.global_to_local_begin(op2.READ)
221-
d.dat.global_to_local_end(op2.READ)
222-
op2.par_loop(kernel, next.node_set,
223-
next.dat(op2.INC),
224-
node_locations.dat(op2.READ),
225-
fine.dat(op2.READ, coarse_node_to_fine_nodes),
226-
fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
237+
d.dat.assemble()
238+
#op2.par_loop(kernel, next.node_set,
239+
# next.dat(op2.INC),
240+
# node_locations.dat(op2.READ),
241+
# fine.dat(op2.READ, coarse_node_to_fine_nodes),
242+
# fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
243+
op3.do_loop(
244+
n := Vc.nodes.index(),
245+
kernel(
246+
next.nodal_dat()[n],
247+
node_locations.nodal_dat()[n],
248+
fine.nodal_dat()[coarse_node_to_fine_nodes(n)],
249+
fine_coords.nodal_dat()[coarse_node_to_fine_coords(n)],
250+
),
251+
)
227252
else:
228253
coarse_coords = Vc.mesh().coordinates
229254
fine_coords = Vf.mesh().coordinates
@@ -232,13 +257,22 @@ def inject(fine, coarse):
232257
# Have to do this, because the node set core size is not right for
233258
# this expanded stencil
234259
for d in [fine, fine_coords]:
235-
d.dat.global_to_local_begin(op2.READ)
236-
d.dat.global_to_local_end(op2.READ)
237-
op2.par_loop(kernel, Vc.mesh().cell_set,
238-
next.dat(op2.INC, next.cell_node_map()),
239-
fine.dat(op2.READ, coarse_cell_to_fine_nodes),
240-
fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
241-
coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
260+
d.dat.assemble()
261+
#op2.par_loop(kernel, Vc.mesh().cell_set,
262+
# next.dat(op2.INC, next.cell_node_map()),
263+
# fine.dat(op2.READ, coarse_cell_to_fine_nodes),
264+
# fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
265+
# coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
266+
op3.do_loop(
267+
c := Vc.mesh().cells.index(),
268+
kernel(
269+
next.dat[c],
270+
fine.nodal_dat()[coarse_cell_to_fine_nodes(c)],
271+
fine_coords.nodal_dat()[coarse_cell_to_fine_coords(c)],
272+
coarse_coords.nodal_dat()[coarse_coords(c)],
273+
),
274+
)
275+
242276
fine = next
243277
Vf = Vc
244278
return coarse

0 commit comments

Comments
 (0)