Skip to content

pyop3: multigrid fixes #3521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: connorjward/pyop3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,14 +349,17 @@ def vector(self):
:class:`Cofunction`"""
return vector.Vector(self)

def nodal_dat(self):
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw_with_halos)

@property
def node_set(self):
r"""A :class:`pyop2.types.set.Set` containing the nodes of this
:class:`Cofunction`. One or (for rank-1 and 2
:class:`.FunctionSpace`\s) more degrees of freedom are stored
at each node.
"""
return self.function_space().node_set
return self.function_space().nodes

def ufl_id(self):
return self.uid
Expand Down Expand Up @@ -386,5 +389,10 @@ def __str__(self):
else:
return super(Cofunction, self).__str__()

def cell_node_map(self):
return self.function_space().cell_node_map()
@property
def cell_node_list(self):
return self.function_space().cell_node_list

@property
def owned_cell_node_list(self):
return self.function_space().owned_cell_node_list
4 changes: 2 additions & 2 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1224,8 +1224,8 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1):
if isinstance(dm, PETSc.DMSwarm) and on_base:
raise NotImplementedError("Vertex Only Meshes cannot be extruded.")
variable = mesh.variable_layers
extruded = mesh.cell_set._extruded
extruded_periodic = mesh.cell_set._extruded_periodic
extruded = mesh.extruded
extruded_periodic = mesh.extruded_periodic
on_base_ = on_base
nodes_per_entity = np.asarray(nodes_per_entity, dtype=IntType)
if variable:
Expand Down
18 changes: 9 additions & 9 deletions firedrake/cython/mgimpl.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def coarse_to_fine_nodes(Vc, Vf, np.ndarray[PetscInt, ndim=2, mode="c"] coarse_t
PetscInt fine_layer, fine_layers, coarse_layer, coarse_layers, ratio
bint extruded

fine_map = Vf.cell_node_map().values
coarse_map = Vc.cell_node_map().values
fine_map = Vf.owned_cell_node_list
coarse_map = Vc.owned_cell_node_list

fine_cell_per_coarse_cell = coarse_to_fine_cells.shape[1]
extruded = Vc.extruded
Expand All @@ -85,7 +85,7 @@ def coarse_to_fine_nodes(Vc, Vf, np.ndarray[PetscInt, ndim=2, mode="c"] coarse_t
ndof = fine_per_cell * fine_cell_per_coarse_cell
if extruded:
ndof *= ratio
coarse_to_fine_map = np.full((Vc.dof_dset.total_size,
coarse_to_fine_map = np.full((Vc.node_count,
ndof),
-1,
dtype=IntType)
Expand Down Expand Up @@ -124,8 +124,8 @@ def fine_to_coarse_nodes(Vf, Vc, np.ndarray[PetscInt, ndim=2, mode="c"] fine_to_
PetscInt coarse_per_cell, fine_per_cell, coarse_cell, fine_cells
bint extruded

fine_map = Vf.cell_node_map().values
coarse_map = Vc.cell_node_map().values
fine_map = Vf.owned_cell_node_list
coarse_map = Vc.owned_cell_node_list

extruded = Vc.extruded

Expand All @@ -142,7 +142,7 @@ def fine_to_coarse_nodes(Vf, Vc, np.ndarray[PetscInt, ndim=2, mode="c"] fine_to_
coarse_per_fine = fine_to_coarse_cells.shape[1]
coarse_per_cell = coarse_map.shape[1]
fine_per_cell = fine_map.shape[1]
fine_to_coarse_map = np.full((Vf.dof_dset.total_size,
fine_to_coarse_map = np.full((Vf.node_count,
coarse_per_fine*coarse_per_cell),
-1,
dtype=IntType)
Expand Down Expand Up @@ -255,8 +255,8 @@ def coarse_to_fine_cells(mc, mf, clgmaps, flgmaps):
fdm = mf.topology_dm
dim = cdm.getDimension()
nref = 2 ** dim
ncoarse = mc.cell_set.size
nfine = mf.cell_set.size
ncoarse = mc.cells.owned.size
nfine = mf.cells.owned.size
co2n, _ = get_entity_renumbering(cdm, mc._cell_numbering, "cell")
_, fn2o = get_entity_renumbering(fdm, mf._cell_numbering, "cell")
coarse_to_fine = np.full((ncoarse, nref), -1, dtype=PETSc.IntType)
Expand All @@ -274,7 +274,7 @@ def coarse_to_fine_cells(mc, mf, clgmaps, flgmaps):
# Need to permute order of co2n so it maps from non-overlapped
# cells to new cells (these may have changed order). Need to
# map all known cells through.
idx = np.arange(mc.cell_set.total_size, dtype=PETSc.IntType)
idx = np.arange(mc.cells.size, dtype=PETSc.IntType)
# LocalToGlobal
co.apply(idx, result=idx)
# GlobalToLocal
Expand Down
4 changes: 4 additions & 0 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ def vector(self):
r"""Return a :class:`.Vector` wrapping the data in this :class:`Function`"""
return vector.Vector(self)

def nodal_dat(self):
return op3.HierarchicalArray(self.function_space().nodal_axes,
data=self.dat.data_rw_with_halos)

@PETSc.Log.EventDecorator()
def interpolate(
self,
Expand Down
6 changes: 3 additions & 3 deletions firedrake/functionspacedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_node_set(mesh, key):
node_classes = mesh.node_classes(nodes_per_entity, real_tensorproduct=real_tensorproduct)
halo = halo_mod.Halo(mesh.topology_dm, global_numbering, comm=mesh.comm)
node_set = op2.Set(node_classes, halo=halo, comm=mesh.comm)
extruded = mesh.cell_set._extruded
extruded = mesh.extruded

assert global_numbering.getStorageSize() == node_set.total_size
if not extruded and node_set.total_size >= (1 << (IntType.itemsize * 8 - 4)):
Expand Down Expand Up @@ -211,7 +211,7 @@ def get_boundary_masks(mesh, key, finat_element):
with points in the closure of point p. The basis function
indices are in the index array, starting at section.getOffset(p).
"""
if not mesh.cell_set._extruded:
if not mesh.extruded:
return None
_, kind = key
assert kind in {"cell", "interior_facet"}
Expand Down Expand Up @@ -453,7 +453,7 @@ def __init__(self, mesh, ufl_element):
self.node_set = node_set
self.cell_boundary_masks = get_boundary_masks(mesh, (edofs_key, "cell"), finat_element)
self.interior_facet_boundary_masks = get_boundary_masks(mesh, (edofs_key, "interior_facet"), finat_element)
self.extruded = mesh.cell_set._extruded
self.extruded = mesh.extruded
self.mesh = mesh
# self.global_numbering = global_numbering

