Skip to content

Commit cf356fe

Browse files
authored
Add higher order Lagrange on pyramids (#95)
* Generalise pyramid lagrange * flake * Test pyramid lagrange
1 parent b85a2ef commit cf356fe

File tree

7 files changed

+127
-61
lines changed

7 files changed

+127
-61
lines changed

symfem/core/polynomials.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -270,20 +270,14 @@ def pyramid_polynomial_set_1d(dim, order):
270270
assert dim == 3
271271
if order == 0:
272272
return [one]
273-
if order >= 3:
274-
raise NotImplementedError()
275-
276-
poly = []
277-
for r in range(order + 1):
278-
for p in range(order - r + 1):
279-
for q in range(order - r + 1):
280-
if r == 0 and p + q < order:
281-
poly.append(x[0] ** p * x[1] ** q)
282-
else:
283-
new_p = (2 * x[0] + x[2]) ** p
284-
new_p *= (2 * x[1] + x[2]) ** q
285-
new_p *= x[2] ** r
286-
poly.append(new_p)
273+
274+
poly = polynomial_set_1d(3, order)
275+
276+
poly = [x[0] ** a * x[1] ** b * x[2] ** c / (1 - x[2]) ** (a + b + c - order)
277+
for c in range(order)
278+
for a in range(order + 1 - c) for b in range(order + 1 - c)]
279+
poly.append(x[2] ** order)
280+
287281
return poly
288282

289283

symfem/elements/lagrange_prism_pyramid.py renamed to symfem/elements/lagrange_prism.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""Lagrange elements on a prism and pyramid."""
1+
"""Lagrange elements on a prism."""
22

33
import sympy
44
from itertools import product
55
from ..core.symbolic import one, zero
66
from ..core.finite_element import CiarletElement
7-
from ..core.polynomials import prism_polynomial_set, pyramid_polynomial_set
7+
from ..core.polynomials import prism_polynomial_set
88
from ..core.functionals import PointEvaluation, DotPointEvaluation
99
from ..core.quadrature import get_quadrature
1010

@@ -51,35 +51,21 @@ def __init__(self, reference, order, variant):
5151
for j, o in enumerate(entity.origin)),
5252
entity=(2, e_n)))
5353

54-
if reference.name == "prism":
55-
# Interior
56-
for i in product(range(1, order), repeat=3):
57-
if i[0] + i[1] < order:
58-
dofs.append(
59-
PointEvaluation(
60-
tuple(o + sum(a[j] * points[b]
61-
for a, b in zip(reference.axes, i))
62-
for j, o in enumerate(reference.origin)),
63-
entity=(3, 0)))
64-
elif reference.name == "pyramid":
65-
# Interior
66-
for i in product(range(1, order), repeat=3):
67-
if max(i[0], i[1]) + i[2] < order:
68-
dofs.append(
69-
PointEvaluation(
70-
tuple(o + sum(a[j] * points[b]
71-
for a, b in zip(reference.axes, i))
72-
for j, o in enumerate(reference.origin)),
73-
entity=(3, 0)))
74-
75-
if reference.name == "prism":
76-
poly = prism_polynomial_set(reference.tdim, 1, order)
77-
elif reference.name == "pyramid":
78-
poly = pyramid_polynomial_set(reference.tdim, 1, order)
54+
# Interior
55+
for i in product(range(1, order), repeat=3):
56+
if i[0] + i[1] < order:
57+
dofs.append(
58+
PointEvaluation(
59+
tuple(o + sum(a[j] * points[b]
60+
for a, b in zip(reference.axes, i))
61+
for j, o in enumerate(reference.origin)),
62+
entity=(3, 0)))
63+
64+
poly = prism_polynomial_set(reference.tdim, 1, order)
7965
super().__init__(reference, order, poly, dofs, reference.tdim, 1)
8066

8167
names = ["Lagrange", "P"]
82-
references = ["prism", "pyramid"]
68+
references = ["prism"]
8369
min_order = 0
8470
continuity = "C0"
8571

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""Lagrange elements on a pyramid.
2+
3+
This element's definition appears in https://doi.org/10.1007/s10915-009-9334-9
4+
(Bergot, Cohen, Durufle, 2010)
5+
"""
6+
7+
import sympy
8+
from itertools import product
9+
from ..core.finite_element import CiarletElement
10+
from ..core.polynomials import pyramid_polynomial_set
11+
from ..core.functionals import PointEvaluation
12+
from ..core.quadrature import get_quadrature
13+
14+
15+
class Lagrange(CiarletElement):
16+
"""Lagrange finite element."""
17+
18+
def __init__(self, reference, order, variant):
19+
if order == 0:
20+
dofs = [
21+
PointEvaluation(
22+
tuple(
23+
sympy.Rational(1, reference.tdim + 1)
24+
for i in range(reference.tdim)
25+
),
26+
entity=(reference.tdim, 0)
27+
)
28+
]
29+
else:
30+
points, _ = get_quadrature(variant, order + 1)
31+
32+
dofs = []
33+
# Vertices
34+
for v_n, v in enumerate(reference.reference_vertices):
35+
dofs.append(PointEvaluation(v, entity=(0, v_n)))
36+
# Edges
37+
for e_n in range(reference.sub_entity_count(1)):
38+
entity = reference.sub_entity(1, e_n)
39+
for i in range(1, order):
40+
dofs.append(
41+
PointEvaluation(
42+
tuple(o + entity.axes[0][j] * points[i]
43+
for j, o in enumerate(entity.origin)),
44+
entity=(1, e_n)))
45+
# Faces
46+
for e_n in range(reference.sub_entity_count(2)):
47+
entity = reference.sub_entity(2, e_n)
48+
for i in product(range(1, order), repeat=2):
49+
if len(entity.vertices) == 4 or sum(i) < order:
50+
dofs.append(
51+
PointEvaluation(
52+
tuple(o + sum(a[j] * points[b]
53+
for a, b in zip(entity.axes, i[::-1]))
54+
for j, o in enumerate(entity.origin)),
55+
entity=(2, e_n)))
56+
57+
# Interior
58+
for i in product(range(1, order), repeat=3):
59+
if max(i[0], i[1]) + i[2] < order:
60+
dofs.append(
61+
PointEvaluation(
62+
tuple(o + sum(a[j] * points[b]
63+
for a, b in zip(reference.axes, i))
64+
for j, o in enumerate(reference.origin)),
65+
entity=(3, 0)))
66+
67+
poly = pyramid_polynomial_set(reference.tdim, 1, order)
68+
69+
super().__init__(reference, order, poly, dofs, reference.tdim, 1)
70+
71+
names = ["Lagrange", "P"]
72+
references = ["pyramid"]
73+
min_order = 0
74+
continuity = "C0"

test/test_against_basix.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
("Qcurl", "Nedelec 1st kind H(curl)", range(1, 3)),
3838
("Sdiv", "Brezzi-Douglas-Marini", range(1, 3)),
3939
("Scurl", "Nedelec 2nd kind H(curl)", range(1, 3))],
40-
"prism": [("Lagrange", "Lagrange", range(1, 4))],
41-
"pyramid": [("Lagrange", "Lagrange", range(1, 3))]
40+
"prism": [("Lagrange", "Lagrange", range(1, 4))]
4241
}
4342

