diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 435b3b6dd9..2b64671bf6 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1280,7 +1280,7 @@ def TwoFormAssembler(form, *args, **kwargs): assert isinstance(form, (ufl.form.Form, slate.TensorBase)) mat_type = kwargs.pop('mat_type', None) sub_mat_type = kwargs.pop('sub_mat_type', None) - mat_type, sub_mat_type = _get_mat_type(mat_type, sub_mat_type, form.arguments()) + mat_type, sub_mat_type, block_shape = _get_mat_type(mat_type, sub_mat_type, form.arguments()) if mat_type == "matfree": # Arguably we should crash here, as we would be passing ignored arguments through kwargs.pop('needs_zeroing', None) @@ -1288,7 +1288,7 @@ def TwoFormAssembler(form, *args, **kwargs): kwargs.pop('allocation_integral_types', None) return MatrixFreeAssembler(form, *args, **kwargs) else: - return ExplicitMatrixAssembler(form, *args, mat_type=mat_type, sub_mat_type=sub_mat_type, **kwargs) + return ExplicitMatrixAssembler(form, *args, mat_type=mat_type, sub_mat_type=sub_mat_type, block_shape=block_shape, **kwargs) def _get_mat_type(mat_type, sub_mat_type, arguments): @@ -1332,6 +1332,8 @@ def _get_mat_type(mat_type, sub_mat_type, arguments): if mat_type == "nest": mat_type = {} + block_shape = dict() + test, trial = arguments for test_subspace in test.function_space(): for trial_subspace in trial.function_space(): @@ -1340,9 +1342,10 @@ def _get_mat_type(mat_type, sub_mat_type, arguments): mat_type[subspace_key] = "dat" else: mat_type[subspace_key] = sub_mat_type + block_shape[subspace_key] = test_subspace.value_size, trial_subspace.value_size # sub_mat_type no longer used after this - return mat_type, sub_mat_type + return mat_type, sub_mat_type, block_shape class ExplicitMatrixAssembler(ParloopFormAssembler): @@ -1369,7 +1372,8 @@ def _cache_key(cls, *args, **kwargs): @FormAssembler._skip_if_initialised def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, mat_type=None, sub_mat_type=None, options_prefix=None, appctx=None, weight=1.0, - allocation_integral_types=None): + allocation_integral_types=None, + block_shape=None): # The previous API was that the user would specify mat_type and sub_mat_type, now # mat_type can be a dict so convert to that here. # NOTE: This function is called in TwoFormAssembler, should it be? @@ -1377,6 +1381,7 @@ def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing= super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing) self._mat_type = mat_type + self._block_shape = block_shape self._options_prefix = options_prefix self._appctx = appctx self.weight = weight @@ -1389,6 +1394,7 @@ def allocate(self): trial, self._mat_type, self._make_maps_and_regions(), + block_shape=self._block_shape, ) mat = op3.Mat.from_sparsity(sparsity) return matrix.Matrix( @@ -1401,7 +1407,7 @@ def allocate(self): ) @staticmethod - def _make_sparsity(test, trial, mat_type, maps_and_regions): + def _make_sparsity(test, trial, mat_type, maps_and_regions, block_shape=None): # Is this overly restrictive? if any(len(a.function_space()) > 1 for a in [test, trial]) and mat_type == "baij": raise ValueError("BAIJ matrix type makes no sense for mixed spaces, use 'aij'") @@ -1409,7 +1415,8 @@ def _make_sparsity(test, trial, mat_type, maps_and_regions): sparsity = op3.Sparsity( test.function_space().axes, trial.function_space().axes, - mat_type=mat_type + mat_type=mat_type, + block_shape=block_shape, ) # Pretend that we are doing assembly by looping over the right