Skip to content

pyop3: parloops #3534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 14 commits into
base: connorjward/pyop3
Choose a base branch
from
2 changes: 1 addition & 1 deletion firedrake/adjoint_utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def _ad_copy(self):
return self._constant_from_values()

def _ad_dim(self):
return self.dat.cdim
return self.dat.data_ro.size

def _ad_imul(self, other):
self.assign(self._constant_from_values(self.dat.data_ro.reshape(-1) * other))
Expand Down
10 changes: 5 additions & 5 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
import contextlib
import functools
import itertools
import operator

Check failure on line 5 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:5:1: F401 'operator' imported but unused
from collections import OrderedDict, defaultdict

Check failure on line 6 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:6:1: F401 'collections.OrderedDict' imported but unused

Check failure on line 6 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:6:1: F401 'collections.defaultdict' imported but unused
from collections.abc import Sequence # noqa: F401
from itertools import product

Check failure on line 8 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:8:1: F401 'itertools.product' imported but unused
from functools import cached_property

import cachetools

Check failure on line 11 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:11:1: F401 'cachetools' imported but unused
from pyrsistent import freeze, pmap

Check failure on line 12 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:12:1: F401 'pyrsistent.freeze' imported but unused
import finat

Check failure on line 13 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:13:1: F401 'finat' imported but unused
import loopy as lp

Check failure on line 14 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:14:1: F401 'loopy as lp' imported but unused
import firedrake
import numpy
from pyadjoint.tape import annotate_tape
from tsfc import kernel_args
from tsfc.finatinterface import create_element

Check failure on line 19 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:19:1: F401 'tsfc.finatinterface.create_element' imported but unused
from tsfc.ufl_utils import extract_firedrake_constants
import ufl
import pyop3 as op3
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,

Check failure on line 23 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:23:1: F401 'firedrake.extrusion_utils as eutils' imported but unused
tsfc_interface, utils)
from firedrake.adjoint_utils import annotate_assemble
from firedrake.ufl_expr import extract_unique_domain
Expand Down Expand Up @@ -1001,7 +1001,7 @@
)

if needs_zeroing:
self._as_pyop3_type(tensor).eager_zero()
self._as_pyop3_type(tensor).zero()

