|
12 | 12 | # Modified by Matthew Scroggs, 2023 |
13 | 13 |
|
14 | 14 | from finat.ufl.finiteelementbase import FiniteElementBase |
| 15 | +from finat.ufl.mixedelement import MixedElement, VectorElement, TensorElement |
15 | 16 | from ufl.sobolevspace import L2 |
16 | 17 |
|
17 | 18 | valid_restriction_domains = ("interior", "facet", "ridge", "face", "edge", "vertex", "reduced") |
18 | 19 |
|
19 | 20 |
|
20 | 21 | class RestrictedElement(FiniteElementBase): |
21 | 22 | """Represents the restriction of a finite element to a type of cell entity.""" |
| 23 | + def __new__(cls, element, restriction_domain): |
| 24 | + """ |
| 25 | + Restricted qualifier must be below Mixed/Vector/Tensor so we |
| 26 | + overload __new__ to return: |
| 27 | +
|
| 28 | + RestrictedElement(MixedElement(elem0, elem1), dom) -> MixedElement(RestrictedElement(elem0, dom), RestrictedElement(elem1, dom)) |
| 29 | +
|
| 30 | + and similarly for VectorElement and TensorElement. |
| 31 | + """ |
| 32 | + if isinstance(element, (VectorElement, TensorElement)): |
| 33 | + return element.reconstruct(sub_element=RestrictedElement(element.sub_elements[0], restriction_domain)) |
| 34 | + |
| 35 | + elif isinstance(element, MixedElement): |
| 36 | + return MixedElement([RestrictedElement(e, restriction_domain) for e in element.sub_elements]) |
| 37 | + |
| 38 | + else: # hopefully no special casing needed |
| 39 | + return super().__new__(cls) |
22 | 40 |
|
23 | 41 | def __init__(self, element, restriction_domain): |
24 | 42 | """Doc.""" |
@@ -68,9 +86,10 @@ def restriction_domain(self): |
68 | 86 | """Return the domain onto which the element is restricted.""" |
69 | 87 | return self._restriction_domain |
70 | 88 |
|
71 | | - def reconstruct(self, **kwargs): |
| 89 | + def reconstruct(self, element=None, **kwargs): |
72 | 90 | """Doc.""" |
73 | | - element = self._element.reconstruct(**kwargs) |
| 91 | + if element is None: |
| 92 | + element = self._element.reconstruct(**kwargs) |
74 | 93 | return RestrictedElement(element, self._restriction_domain) |
75 | 94 |
|
76 | 95 | def __str__(self): |
|
0 commit comments