|
9 | 9 | class FiniteElement: |
10 | 10 | """Abstract finite element.""" |
11 | 11 |
|
12 | | - def __init__(self, reference, order, basis, dofs, domain_dim, range_dim, |
| 12 | + def __init__(self, reference, order, space_dim, domain_dim, range_dim, |
13 | 13 | range_shape=None): |
14 | | - assert len(basis) == len(dofs) |
15 | 14 | self.reference = reference |
16 | 15 | self.order = order |
17 | | - self.basis = basis |
18 | | - self.dofs = dofs |
| 16 | + self.space_dim = space_dim |
19 | 17 | self.domain_dim = domain_dim |
20 | 18 | self.range_dim = range_dim |
21 | 19 | self.range_shape = range_shape |
22 | | - self.space_dim = len(dofs) |
23 | | - self._basis_functions = None |
24 | | - self._reshaped_basis_functions = None |
25 | 20 |
|
26 | 21 | def entity_dofs(self, entity_dim, entity_number): |
27 | 22 | """Get the numbers of the DOFs associated with the given entity.""" |
28 | | - return [i for i, j in enumerate(self.dofs) if j.entity == (entity_dim, entity_number)] |
29 | | - |
30 | | - def get_polynomial_basis(self, reshape=True): |
31 | | - """Get the polynomial basis for the element.""" |
32 | | - if reshape and self.range_shape is not None: |
33 | | - if len(self.range_shape) != 2: |
34 | | - raise NotImplementedError |
35 | | - assert self.range_shape[0] * self.range_shape[1] == self.range_dim |
36 | | - return [sympy.Matrix( |
37 | | - [b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]] |
38 | | - for i in range(self.range_shape[0])]) for b in self.basis] |
39 | | - |
40 | | - return self.basis |
41 | | - |
42 | | - def get_dual_matrix(self): |
43 | | - """Get the dual matrix.""" |
44 | | - mat = [] |
45 | | - for b in self.basis: |
46 | | - row = [] |
47 | | - for d in self.dofs: |
48 | | - row.append(d.eval(b)) |
49 | | - mat.append(row) |
50 | | - return sympy.Matrix(mat) |
| 23 | + raise NotImplementedError() |
51 | 24 |
|
52 | 25 | def get_basis_functions(self, reshape=True): |
53 | 26 | """Get the basis functions of the element.""" |
54 | | - if self._basis_functions is None: |
55 | | - minv = self.get_dual_matrix().inv("LU") |
56 | | - self._basis_functions = [] |
57 | | - if self.range_dim == 1: |
58 | | - # Scalar space |
59 | | - for i, dof in enumerate(self.dofs): |
60 | | - self._basis_functions.append( |
61 | | - sym_sum(c * d for c, d in zip(minv.row(i), self.basis))) |
62 | | - else: |
63 | | - # Vector space |
64 | | - for i, dof in enumerate(self.dofs): |
65 | | - b = [zero for i in self.basis[0]] |
66 | | - for c, d in zip(minv.row(i), self.basis): |
67 | | - for j, d_j in enumerate(d): |
68 | | - b[j] += c * d_j |
69 | | - self._basis_functions.append(b) |
70 | | - |
71 | | - if reshape and self.range_shape is not None: |
72 | | - if len(self.range_shape) != 2: |
73 | | - raise NotImplementedError |
74 | | - assert self.range_shape[0] * self.range_shape[1] == self.range_dim |
75 | | - return [sympy.Matrix( |
76 | | - [b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]] |
77 | | - for i in range(self.range_shape[0])]) for b in self._basis_functions] |
78 | | - |
79 | | - return self._basis_functions |
| 27 | + raise NotImplementedError() |
80 | 28 |
|
81 | 29 | def get_basis_function(self, n): |
82 | 30 | """Get a single basis function of the element.""" |
@@ -147,3 +95,103 @@ def name(self): |
147 | 95 | return self.names[0] |
148 | 96 |
|
149 | 97 | names = [] |
| 98 | + |
| 99 | + |
| 100 | +class CiarletElement(FiniteElement): |
| 101 | + """Finite element defined using the Ciarlet definition.""" |
| 102 | + |
| 103 | + def __init__(self, reference, order, basis, dofs, domain_dim, range_dim, |
| 104 | + range_shape=None): |
| 105 | + super().__init__(reference, order, len(dofs), domain_dim, range_dim, range_shape) |
| 106 | + assert len(basis) == len(dofs) |
| 107 | + self.basis = basis |
| 108 | + self.dofs = dofs |
| 109 | + self._basis_functions = None |
| 110 | + self._reshaped_basis_functions = None |
| 111 | + |
| 112 | + def entity_dofs(self, entity_dim, entity_number): |
| 113 | + """Get the numbers of the DOFs associated with the given entity.""" |
| 114 | + return [i for i, j in enumerate(self.dofs) if j.entity == (entity_dim, entity_number)] |
| 115 | + |
| 116 | + def get_polynomial_basis(self, reshape=True): |
| 117 | + """Get the polynomial basis for the element.""" |
| 118 | + if reshape and self.range_shape is not None: |
| 119 | + if len(self.range_shape) != 2: |
| 120 | + raise NotImplementedError |
| 121 | + assert self.range_shape[0] * self.range_shape[1] == self.range_dim |
| 122 | + return [sympy.Matrix( |
| 123 | + [b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]] |
| 124 | + for i in range(self.range_shape[0])]) for b in self.basis] |
| 125 | + |
| 126 | + return self.basis |
| 127 | + |
| 128 | + def get_dual_matrix(self): |
| 129 | + """Get the dual matrix.""" |
| 130 | + mat = [] |
| 131 | + for b in self.basis: |
| 132 | + row = [] |
| 133 | + for d in self.dofs: |
| 134 | + row.append(d.eval(b)) |
| 135 | + mat.append(row) |
| 136 | + return sympy.Matrix(mat) |
| 137 | + |
| 138 | + def get_basis_functions(self, reshape=True): |
| 139 | + """Get the basis functions of the element.""" |
| 140 | + if self._basis_functions is None: |
| 141 | + minv = self.get_dual_matrix().inv("LU") |
| 142 | + self._basis_functions = [] |
| 143 | + if self.range_dim == 1: |
| 144 | + # Scalar space |
| 145 | + for i, dof in enumerate(self.dofs): |
| 146 | + self._basis_functions.append( |
| 147 | + sym_sum(c * d for c, d in zip(minv.row(i), self.basis))) |
| 148 | + else: |
| 149 | + # Vector space |
| 150 | + for i, dof in enumerate(self.dofs): |
| 151 | + b = [zero for i in self.basis[0]] |
| 152 | + for c, d in zip(minv.row(i), self.basis): |
| 153 | + for j, d_j in enumerate(d): |
| 154 | + b[j] += c * d_j |
| 155 | + self._basis_functions.append(b) |
| 156 | + |
| 157 | + if reshape and self.range_shape is not None: |
| 158 | + if len(self.range_shape) != 2: |
| 159 | + raise NotImplementedError |
| 160 | + assert self.range_shape[0] * self.range_shape[1] == self.range_dim |
| 161 | + return [sympy.Matrix( |
| 162 | + [b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]] |
| 163 | + for i in range(self.range_shape[0])]) for b in self._basis_functions] |
| 164 | + |
| 165 | + return self._basis_functions |
| 166 | + |
| 167 | + names = [] |
| 168 | + |
| 169 | + |
| 170 | +class DirectElement(FiniteElement): |
| 171 | + """Finite element defined directly.""" |
| 172 | + |
| 173 | + def __init__(self, reference, order, basis_functions, basis_entities, domain_dim, range_dim, |
| 174 | + range_shape=None): |
| 175 | + super().__init__(reference, order, len(basis_functions), domain_dim, range_dim, |
| 176 | + range_shape) |
| 177 | + self._basis_entities = basis_entities |
| 178 | + self._basis_functions = basis_functions |
| 179 | + self._reshaped_basis_functions = None |
| 180 | + |
| 181 | + def entity_dofs(self, entity_dim, entity_number): |
| 182 | + """Get the numbers of the DOFs associated with the given entity.""" |
| 183 | + return [i for i, j in enumerate(self._basis_entities) if j == (entity_dim, entity_number)] |
| 184 | + |
| 185 | + def get_basis_functions(self, reshape=True): |
| 186 | + """Get the basis functions of the element.""" |
| 187 | + if reshape and self.range_shape is not None: |
| 188 | + if len(self.range_shape) != 2: |
| 189 | + raise NotImplementedError |
| 190 | + assert self.range_shape[0] * self.range_shape[1] == self.range_dim |
| 191 | + return [sympy.Matrix( |
| 192 | + [b[i * self.range_shape[1]: (i + 1) * self.range_shape[1]] |
| 193 | + for i in range(self.range_shape[0])]) for b in self._basis_functions] |
| 194 | + |
| 195 | + return self._basis_functions |
| 196 | + |
| 197 | + names = [] |
0 commit comments