Skip to content

Commit 2514195

Browse files
authored
Merge pull request #13 from mscroggs/restructure-and-matrix-shape
Restructure and matrix shape
2 parents bacbceb + 7baf952 commit 2514195

File tree

3 files changed

+33
-9
lines changed

3 files changed

+33
-9
lines changed

symfem/core/finite_element.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,32 @@
77
class FiniteElement:
88
"""Abstract finite element."""
99

10-
def __init__(self, reference, basis, dofs, domain_dim, range_dim):
10+
def __init__(self, reference, basis, dofs, domain_dim, range_dim,
11+
range_shape=None):
1112
assert len(basis) == len(dofs)
1213
self.reference = reference
1314
self.basis = basis
1415
self.dofs = dofs
1516
self.domain_dim = domain_dim
1617
self.range_dim = range_dim
18+
self.range_shape = range_shape
1719
self.space_dim = len(dofs)
1820
self._basis_functions = None
21+
self._reshaped_basis_functions = None
1922

20-
def get_basis_functions(self):
23+
def get_polynomial_basis(self, reshape=True):
24+
"""Get the polynomial basis for the element."""
25+
if reshape and self.range_shape is not None:
26+
if len(self.range_shape) != 2:
27+
raise NotImplementedError
28+
assert self.range_shape[0] * self.range_shape[1] == self.range_dim
29+
return [sympy.Matrix(
30+
[b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]]
31+
for i in range(self.range_shape[0])]) for b in self.basis]
32+
33+
return self.basis
34+
35+
def get_basis_functions(self, reshape=True):
2136
"""Get the basis functions of the element."""
2237
if self._basis_functions is None:
2338
mat = []
@@ -44,6 +59,14 @@ def get_basis_functions(self):
4459
b[j] += c * d_j
4560
self._basis_functions.append(b)
4661

62+
if reshape and self.range_shape is not None:
63+
if len(self.range_shape) != 2:
64+
raise NotImplementedError
65+
assert self.range_shape[0] * self.range_shape[1] == self.range_dim
66+
return [sympy.Matrix(
67+
[b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]]
68+
for i in range(self.range_shape[0])]) for b in self._basis_functions]
69+
4770
return self._basis_functions
4871

4972
def tabulate_basis(self, points, order="xyzxyz"):
@@ -52,7 +75,7 @@ def tabulate_basis(self, points, order="xyzxyz"):
5275
output = []
5376
for p in points:
5477
row = []
55-
for b in self.get_basis_functions():
78+
for b in self.get_basis_functions(False):
5679
row.append(subs(b, x, p))
5780
output.append(row)
5881
return output
@@ -62,15 +85,15 @@ def tabulate_basis(self, points, order="xyzxyz"):
6285
for p in points:
6386
row = []
6487
for d in range(self.range_dim):
65-
for b in self.get_basis_functions():
88+
for b in self.get_basis_functions(False):
6689
row.append(subs(b[d], x, p))
6790
output.append(row)
6891
return output
6992
if order == "xyzxyz":
7093
output = []
7194
for p in points:
7295
row = []
73-
for b in self.get_basis_functions():
96+
for b in self.get_basis_functions(False):
7497
for i in subs(b, x, p):
7598
row.append(i)
7699
output.append(row)
@@ -79,7 +102,7 @@ def tabulate_basis(self, points, order="xyzxyz"):
79102
output = []
80103
for p in points:
81104
row = []
82-
for b in self.get_basis_functions():
105+
for b in self.get_basis_functions(False):
83106
row.append(subs(b, x, p))
84107
output.append(row)
85108
return output

symfem/elements/regge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def __init__(self, reference, order):
3636
for j, o in enumerate(entity.origin)),
3737
tangent, entity_dim=edim))
3838

39-
super().__init__(reference, poly, dofs, reference.tdim, reference.tdim ** 2)
39+
super().__init__(reference, poly, dofs, reference.tdim, reference.tdim ** 2,
40+
(reference.tdim, reference.tdim))
4041

4142
names = ["Regge"]
4243
min_order = 0

test/test_against_basix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ def test_against_basix(cell, symfem_type, basix_type, order):
6060
result = space.tabulate(0, points)[0]
6161

6262
if element.range_dim == 1:
63-
basis = element.get_basis_functions()
63+
basis = element.get_basis_functions(False)
6464
sym_result = [[float(subs(b, x, p)) for b in basis] for p in points]
6565
else:
66-
basis = element.get_basis_functions()
66+
basis = element.get_basis_functions(False)
6767
sym_result = [[float(subs(b, x, p)[j]) for j in range(element.range_dim) for b in basis] for p in points]
6868
assert np.allclose(result, sym_result)

0 commit comments

Comments
 (0)