Expand Down
45 changes: 40 additions & 5 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, mesh, element, component=None, cargo=None):
self.cargo = cargo
self.comm = mesh.comm
self._comm = mpi.internal_comm(mesh.comm, self)
self.extruded = mesh.extruded

@classmethod
def create(cls, function_space, mesh):
Expand Down Expand Up @@ -562,7 +563,7 @@ def __init__(self, mesh, element, name=None):
mesh._shared_data_cache[key] = (axes, block_axes)

self.axes = axes
self.block_axes = axes
self.block_axes = block_axes

# These properties are overridden in ProxyFunctionSpaces, but are
# provided by FunctionSpace so that we don't have to special case.
Expand Down Expand Up @@ -694,10 +695,44 @@ def local_section(self):
def _cdim(self):
return self.value_size

@utils.cached_property
def nodes(self):
ax = self.block_axes
return op3.Axis([op3.AxisComponent((ax.owned.size, ax.size),
"XXX", rank_equal=False)], "nodes", numbering=None, sf=ax.sf)

@utils.cached_property
def nodal_axes(self):
return op3.AxisTree.from_iterable([self.nodes, self.value_size])

@utils.cached_property
def cell_node_dat(self):
from firedrake.parloops import pack_pyop3_tensor
cells = self.mesh().cells
# Pass self.sub(0) to get nodes from the scalar version of this function space
packed_axes = pack_pyop3_tensor(self.block_axes, self.sub(0), cells.index(include_ghost_points=True), "cell")
return packed_axes.tabulated_offsets

@utils.cached_property
def cell_node_map(self):
from pyrsistent import freeze
return op3.Map({
freeze({self.mesh().topology.name: self.mesh().cells.owned.root.component.label}): [
op3.TabulatedMapComponent(self.nodes.label, self.nodes.component.label, self.cell_node_dat)
]
})

@utils.cached_property
def cell_node_list(self):
r"""A numpy array mapping mesh cells to function space nodes."""
return self._shared_data.entity_node_lists[self.mesh().cell_set]
r"""A numpy array mapping mesh cells to function space nodes (includes halo)."""
cells = self.mesh().cells
return self.cell_node_dat.buffer.data_rw_with_halos.reshape((cells.size, -1))

@utils.cached_property
def owned_cell_node_list(self):
r"""A numpy array mapping owned mesh cells to function space nodes."""
cells = self.mesh().cells
return self.cell_node_list[:cells.owned.size]

@utils.cached_property
def topological(self):
Expand Down Expand Up @@ -771,13 +806,13 @@ def node_count(self):
this process. If the :class:`FunctionSpace` has :attr:`FunctionSpace.rank` 0, this
is equal to the :attr:`FunctionSpace.dof_count`, otherwise the :attr:`FunctionSpace.dof_count` is
:attr:`dim` times the :attr:`node_count`."""
return self.node_set.total_size
return self.block_axes.size

@utils.cached_property
def dof_count(self):
r"""The number of degrees of freedom (includes halo dofs) of this
function space on this process. Cf. :attr:`FunctionSpace.node_count` ."""
return self.node_count*self.value_size
return self.axes.size

def dim(self):
r"""The global number of degrees of freedom for this function space.
Expand Down
1 change: 1 addition & 0 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def _entity_numbering(self, label):
for old_component_num, new_component_num in enumerate(renumbering):
old_pt = old_component_num + self.points._component_offsets[component_index]
section.setOffset(old_pt, new_component_num)
section.setDof(old_pt, 1)

