Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions finat/ufl/restrictedelement.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,31 @@
# Modified by Matthew Scroggs, 2023

from finat.ufl.finiteelementbase import FiniteElementBase
from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement
from ufl.sobolevspace import L2

valid_restriction_domains = ("interior", "facet", "ridge", "face", "edge", "vertex", "reduced")


class RestrictedElement(FiniteElementBase):
"""Represents the restriction of a finite element to a type of cell entity."""
def __new__(cls, element, restriction_domain):
"""
Restricted qualifier must be below Mixed/Vector/Tensor so we
overload __new__ to return:

RestrictedElement(MixedElement(elem0, elem1), dom) -> MixedElement(RestrictedElement(elem0, dom), RestrictedElement(elem1, dom))

and similarly for VectorElement and TensorElement.
"""
if isinstance(element, (VectorElement, TensorElement)):
return element.reconstruct(sub_element=RestrictedElement(element.sub_elements[0], restriction_domain))

elif isinstance(element, MixedElement):
return MixedElement([RestrictedElement(e, restriction_domain) for e in element.sub_elements])

else: # hopefully no special casing needed
return super().__new__(cls)

def __init__(self, element, restriction_domain):
"""Doc."""
Expand Down Expand Up @@ -68,9 +86,10 @@ def restriction_domain(self):
"""Return the domain onto which the element is restricted."""
return self._restriction_domain

def reconstruct(self, **kwargs):
def reconstruct(self, element=None, **kwargs):
"""Doc."""
element = self._element.reconstruct(**kwargs)
if element is None:
element = self._element.reconstruct(**kwargs)
return RestrictedElement(element, self._restriction_domain)

def __str__(self):
Expand Down
6 changes: 0 additions & 6 deletions test/finat/test_create_broken_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,3 @@ def test_create_broken_mixed_element(sub_elements):
mixed = MixedElement(sub_elements)
expected = MixedElement([BrokenElement(elem) for elem in sub_elements])
assert BrokenElement(mixed) == expected


if __name__ == "__main__":
import os
import sys
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])
43 changes: 43 additions & 0 deletions test/finat/test_create_restricted_element.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import ufl
from finat.ufl import FiniteElement, RestrictedElement, VectorElement, TensorElement, MixedElement

sub_elements = [
FiniteElement("CG", ufl.triangle, 1),
FiniteElement("BDM", ufl.triangle, 2),
FiniteElement("DG", ufl.interval, 2, variant="spectral")
]

sub_ids = [
"CG(1)",
"BDM(2)",
"DG(2,spectral)"
]


@pytest.mark.parametrize("sub_element", sub_elements, ids=sub_ids)
@pytest.mark.parametrize("shape", (1, 2, (2, 3)), ids=("1", "2", "(2,3)"))
def test_create_restricted_vector_or_tensor_element(shape, sub_element):
"""Check that RestrictedElement returns a nested element
for mixed, vector, and tensor elements.
"""
if not isinstance(shape, int):
make_element = lambda elem: TensorElement(elem, shape=shape)
else:
make_element = lambda elem: VectorElement(elem, dim=shape)

tensor = make_element(sub_element)
expected = make_element(RestrictedElement(sub_element, "interior"))

assert RestrictedElement(tensor, "interior") == expected


@pytest.mark.parametrize("sub_elements", [sub_elements, sub_elements[-1:]],
ids=(f"nsubs={len(sub_elements)}", "nsubs=1"))
def test_create_restricted_mixed_element(sub_elements):
"""Check that RestrictedElement returns a nested element
for mixed, vector, and tensor elements.
"""
mixed = MixedElement(sub_elements)
expected = MixedElement([elem["facet"] for elem in sub_elements])
assert mixed["facet"] == expected