11import sympy
2+ from utils import all_symequal
23from symfem import create_element
3-
4-
5- def all_equal (a , b ):
6- if isinstance (a , (list , tuple )):
7- for i , j in zip (a , b ):
8- if not all_equal (i , j ):
9- return False
10- return True
11- return a == b
4+ from symfem .core .symbolic import x
125
136
147def test_lagrange ():
158 space = create_element ("triangle" , "Lagrange" , 1 )
16- assert all_equal (
9+ assert all_symequal (
1710 space .tabulate_basis ([[0 , 0 ], [0 , 1 ], [1 , 0 ]]),
1811 ((1 , 0 , 0 ), (0 , 0 , 1 ), (0 , 1 , 0 )),
1912 )
2013
2114
2215def test_nedelec ():
2316 space = create_element ("triangle" , "Nedelec" , 1 )
24- assert all_equal (
17+ assert all_symequal (
2518 space .tabulate_basis ([[0 , 0 ], [1 , 0 ], [0 , 1 ]], "xxyyzz" ),
2619 ((0 , 0 , 1 , 0 , 1 , 0 ), (0 , 0 , 1 , 1 , 0 , 1 ), (- 1 , 1 , 0 , 0 , 1 , 0 )),
2720 )
2821
2922
3023def test_rt ():
3124 space = create_element ("triangle" , "Raviart-Thomas" , 1 )
32- assert all_equal (
25+ assert all_symequal (
3326 space .tabulate_basis ([[0 , 0 ], [1 , 0 ], [0 , 1 ]], "xxyyzz" ),
3427 ((0 , - 1 , 0 , 0 , 0 , 1 ), (- 1 , 0 , - 1 , 0 , 0 , 1 ), (0 , - 1 , 0 , - 1 , 1 , 0 )),
3528 )
3629
3730
3831def test_Q ():
3932 space = create_element ("quadrilateral" , "Q" , 1 )
40- assert all_equal (
33+ assert all_symequal (
4134 space .tabulate_basis ([[0 , 0 ], [1 , 0 ], [0 , 1 ], [1 , 1 ]]),
4235 ((1 , 0 , 0 , 0 ), (0 , 1 , 0 , 0 ), (0 , 0 , 1 , 0 ), (0 , 0 , 0 , 1 )),
4336 )
@@ -46,7 +39,7 @@ def test_Q():
4639def test_dual0 ():
4740 space = create_element ("dual polygon(4)" , "dual" , 0 )
4841 q = sympy .Rational (1 , 4 )
49- assert all_equal (
42+ assert all_symequal (
5043 space .tabulate_basis ([[q , q ], [- q , q ], [- q , - q ], [q , - q ]]),
5144 ((1 , ), (1 , ), (1 , ), (1 , ))
5245 )
@@ -57,9 +50,29 @@ def test_dual1():
5750 h = sympy .Rational (1 , 2 )
5851 q = sympy .Rational (1 , 4 )
5952 e = sympy .Rational (1 , 8 )
60- assert all_equal (
53+ assert all_symequal (
6154 space .tabulate_basis ([[0 , 0 ], [q , q ], [h , 0 ]]),
6255 ((q , q , q , q ),
6356 (sympy .Rational (5 , 8 ), e , e , e ),
6457 (sympy .Rational (3 , 8 ), e , e , sympy .Rational (3 , 8 )))
6558 )
59+
60+
61+ def test_lagrange_pyramid ():
62+ space = create_element ("pyramid" , "Lagrange" , 1 )
63+ x_i = x [0 ] / (1 - x [2 ])
64+ y_i = x [1 ] / (1 - x [2 ])
65+ z_i = x [2 ] / (1 - x [2 ])
66+ basis = [(1 - x_i ) * (1 - y_i ) / (1 + z_i ),
67+ x_i * (1 - y_i ) / (1 + z_i ),
68+ (1 - x_i ) * y_i / (1 + z_i ),
69+ x_i * y_i / (1 + z_i ),
70+ z_i / (1 + z_i )]
71+ assert all_symequal (basis , space .get_basis_functions ())
72+
73+ basis = [(1 - x [0 ] - x [2 ]) * (1 - x [1 ] - x [2 ]) / (1 - x [2 ]),
74+ x [0 ] * (1 - x [1 ] - x [2 ]) / (1 - x [2 ]),
75+ (1 - x [0 ] - x [2 ]) * x [1 ] / (1 - x [2 ]),
76+ x [0 ] * x [1 ] / (1 - x [2 ]),
77+ x [2 ]]
78+ assert all_symequal (basis , space .get_basis_functions ())
0 commit comments