for (lknl, _), (parloop, lgmaps) in zip(self.local_kernels, self.parloops(tensor)):
subtensor = _FormHandler.index_tensor(
Expand Down Expand Up @@ -1186,7 +1186,7 @@
def result(self, tensor):
# NOTE: If we could return the tensor here then that would avoid a
# halo exchange. That would be a very significant API change though.
tensor.assemble()
tensor.assemble(update_leaves=True)
return op3.utils.just_one(tensor.buffer._data)


Expand Down Expand Up @@ -1323,8 +1323,8 @@
if sub_mat_type is None:
sub_mat_type = parameters.parameters["default_sub_matrix_type"]

if has_real_subspace and mat_type != "nest":
raise ValueError
if has_real_subspace and mat_type not in ["nest", "matfree"]:
raise ValueError("Matrices containing real space arguments must have type 'nest' or 'matfree'")
if sub_mat_type not in {"aij", "baij"}:
raise ValueError(
f"Invalid submatrix type, '{sub_mat_type}' (not 'aij' or 'baij')"
Expand Down Expand Up @@ -1612,7 +1612,7 @@
dat = op2tensor[i, j].handle.getPythonContext().dat
if component is not None:
dat = op2.DatView(dat, component)
dat.eager_zero(subset=node_set)
dat.zero(subset=node_set)

def _check_tensor(self, tensor):
if tensor is not None and tensor.a.arguments() != self._form.arguments():
Expand Down
6 changes: 4 additions & 2 deletions firedrake/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import finat.ufl
import numpy as np
import pyop3 as op3
from pyop3.exceptions import DataValueError
import pytools
from pyadjoint.tape import annotate_tape
from pyop2.utils import cached_property
Expand Down Expand Up @@ -256,8 +257,9 @@ def _assign_single_dat(self, lhs, subset, rvalue, assign_to_halos):
if isinstance(rvalue, numbers.Number) or rvalue.shape in {(1,), assignee.shape}:
assignee[...] = rvalue
else:
cdim = self._assignee.function_space()._cdim
assert rvalue.shape == (cdim,)
cdim = self._assignee.function_space().value_size
if rvalue.shape != (cdim,):
raise DataValueError("Assignee and assignment values are different shapes")
assignee.reshape((-1, cdim))[...] = rvalue

def _compute_rvalue(self, func_data):
Expand Down
2 changes: 1 addition & 1 deletion firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def zero(self, r):
if r.function_space() != self._function_space:
raise RuntimeError(f"{r} defined on an incompatible FunctionSpace")

r.dat.eager_zero(subset=self.constrained_points)
r.dat.zero(subset=self.constrained_points)

@PETSc.Log.EventDecorator()
def set(self, r, val):
Expand Down
16 changes: 12 additions & 4 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def assign(self, expr, subset=None):
expr = ufl.as_ufl(expr)
if isinstance(expr, ufl.classes.Zero):
with stop_annotating(modifies=(self,)):
self.dat.eager_zero(subset=subset)
self.dat.zero(subset=subset)
return self
elif (isinstance(expr, Cofunction)
and expr.function_space() == self.function_space()):
Expand Down Expand Up @@ -349,14 +349,17 @@ def vector(self):
:class:`Cofunction`"""
return vector.Vector(self)

def nodal_dat(self):
return op3.HierarchicalArray(self.function_space().nodal_axes, data=self.dat.data_rw_with_halos)

@property
def node_set(self):
r"""A :class:`pyop2.types.set.Set` containing the nodes of this
:class:`Cofunction`. One or (for rank-1 and 2
:class:`.FunctionSpace`\s) more degrees of freedom are stored
at each node.
"""
return self.function_space().node_set
return self.function_space().nodes

def ufl_id(self):
return self.uid
Expand Down Expand Up @@ -386,5 +389,10 @@ def __str__(self):
else:
return super(Cofunction, self).__str__()

def cell_node_map(self):
return self.function_space().cell_node_map()
@property
def cell_node_list(self):
return self.function_space().cell_node_list

@property
def owned_cell_node_list(self):
return self.function_space().owned_cell_node_list
18 changes: 7 additions & 11 deletions firedrake/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import finat.ufl

from tsfc.ufl_utils import TSFCConstantMixin
from pyop2.exceptions import DataTypeError, DataValueError
import pyop3 as op3
from pyop3.exceptions import DataValueError
from firedrake.petsc import PETSc
from firedrake.utils import ScalarType
from ufl.classes import all_ufl_classes, ufl_classes, terminal_classes
Expand All @@ -29,15 +29,12 @@ def _create_const(value, comm):
shape = data.shape
rank = len(shape)

if comm is not None:
raise NotImplementedError("Won't be a back door for real space here, do elsewhere")

if rank == 0:
axes = op3.AxisTree(op3.Axis(1))
axes = op3.AxisTree()
else:
axes = op3.AxisTree(op3.Axis(shape[0]))
for size in shape[1:]:
axes = axes.add_axis(op3.Axis(size), *axes.leaf)
axes = op3.AxisTree(op3.Axis({"XXX": shape[0]}, label="dim0"))
for i, s in enumerate(shape[1:]):
axes = axes.add_axis(op3.Axis({"XXX": s}, label=f"dim{i+1}"), *axes.leaf)
dat = op3.HierarchicalArray(axes, data=data.flatten())
return dat, rank, shape

Expand Down Expand Up @@ -198,11 +195,10 @@ def assign(self, value):
self

"""
if self.ufl_shape() and np.array(value).shape != self.ufl_shape():
raise DataValueError("Cannot assign to constant, value has incorrect shape")
self.dat.data_wo[...] = value
return self
# TODO pyop3
# except (DataTypeError, DataValueError) as e:
# raise ValueError(e)

def __iadd__(self, o):
raise NotImplementedError("Augmented assignment to Constant not implemented")
Expand Down
6 changes: 4 additions & 2 deletions firedrake/cython/dmcommon.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,8 @@ def closure_ordering(mesh, closure_data, closure_sizes):
CHKERR(PetscMalloc1(nverts_per_cell, &facet_verts))

closure_data_reord = tuple(np.empty_like(d) for d in closure_data)
# Must call this before loop collectively.
mesh._global_vertex_numbering
for cell in range(mesh.num_cells()):
# 1. Order vertices
for vi, vert in enumerate(closure_data[0][cell]):
Expand Down Expand Up @@ -1224,8 +1226,8 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1):
if isinstance(dm, PETSc.DMSwarm) and on_base:
raise NotImplementedError("Vertex Only Meshes cannot be extruded.")
variable = mesh.variable_layers
extruded = mesh.cell_set._extruded
extruded_periodic = mesh.cell_set._extruded_periodic
extruded = mesh.extruded
extruded_periodic = mesh.extruded_periodic
on_base_ = on_base
nodes_per_entity = np.asarray(nodes_per_entity, dtype=IntType)
if variable:
Expand Down
18 changes: 9 additions & 9 deletions firedrake/cython/mgimpl.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def coarse_to_fine_nodes(Vc, Vf, np.ndarray[PetscInt, ndim=2, mode="c"] coarse_t
PetscInt fine_layer, fine_layers, coarse_layer, coarse_layers, ratio
bint extruded

fine_map = Vf.cell_node_map().values
coarse_map = Vc.cell_node_map().values
fine_map = Vf.owned_cell_node_list
coarse_map = Vc.owned_cell_node_list

fine_cell_per_coarse_cell = coarse_to_fine_cells.shape[1]
extruded = Vc.extruded
Expand All @@ -85,7 +85,7 @@ def coarse_to_fine_nodes(Vc, Vf, np.ndarray[PetscInt, ndim=2, mode="c"] coarse_t
ndof = fine_per_cell * fine_cell_per_coarse_cell
if extruded:
ndof *= ratio
coarse_to_fine_map = np.full((Vc.dof_dset.total_size,
coarse_to_fine_map = np.full((Vc.node_count,
ndof),
-1,
dtype=IntType)
Expand Down Expand Up @@ -124,8 +124,8 @@ def fine_to_coarse_nodes(Vf, Vc, np.ndarray[PetscInt, ndim=2, mode="c"] fine_to_
PetscInt coarse_per_cell, fine_per_cell, coarse_cell, fine_cells
bint extruded

fine_map = Vf.cell_node_map().values
coarse_map = Vc.cell_node_map().values
fine_map = Vf.owned_cell_node_list
coarse_map = Vc.owned_cell_node_list

extruded = Vc.extruded

Expand All @@ -142,7 +142,7 @@ def fine_to_coarse_nodes(Vf, Vc, np.ndarray[PetscInt, ndim=2, mode="c"] fine_to_
coarse_per_fine = fine_to_coarse_cells.shape[1]
coarse_per_cell = coarse_map.shape[1]
fine_per_cell = fine_map.shape[1]
fine_to_coarse_map = np.full((Vf.dof_dset.total_size,
fine_to_coarse_map = np.full((Vf.node_count,
coarse_per_fine*coarse_per_cell),
-1,
dtype=IntType)
Expand Down Expand Up @@ -255,8 +255,8 @@ def coarse_to_fine_cells(mc, mf, clgmaps, flgmaps):
fdm = mf.topology_dm
dim = cdm.getDimension()
nref = 2 ** dim
ncoarse = mc.cell_set.size
nfine = mf.cell_set.size
ncoarse = mc.cells.owned.size
nfine = mf.cells.owned.size
co2n, _ = get_entity_renumbering(cdm, mc._cell_numbering, "cell")
_, fn2o = get_entity_renumbering(fdm, mf._cell_numbering, "cell")
coarse_to_fine = np.full((ncoarse, nref), -1, dtype=PETSc.IntType)
Expand All @@ -274,7 +274,7 @@ def coarse_to_fine_cells(mc, mf, clgmaps, flgmaps):
# Need to permute order of co2n so it maps from non-overlapped
# cells to new cells (these may have changed order). Need to
# map all known cells through.
idx = np.arange(mc.cell_set.total_size, dtype=PETSc.IntType)
idx = np.arange(mc.cells.size, dtype=PETSc.IntType)
# LocalToGlobal
co.apply(idx, result=idx)
# GlobalToLocal
Expand Down
6 changes: 5 additions & 1 deletion firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ def vector(self):
r"""Return a :class:`.Vector` wrapping the data in this :class:`Function`"""
return vector.Vector(self)

def nodal_dat(self):
return op3.HierarchicalArray(self.function_space().nodal_axes,
data=self.dat.data_rw_with_halos)

@PETSc.Log.EventDecorator()
def interpolate(
self,
Expand Down Expand Up @@ -471,7 +475,7 @@ def assign(self, expr, subset=Ellipsis):
except (DataTypeError, DataValueError) as e:
raise ValueError(e)
elif expr == 0:
self.dat.eager_zero(subset=subset)
self.dat.zero(subset=subset)
else:
from firedrake.assign import Assigner
Assigner(self, expr, subset).assign()
Expand Down
6 changes: 3 additions & 3 deletions firedrake/functionspacedata.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def get_node_set(mesh, key):
node_classes = mesh.node_classes(nodes_per_entity, real_tensorproduct=real_tensorproduct)
halo = halo_mod.Halo(mesh.topology_dm, global_numbering, comm=mesh.comm)
node_set = op2.Set(node_classes, halo=halo, comm=mesh.comm)
extruded = mesh.cell_set._extruded
extruded = mesh.extruded

assert global_numbering.getStorageSize() == node_set.total_size
if not extruded and node_set.total_size >= (1 << (IntType.itemsize * 8 - 4)):
Expand Down Expand Up @@ -211,7 +211,7 @@ def get_boundary_masks(mesh, key, finat_element):
with points in the closure of point p. The basis function
indices are in the index array, starting at section.getOffset(p).
"""
if not mesh.cell_set._extruded:
if not mesh.extruded:
return None
_, kind = key
assert kind in {"cell", "interior_facet"}
Expand Down Expand Up @@ -453,7 +453,7 @@ def __init__(self, mesh, ufl_element):
self.node_set = node_set
self.cell_boundary_masks = get_boundary_masks(mesh, (edofs_key, "cell"), finat_element)
self.interior_facet_boundary_masks = get_boundary_masks(mesh, (edofs_key, "interior_facet"), finat_element)
self.extruded = mesh.cell_set._extruded
self.extruded = mesh.extruded
self.mesh = mesh
# self.global_numbering = global_numbering

Expand Down
37 changes: 31 additions & 6 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, mesh, element, component=None, cargo=None):
self.cargo = cargo
self.comm = mesh.comm
self._comm = mpi.internal_comm(mesh.comm, self)
self.extruded = mesh.extruded

@classmethod
def create(cls, function_space, mesh):
Expand Down Expand Up @@ -263,7 +264,7 @@ def get_work_function(self, zero=True):
if not out:
cache[function] = True
if zero:
function.dat.eager_zero()
function.dat.zero()
return function
if len(cache) == self.max_work_functions:
raise ValueError("Can't check out more than %d work functions." %
Expand Down Expand Up @@ -562,7 +563,7 @@ def __init__(self, mesh, element, name=None):
mesh._shared_data_cache[key] = (axes, block_axes)

self.axes = axes
self.block_axes = axes
self.block_axes = block_axes

# These properties are overridden in ProxyFunctionSpaces, but are
# provided by FunctionSpace so that we don't have to special case.
Expand Down Expand Up @@ -675,6 +676,10 @@ def _local_ises(self):
@utils.cached_property
def local_section(self):
section = PETSc.Section().create(comm=self.comm)
if self._ufl_function_space.ufl_element().family() == "Real":
# If real we don't need to populate the section
return section

points = self._mesh.points
section.setChart(0, points.size)
perm = PETSc.IS().createGeneral(points.numbering.data_ro, comm=self.comm)
Expand All @@ -696,8 +701,28 @@ def _cdim(self):

@utils.cached_property
def cell_node_list(self):
r"""A numpy array mapping mesh cells to function space nodes."""
return self._shared_data.entity_node_lists[self.mesh().cell_set]
r"""A numpy array mapping mesh cells to function space nodes (includes halo)."""
from firedrake.parloops import pack_pyop3_tensor
cells = self.mesh().cells
# Pass self.sub(0) to get nodes from the scalar version of this function space
packed_axes = pack_pyop3_tensor(self.block_axes, self.sub(0), cells.index(include_ghost_points=True), "cell")
return packed_axes.tabulated_offsets.buffer.data_rw_with_halos.reshape((cells.size, -1))

@utils.cached_property
def owned_cell_node_list(self):
r"""A numpy array mapping owned mesh cells to function space nodes."""
cells = self.mesh().cells
return self.cell_node_list[:cells.owned.size]

@utils.cached_property
def nodes(self):
ax = self.block_axes
return op3.Axis([op3.AxisComponent((ax.owned.size, ax.size),
"XXX", rank_equal=False)], "nodes", numbering=None, sf=ax.sf)

@utils.cached_property
def nodal_axes(self):
return op3.AxisTree.from_iterable([self.nodes, self.value_size])

@utils.cached_property
def topological(self):
Expand Down Expand Up @@ -771,13 +796,13 @@ def node_count(self):
this process. If the :class:`FunctionSpace` has :attr:`FunctionSpace.rank` 0, this
is equal to the :attr:`FunctionSpace.dof_count`, otherwise the :attr:`FunctionSpace.dof_count` is
:attr:`dim` times the :attr:`node_count`."""
return self.node_set.total_size
return self.block_axes.size

@utils.cached_property
def dof_count(self):
r"""The number of degrees of freedom (includes halo dofs) of this
function space on this process. Cf. :attr:`FunctionSpace.node_count` ."""
return self.node_count*self.value_size
return self.axes.size

def dim(self):
r"""The global number of degrees of freedom for this function space.
Expand Down
2 changes: 1 addition & 1 deletion firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ def argfs_map(pt):
target_plex = target_mesh.topology
op3.do_loop(
c := target_plex.owned_cells.index(),
sparsity[target_plex.closure(c), target_plex.closure(c)].assign(666),
sparsity[target_plex.closure(c), target_plex.closure(c)].assign(666, eager=False),
)
tensor = op3.Mat.from_sparsity(sparsity)
f = tensor
Expand Down
2 changes: 1 addition & 1 deletion firedrake/linear_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _rhs(self):

def _lifted(self, b):
u, update, blift = self._rhs
u.dat.eager_zero()
u.dat.zero()
for bc in self.A.bcs:
bc.apply(u)
update(tensor=blift)
Expand Down
2 changes: 1 addition & 1 deletion firedrake/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"RED", "GREEN", "BLUE")


packages = ("pyop2", "tsfc", "firedrake", "UFL")
packages = ("pyop2", "pyop3", "tsfc", "firedrake", "UFL")


logger = logging.getLogger("firedrake")
Expand Down
Loading
Loading