4443

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,36 @@
11
import sympy
2+
from utils import all_symequal
23
from symfem import create_element
3-
4-
5-
def all_equal(a, b):
6-
if isinstance(a, (list, tuple)):
7-
for i, j in zip(a, b):
8-
if not all_equal(i, j):
9-
return False
10-
return True
11-
return a == b
4+
from symfem.core.symbolic import x
125

136

147
def test_lagrange():
158
space = create_element("triangle", "Lagrange", 1)
16-
assert all_equal(
9+
assert all_symequal(
1710
space.tabulate_basis([[0, 0], [0, 1], [1, 0]]),
1811
((1, 0, 0), (0, 0, 1), (0, 1, 0)),
1912
)
2013

2114

2215
def test_nedelec():
2316
space = create_element("triangle", "Nedelec", 1)
24-
assert all_equal(
17+
assert all_symequal(
2518
space.tabulate_basis([[0, 0], [1, 0], [0, 1]], "xxyyzz"),
2619
((0, 0, 1, 0, 1, 0), (0, 0, 1, 1, 0, 1), (-1, 1, 0, 0, 1, 0)),
2720
)
2821

2922

3023
def test_rt():
3124
space = create_element("triangle", "Raviart-Thomas", 1)
32-
assert all_equal(
25+
assert all_symequal(
3326
space.tabulate_basis([[0, 0], [1, 0], [0, 1]], "xxyyzz"),
3427
((0, -1, 0, 0, 0, 1), (-1, 0, -1, 0, 0, 1), (0, -1, 0, -1, 1, 0)),
3528
)
3629

3730

3831
def test_Q():
3932
space = create_element("quadrilateral", "Q", 1)
40-
assert all_equal(
33+
assert all_symequal(
4134
space.tabulate_basis([[0, 0], [1, 0], [0, 1], [1, 1]]),
4235
((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, 1)),
4336
)
@@ -46,7 +39,7 @@ def test_Q():
4639
def test_dual0():
4740
space = create_element("dual polygon(4)", "dual", 0)
4841
q = sympy.Rational(1, 4)
49-
assert all_equal(
42+
assert all_symequal(
5043
space.tabulate_basis([[q, q], [-q, q], [-q, -q], [q, -q]]),
5144
((1, ), (1, ), (1, ), (1, ))
5245
)
@@ -57,9 +50,29 @@ def test_dual1():
5750
h = sympy.Rational(1, 2)
5851
q = sympy.Rational(1, 4)
5952
e = sympy.Rational(1, 8)
60-
assert all_equal(
53+
assert all_symequal(
6154
space.tabulate_basis([[0, 0], [q, q], [h, 0]]),
6255
((q, q, q, q),
6356
(sympy.Rational(5, 8), e, e, e),
6457
(sympy.Rational(3, 8), e, e, sympy.Rational(3, 8)))
6558
)
59+
60+
61+
def test_lagrange_pyramid():
62+
space = create_element("pyramid", "Lagrange", 1)
63+
x_i = x[0] / (1 - x[2])
64+
y_i = x[1] / (1 - x[2])
65+
z_i = x[2] / (1 - x[2])
66+
basis = [(1 - x_i) * (1 - y_i) / (1 + z_i),
67+
x_i * (1 - y_i) / (1 + z_i),
68+
(1 - x_i) * y_i / (1 + z_i),
69+
x_i * y_i / (1 + z_i),
70+
z_i / (1 + z_i)]
71+
assert all_symequal(basis, space.get_basis_functions())
72+
73+
basis = [(1 - x[0] - x[2]) * (1 - x[1] - x[2]) / (1 - x[2]),
74+
x[0] * (1 - x[1] - x[2]) / (1 - x[2]),
75+
(1 - x[0] - x[2]) * x[1] / (1 - x[2]),
76+
x[0] * x[1] / (1 - x[2]),
77+
x[2]]
78+
assert all_symequal(basis, space.get_basis_functions())

test/test_elements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def test_element_functionals_and_continuity(
176176
(0, 1, 0), (0, 0, 1))
177177
entity_pairs = [[0, (0, 1)], [0, (2, 3)], [0, (4, 4)],
178178
[1, (1, 3)], [1, (2, 4)], [1, (6, 7)],
179-
[2, (2, 4)]]
179+
[2, (2, 3)]]
180180

181181
if space.continuity == "L2":
182182
return

test/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@
128128
"Lagrange": {"equispaced": range(4)}
129129
},
130130
"pyramid": {
131-
"Lagrange": {"equispaced": range(3)}
131+
"Lagrange": {"equispaced": range(4)}
132132
}
133133
}
134134

0 commit comments

Comments
 (0)