Skip to content

pyop3: matnest #3532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: connorjward/pyop3
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
import contextlib
import functools
import itertools
import operator

Check failure on line 5 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:5:1: F401 'operator' imported but unused
from collections import OrderedDict, defaultdict

Check failure on line 6 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:6:1: F401 'collections.OrderedDict' imported but unused

Check failure on line 6 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:6:1: F401 'collections.defaultdict' imported but unused
from collections.abc import Sequence # noqa: F401
from itertools import product

Check failure on line 8 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:8:1: F401 'itertools.product' imported but unused
from functools import cached_property

import cachetools

Check failure on line 11 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:11:1: F401 'cachetools' imported but unused
from pyrsistent import freeze, pmap

Check failure on line 12 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:12:1: F401 'pyrsistent.freeze' imported but unused
import finat

Check failure on line 13 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:13:1: F401 'finat' imported but unused
import loopy as lp

Check failure on line 14 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:14:1: F401 'loopy as lp' imported but unused
import firedrake
import numpy
from pyadjoint.tape import annotate_tape
from tsfc import kernel_args
from tsfc.finatinterface import create_element

Check failure on line 19 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:19:1: F401 'tsfc.finatinterface.create_element' imported but unused
from tsfc.ufl_utils import extract_firedrake_constants
import ufl
import pyop3 as op3
from firedrake import (extrusion_utils as eutils, matrix, parameters, solving,

Check failure on line 23 in firedrake/assemble.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/assemble.py:23:1: F401 'firedrake.extrusion_utils as eutils' imported but unused
tsfc_interface, utils)
from firedrake.adjoint_utils import annotate_assemble
from firedrake.ufl_expr import extract_unique_domain
Expand Down Expand Up @@ -1280,15 +1280,15 @@
assert isinstance(form, (ufl.form.Form, slate.TensorBase))
mat_type = kwargs.pop('mat_type', None)
sub_mat_type = kwargs.pop('sub_mat_type', None)
mat_type, sub_mat_type = _get_mat_type(mat_type, sub_mat_type, form.arguments())
mat_type, sub_mat_type, block_shape = _get_mat_type(mat_type, sub_mat_type, form.arguments())
if mat_type == "matfree":
# Arguably we should crash here, as we would be passing ignored arguments through
kwargs.pop('needs_zeroing', None)
kwargs.pop('weight', None)
kwargs.pop('allocation_integral_types', None)
return MatrixFreeAssembler(form, *args, **kwargs)
else:
return ExplicitMatrixAssembler(form, *args, mat_type=mat_type, sub_mat_type=sub_mat_type, **kwargs)
return ExplicitMatrixAssembler(form, *args, mat_type=mat_type, sub_mat_type=sub_mat_type, block_shape=block_shape, **kwargs)


def _get_mat_type(mat_type, sub_mat_type, arguments):
Expand Down Expand Up @@ -1332,6 +1332,8 @@

if mat_type == "nest":
mat_type = {}
block_shape = dict()

test, trial = arguments
for test_subspace in test.function_space():
for trial_subspace in trial.function_space():
Expand All @@ -1340,9 +1342,10 @@
mat_type[subspace_key] = "dat"
else:
mat_type[subspace_key] = sub_mat_type
block_shape[subspace_key] = test_subspace.value_size, trial_subspace.value_size

# sub_mat_type no longer used after this
return mat_type, sub_mat_type
return mat_type, sub_mat_type, block_shape


class ExplicitMatrixAssembler(ParloopFormAssembler):
Expand All @@ -1369,14 +1372,16 @@
@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
mat_type=None, sub_mat_type=None, options_prefix=None, appctx=None, weight=1.0,
allocation_integral_types=None):
allocation_integral_types=None,
block_shape=None):
# The previous API was that the user would specify mat_type and sub_mat_type, now
# mat_type can be a dict so convert to that here.
# NOTE: This function is called in TwoFormAssembler, should it be?
# mat_type, sub_mat_type = _get_mat_type(mat_type, sub_mat_type, form.arguments())

super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._mat_type = mat_type
self._block_shape = block_shape
self._options_prefix = options_prefix
self._appctx = appctx
self.weight = weight
Expand All @@ -1389,6 +1394,7 @@
trial,
self._mat_type,
self._make_maps_and_regions(),
block_shape=self._block_shape,
)
mat = op3.Mat.from_sparsity(sparsity)
return matrix.Matrix(
Expand All @@ -1401,15 +1407,16 @@
)

@staticmethod
def _make_sparsity(test, trial, mat_type, maps_and_regions):
def _make_sparsity(test, trial, mat_type, maps_and_regions, block_shape=None):
# Is this overly restrictive?
if any(len(a.function_space()) > 1 for a in [test, trial]) and mat_type == "baij":
raise ValueError("BAIJ matrix type makes no sense for mixed spaces, use 'aij'")

sparsity = op3.Sparsity(
test.function_space().axes,
trial.function_space().axes,
mat_type=mat_type
mat_type=mat_type,
block_shape=block_shape,
)

# Pretend that we are doing assembly by looping over the right
Expand Down
Loading