Skip to content

Commit 924e0c5

Browse files
authored
RestrictedElement: nest inside MixedElement (#202)
* RestrictedElement: nest inside MixedElement * reconstruct
1 parent f3211e7 commit 924e0c5

File tree

3 files changed

+64
-8
lines changed

3 files changed

+64
-8
lines changed

finat/ufl/restrictedelement.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,31 @@
1212
# Modified by Matthew Scroggs, 2023
1313

1414
from finat.ufl.finiteelementbase import FiniteElementBase
15+
from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement
1516
from ufl.sobolevspace import L2
1617

1718
valid_restriction_domains = ("interior", "facet", "ridge", "face", "edge", "vertex", "reduced")
1819

1920

2021
class RestrictedElement(FiniteElementBase):
2122
"""Represents the restriction of a finite element to a type of cell entity."""
23+
def __new__(cls, element, restriction_domain):
24+
"""
25+
Restricted qualifier must be below Mixed/Vector/Tensor so we
26+
overload __new__ to return:
27+
28+
RestrictedElement(MixedElement(elem0, elem1), dom) -> MixedElement(RestrictedElement(elem0, dom), RestrictedElement(elem1, dom))
29+
30+
and similarly for VectorElement and TensorElement.
31+
"""
32+
if isinstance(element, (VectorElement, TensorElement)):
33+
return element.reconstruct(sub_element=RestrictedElement(element.sub_elements[0], restriction_domain))
34+
35+
elif isinstance(element, MixedElement):
36+
return MixedElement([RestrictedElement(e, restriction_domain) for e in element.sub_elements])
37+
38+
else: # hopefully no special casing needed
39+
return super().__new__(cls)
2240

2341
def __init__(self, element, restriction_domain):
2442
"""Doc."""
@@ -68,9 +86,10 @@ def restriction_domain(self):
6886
"""Return the domain onto which the element is restricted."""
6987
return self._restriction_domain
7088

71-
def reconstruct(self, **kwargs):
89+
def reconstruct(self, element=None, **kwargs):
7290
"""Doc."""
73-
element = self._element.reconstruct(**kwargs)
91+
if element is None:
92+
element = self._element.reconstruct(**kwargs)
7493
return RestrictedElement(element, self._restriction_domain)
7594

7695
def __str__(self):

test/finat/test_create_broken_element.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,3 @@ def test_create_broken_mixed_element(sub_elements):
4141
mixed = MixedElement(sub_elements)
4242
expected = MixedElement([BrokenElement(elem) for elem in sub_elements])
4343
assert BrokenElement(mixed) == expected
44-
45-
46-
if __name__ == "__main__":
47-
import os
48-
import sys
49-
pytest.main(args=[os.path.abspath(__file__)] + sys.argv[1:])
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import pytest
2+
import ufl
3+
from finat.ufl import FiniteElement, RestrictedElement, VectorElement, TensorElement, MixedElement
4+
5+
sub_elements = [
6+
FiniteElement("CG", ufl.triangle, 1),
7+
FiniteElement("BDM", ufl.triangle, 2),
8+
FiniteElement("DG", ufl.interval, 2, variant="spectral")
9+
]
10+
11+
sub_ids = [
12+
"CG(1)",
13+
"BDM(2)",
14+
"DG(2,spectral)"
15+
]
16+
17+
18+
@pytest.mark.parametrize("sub_element", sub_elements, ids=sub_ids)
19+
@pytest.mark.parametrize("shape", (1, 2, (2, 3)), ids=("1", "2", "(2,3)"))
20+
def test_create_restricted_vector_or_tensor_element(shape, sub_element):
21+
"""Check that RestrictedElement returns a nested element
22+
for mixed, vector, and tensor elements.
23+
"""
24+
if not isinstance(shape, int):
25+
make_element = lambda elem: TensorElement(elem, shape=shape)
26+
else:
27+
make_element = lambda elem: VectorElement(elem, dim=shape)
28+
29+
tensor = make_element(sub_element)
30+
expected = make_element(RestrictedElement(sub_element, "interior"))
31+
32+
assert RestrictedElement(tensor, "interior") == expected
33+
34+
35+
@pytest.mark.parametrize("sub_elements", [sub_elements, sub_elements[-1:]],
36+
ids=(f"nsubs={len(sub_elements)}", "nsubs=1"))
37+
def test_create_restricted_mixed_element(sub_elements):
38+
"""Check that RestrictedElement returns a nested element
39+
for mixed, vector, and tensor elements.
40+
"""
41+
mixed = MixedElement(sub_elements)
42+
expected = MixedElement([elem["facet"] for elem in sub_elements])
43+
assert mixed["facet"] == expected

0 commit comments

Comments
 (0)