Skip to content

pyop3: matnest #3532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: connorjward/pyop3
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,15 +1280,15 @@ def TwoFormAssembler(form, *args, **kwargs):
assert isinstance(form, (ufl.form.Form, slate.TensorBase))
mat_type = kwargs.pop('mat_type', None)
sub_mat_type = kwargs.pop('sub_mat_type', None)
mat_type, sub_mat_type = _get_mat_type(mat_type, sub_mat_type, form.arguments())
mat_type, sub_mat_type, block_shape = _get_mat_type(mat_type, sub_mat_type, form.arguments())
if mat_type == "matfree":
# Arguably we should crash here, as we would be passing ignored arguments through
kwargs.pop('needs_zeroing', None)
kwargs.pop('weight', None)
kwargs.pop('allocation_integral_types', None)
return MatrixFreeAssembler(form, *args, **kwargs)
else:
return ExplicitMatrixAssembler(form, *args, mat_type=mat_type, sub_mat_type=sub_mat_type, **kwargs)
return ExplicitMatrixAssembler(form, *args, mat_type=mat_type, sub_mat_type=sub_mat_type, block_shape=block_shape, **kwargs)


def _get_mat_type(mat_type, sub_mat_type, arguments):
Expand Down Expand Up @@ -1332,6 +1332,8 @@ def _get_mat_type(mat_type, sub_mat_type, arguments):

if mat_type == "nest":
mat_type = {}
block_shape = dict()

test, trial = arguments
for test_subspace in test.function_space():
for trial_subspace in trial.function_space():
Expand All @@ -1340,9 +1342,10 @@ def _get_mat_type(mat_type, sub_mat_type, arguments):
mat_type[subspace_key] = "dat"
else:
mat_type[subspace_key] = sub_mat_type
block_shape[subspace_key] = test_subspace.value_size, trial_subspace.value_size

# sub_mat_type no longer used after this
return mat_type, sub_mat_type
return mat_type, sub_mat_type, block_shape


class ExplicitMatrixAssembler(ParloopFormAssembler):
Expand All @@ -1369,14 +1372,16 @@ def _cache_key(cls, *args, **kwargs):
@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
mat_type=None, sub_mat_type=None, options_prefix=None, appctx=None, weight=1.0,
allocation_integral_types=None):
allocation_integral_types=None,
block_shape=None):
# The previous API was that the user would specify mat_type and sub_mat_type, now
# mat_type can be a dict so convert to that here.
# NOTE: This function is called in TwoFormAssembler, should it be?
# mat_type, sub_mat_type = _get_mat_type(mat_type, sub_mat_type, form.arguments())

super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._mat_type = mat_type
self._block_shape = block_shape
self._options_prefix = options_prefix
self._appctx = appctx
self.weight = weight
Expand All @@ -1389,6 +1394,7 @@ def allocate(self):
trial,
self._mat_type,
self._make_maps_and_regions(),
block_shape=self._block_shape,
)
mat = op3.Mat.from_sparsity(sparsity)
return matrix.Matrix(
Expand All @@ -1401,15 +1407,16 @@ def allocate(self):
)

@staticmethod
def _make_sparsity(test, trial, mat_type, maps_and_regions):
def _make_sparsity(test, trial, mat_type, maps_and_regions, block_shape=None):
# Is this overly restrictive?
if any(len(a.function_space()) > 1 for a in [test, trial]) and mat_type == "baij":
raise ValueError("BAIJ matrix type makes no sense for mixed spaces, use 'aij'")

sparsity = op3.Sparsity(
test.function_space().axes,
trial.function_space().axes,
mat_type=mat_type
mat_type=mat_type,
block_shape=block_shape,
)

# Pretend that we are doing assembly by looping over the right
Expand Down