return section

Expand Down
81 changes: 50 additions & 31 deletions firedrake/mg/interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pyop2 import op2
import pyop3 as op3

import firedrake
from firedrake import ufl_expr
Expand Down Expand Up @@ -81,13 +81,18 @@ def prolong(coarse, fine):
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [coarse, coarse_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, next.node_set,
next.dat(op2.WRITE),
coarse.dat(op2.READ, fine_to_coarse),
node_locations.dat(op2.READ),
coarse_coords.dat(op2.READ, fine_to_coarse_coords))
d.dat.assemble()

id_map = utils.owned_node_map(Vf)
op3.do_loop(
n := Vf.nodes.owned.index(),
kernel(
next.nodal_dat()[id_map(n)],
coarse.nodal_dat()[fine_to_coarse(n)],
node_locations.nodal_dat()[id_map(n)],
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
),
)
coarse = next
Vc = Vf
return fine
Expand Down Expand Up @@ -125,7 +130,7 @@ def restrict(fine_dual, coarse_dual):
for j in range(repeat):
next_level -= 1
if j == repeat - 1:
coarse_dual.dat.eager_zero()
coarse_dual.dat.zero()
next = coarse_dual
else:
Vc = firedrake.FunctionSpace(meshes[next_level], element)
Expand All @@ -142,14 +147,19 @@ def restrict(fine_dual, coarse_dual):
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [coarse_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
d.dat.assemble()
kernel = kernels.restrict_kernel(Vf, Vc)
op2.par_loop(kernel, fine_dual.node_set,
next.dat(op2.INC, fine_to_coarse),
fine_dual.dat(op2.READ),
node_locations.dat(op2.READ),
coarse_coords.dat(op2.READ, fine_to_coarse_coords))

id_map = utils.owned_node_map(Vf)
op3.do_loop(
n := Vf.nodes.owned.index(),
kernel(
next.nodal_dat()[fine_to_coarse(n)],
fine_dual.nodal_dat()[id_map(n)],
node_locations.nodal_dat()[id_map(n)],
coarse_coords.nodal_dat()[fine_to_coarse_coords(n)],
),
)
fine_dual = next
Vf = Vc
return coarse_dual
Expand Down Expand Up @@ -201,7 +211,7 @@ def inject(fine, coarse):
for j in range(repeat):
next_level -= 1
if j == repeat - 1:
coarse.dat.eager_zero()
coarse.dat.zero()
next = coarse
Vc = next.function_space()
else:
Expand All @@ -217,13 +227,18 @@ def inject(fine, coarse):
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [fine, fine_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, next.node_set,
next.dat(op2.INC),
node_locations.dat(op2.READ),
fine.dat(op2.READ, coarse_node_to_fine_nodes),
fine_coords.dat(op2.READ, coarse_node_to_fine_coords))
d.dat.assemble()

id_map = utils.owned_node_map(Vc)
op3.do_loop(
n := Vc.nodes.owned.index(),
kernel(
next.nodal_dat()[id_map(n)],
node_locations.nodal_dat()[id_map(n)],
fine.nodal_dat()[coarse_node_to_fine_nodes(n)],
fine_coords.nodal_dat()[coarse_node_to_fine_coords(n)],
),
)
else:
coarse_coords = Vc.mesh().coordinates
fine_coords = Vf.mesh().coordinates
Expand All @@ -232,13 +247,17 @@ def inject(fine, coarse):
# Have to do this, because the node set core size is not right for
# this expanded stencil
for d in [fine, fine_coords]:
d.dat.global_to_local_begin(op2.READ)
d.dat.global_to_local_end(op2.READ)
op2.par_loop(kernel, Vc.mesh().cell_set,
next.dat(op2.INC, next.cell_node_map()),
fine.dat(op2.READ, coarse_cell_to_fine_nodes),
fine_coords.dat(op2.READ, coarse_cell_to_fine_coords),
coarse_coords.dat(op2.READ, coarse_coords.cell_node_map()))
d.dat.assemble()
op3.do_loop(
c := Vc.mesh().cells.owned.index(),
kernel(
next.dat[c],
fine.nodal_dat()[coarse_cell_to_fine_nodes(c)],
fine_coords.nodal_dat()[coarse_cell_to_fine_coords(c)],
coarse_coords.nodal_dat()[coarse_coords.function_space().cell_node_map(c)],
),
)

fine = next
Vf = Vc
return coarse
Loading
Loading