Skip to content

Commit 742a0f2

Browse files
authored
Implement integrals of matrix and vector functions (#297)
* add py.typed * rename AnyFunction -> Function * integrate components of vector function * ruff * DefElement branch
1 parent 2e01eec commit 742a0f2

File tree

13 files changed

+187
-177
lines changed

13 files changed

+187
-177
lines changed

.github/workflows/defelement.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
with:
2828
path: ./defelement
2929
repository: DefElement/DefElement
30-
ref: main
30+
ref: mscroggs/AnyFunction->Function
3131
- name: Install requirements
3232
run: |
3333
cd defelement

symfem/basis_functions.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import symfem
1111
from symfem.functions import (
12-
AnyFunction,
12+
Function,
1313
FunctionInput,
1414
ScalarFunction,
1515
SympyFormat,
@@ -22,7 +22,7 @@
2222
__all__ = ["BasisFunction", "SubbedBasisFunction"]
2323

2424

25-
class BasisFunction(AnyFunction):
25+
class BasisFunction(Function):
2626
"""A basis function of a finite element.
2727
2828
This basis function can be used before the element's basis functions have been computed. When
@@ -40,15 +40,15 @@ def __init__(self, scalar=False, vector=False, matrix=False):
4040
super().__init__(scalar=scalar, vector=vector, matrix=matrix)
4141

4242
@abstractmethod
43-
def get_function(self) -> AnyFunction:
43+
def get_function(self) -> Function:
4444
"""Get the actual basis function.
4545
4646
Returns:
4747
The basis function
4848
"""
4949
pass
5050

51-
def __add__(self, other: typing.Any) -> AnyFunction:
51+
def __add__(self, other: typing.Any) -> Function:
5252
"""Add.
5353
5454
Args:
@@ -59,7 +59,7 @@ def __add__(self, other: typing.Any) -> AnyFunction:
5959
"""
6060
return self.get_function().__add__(other)
6161

62-
def __radd__(self, other: typing.Any) -> AnyFunction:
62+
def __radd__(self, other: typing.Any) -> Function:
6363
"""Add.
6464
6565
Args:
@@ -70,7 +70,7 @@ def __radd__(self, other: typing.Any) -> AnyFunction:
7070
"""
7171
return self.get_function().__radd__(other)
7272

73-
def __sub__(self, other: typing.Any) -> AnyFunction:
73+
def __sub__(self, other: typing.Any) -> Function:
7474
"""Subtract.
7575
7676
Args:
@@ -81,7 +81,7 @@ def __sub__(self, other: typing.Any) -> AnyFunction:
8181
"""
8282
return self.get_function().__sub__(other)
8383

84-
def __rsub__(self, other: typing.Any) -> AnyFunction:
84+
def __rsub__(self, other: typing.Any) -> Function:
8585
"""Subtract.
8686
8787
Args:
@@ -92,15 +92,15 @@ def __rsub__(self, other: typing.Any) -> AnyFunction:
9292
"""
9393
return self.get_function().__rsub__(other)
9494

95-
def __neg__(self) -> AnyFunction:
95+
def __neg__(self) -> Function:
9696
"""Negate.
9797
9898
Returns:
9999
Negated function
100100
"""
101101
return self.get_function().__neg__()
102102

103-
def __truediv__(self, other: typing.Any) -> AnyFunction:
103+
def __truediv__(self, other: typing.Any) -> Function:
104104
"""Divide.
105105
106106
Args:
@@ -111,7 +111,7 @@ def __truediv__(self, other: typing.Any) -> AnyFunction:
111111
"""
112112
return self.get_function().__truediv__(other)
113113

114-
def __rtruediv__(self, other: typing.Any) -> AnyFunction:
114+
def __rtruediv__(self, other: typing.Any) -> Function:
115115
"""Divide.
116116
117117
Args:
@@ -122,7 +122,7 @@ def __rtruediv__(self, other: typing.Any) -> AnyFunction:
122122
"""
123123
return self.get_function().__rtruediv__(other)
124124

125-
def __mul__(self, other: typing.Any) -> AnyFunction:
125+
def __mul__(self, other: typing.Any) -> Function:
126126
"""Multiply.
127127
128128
Args:
@@ -133,7 +133,7 @@ def __mul__(self, other: typing.Any) -> AnyFunction:
133133
"""
134134
return self.get_function().__mul__(other)
135135

136-
def __rmul__(self, other: typing.Any) -> AnyFunction:
136+
def __rmul__(self, other: typing.Any) -> Function:
137137
"""Multiply.
138138
139139
Args:
@@ -144,7 +144,7 @@ def __rmul__(self, other: typing.Any) -> AnyFunction:
144144
"""
145145
return self.get_function().__rmul__(other)
146146

147-
def __matmul__(self, other: typing.Any) -> AnyFunction:
147+
def __matmul__(self, other: typing.Any) -> Function:
148148
"""Multiply.
149149
150150
Args:
@@ -155,7 +155,7 @@ def __matmul__(self, other: typing.Any) -> AnyFunction:
155155
"""
156156
return self.get_function().__matmul__(other)
157157

158-
def __rmatmul__(self, other: typing.Any) -> AnyFunction:
158+
def __rmatmul__(self, other: typing.Any) -> Function:
159159
"""Multiply.
160160
161161
Args:
@@ -166,7 +166,7 @@ def __rmatmul__(self, other: typing.Any) -> AnyFunction:
166166
"""
167167
return self.get_function().__rmatmul__(other)
168168

169-
def __pow__(self, other: typing.Any) -> AnyFunction:
169+
def __pow__(self, other: typing.Any) -> Function:
170170
"""Raise to a power.
171171
172172
Args:
@@ -193,7 +193,7 @@ def as_tex(self) -> str:
193193
"""
194194
return self.get_function().as_tex()
195195

196-
def diff(self, variable: sympy.core.symbol.Symbol) -> AnyFunction:
196+
def diff(self, variable: sympy.core.symbol.Symbol) -> Function:
197197
"""Differentiate the function.
198198
199199
Args:
@@ -204,7 +204,7 @@ def diff(self, variable: sympy.core.symbol.Symbol) -> AnyFunction:
204204
"""
205205
return self.get_function().diff(variable)
206206

207-
def directional_derivative(self, direction: PointType) -> AnyFunction:
207+
def directional_derivative(self, direction: PointType) -> Function:
208208
"""Compute a directional derivative.
209209
210210
Args:
@@ -215,7 +215,7 @@ def directional_derivative(self, direction: PointType) -> AnyFunction:
215215
"""
216216
return self.get_function().directional_derivative(direction)
217217

218-
def jacobian_component(self, component: typing.Tuple[int, int]) -> AnyFunction:
218+
def jacobian_component(self, component: typing.Tuple[int, int]) -> Function:
219219
"""Compute a component of the jacobian.
220220
221221
Args:
@@ -226,7 +226,7 @@ def jacobian_component(self, component: typing.Tuple[int, int]) -> AnyFunction:
226226
"""
227227
return self.get_function().jacobian_component(component)
228228

229-
def jacobian(self, dim: int) -> AnyFunction:
229+
def jacobian(self, dim: int) -> Function:
230230
"""Compute the jacobian.
231231
232232
Args:
@@ -237,7 +237,7 @@ def jacobian(self, dim: int) -> AnyFunction:
237237
"""
238238
return self.get_function().jacobian(dim)
239239

240-
def dot(self, other_in: FunctionInput) -> AnyFunction:
240+
def dot(self, other_in: FunctionInput) -> Function:
241241
"""Compute the dot product with another function.
242242
243243
Args:
@@ -248,7 +248,7 @@ def dot(self, other_in: FunctionInput) -> AnyFunction:
248248
"""
249249
return self.get_function().dot(other_in)
250250

251-
def cross(self, other_in: FunctionInput) -> AnyFunction:
251+
def cross(self, other_in: FunctionInput) -> Function:
252252
"""Compute the cross product with another function.
253253
254254
Args:
@@ -259,23 +259,23 @@ def cross(self, other_in: FunctionInput) -> AnyFunction:
259259
"""
260260
return self.get_function().cross(other_in)
261261

262-
def div(self) -> AnyFunction:
262+
def div(self) -> Function:
263263
"""Compute the divergence of the function.
264264
265265
Returns:
266266
The divergence
267267
"""
268268
return self.get_function().div()
269269

270-
def grad(self, dim: int) -> AnyFunction:
270+
def grad(self, dim: int) -> Function:
271271
"""Compute the gradient of the function.
272272
273273
Returns:
274274
The gradient
275275
"""
276276
return self.get_function().grad(dim)
277277

278-
def curl(self) -> AnyFunction:
278+
def curl(self) -> Function:
279279
"""Compute the curl of the function.
280280
281281
Returns:
@@ -296,7 +296,7 @@ def integral(
296296
domain: Reference,
297297
vars: AxisVariablesNotSingle = x,
298298
dummy_vars: AxisVariablesNotSingle = t,
299-
) -> ScalarFunction:
299+
) -> Function:
300300
"""Compute the integral of the function.
301301
302302
Args:
@@ -321,7 +321,7 @@ def subs(self, vars: AxisVariables, values: ValuesToSubstitute) -> BasisFunction
321321
"""
322322
return SubbedBasisFunction(self, vars, values)
323323

324-
def __getitem__(self, key) -> AnyFunction:
324+
def __getitem__(self, key) -> Function:
325325
"""Forward all other function calls to symbolic function."""
326326
return self.get_function().__getitem__(key)
327327

@@ -353,15 +353,15 @@ def transpose(self) -> ScalarFunction:
353353
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute 'transpose'")
354354
return self.get_function().transpose()
355355

356-
def with_floats(self) -> AnyFunction:
356+
def with_floats(self) -> Function:
357357
"""Return a version the function with floats as coefficients.
358358
359359
Returns:
360360
The function with floats as coefficients
361361
"""
362362
return self.get_function().with_floats()
363363

364-
def __iter__(self) -> typing.Iterator[AnyFunction]:
364+
def __iter__(self) -> typing.Iterator[Function]:
365365
"""Iterate through components of vector function."""
366366
f = self.get_function()
367367
return f.__iter__()
@@ -397,7 +397,7 @@ def __init__(self, f: BasisFunction, vars: AxisVariables, values: ValuesToSubsti
397397
self._vars = vars
398398
self._values = values
399399

400-
def get_function(self) -> AnyFunction:
400+
def get_function(self) -> Function:
401401
"""Return the symbolic function.
402402
403403
Returns:

symfem/elements/bernstein.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from symfem.finite_element import CiarletElement
1313
from symfem.functionals import BaseFunctional, ListOfFunctionals, PointEvaluation
14-
from symfem.functions import AnyFunction, FunctionInput
14+
from symfem.functions import Function, FunctionInput
1515
from symfem.geometry import PointType
1616
from symfem.polynomials import orthogonal_basis, polynomial_set_1d
1717
from symfem.references import Reference
@@ -137,7 +137,7 @@ def dof_point(self) -> PointType:
137137
"""
138138
return self.ref.sub_entity(*self.entity).midpoint()
139139

140-
def _eval_symbolic(self, function: AnyFunction) -> AnyFunction:
140+
def _eval_symbolic(self, function: Function) -> Function:
141141
"""Apply the functional to a function.
142142
143143
Args:

symfem/elements/dual.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sympy
1010

1111
from symfem.finite_element import FiniteElement
12-
from symfem.functions import AnyFunction, FunctionInput, VectorFunction
12+
from symfem.functions import Function, FunctionInput, VectorFunction
1313
from symfem.geometry import PointType, SetOfPoints, SetOfPointsInput
1414
from symfem.piecewise_functions import PiecewiseFunction
1515
from symfem.references import DualPolygon, NonDefaultReferenceError
@@ -56,11 +56,11 @@ def __init__(
5656
super().__init__(
5757
reference, order, len(dual_coefficients), domain_dim, range_dim, range_shape=range_shape
5858
)
59-
self._basis_functions: typing.Union[typing.List[AnyFunction], None] = None
59+
self._basis_functions: typing.Union[typing.List[Function], None] = None
6060
self._dof_entities = dof_entities
6161
self._dof_directions = dof_directions
6262

63-
def get_polynomial_basis(self, reshape: bool = True) -> typing.List[AnyFunction]:
63+
def get_polynomial_basis(self, reshape: bool = True) -> typing.List[Function]:
6464
"""Get the symbolic polynomial basis for the element.
6565
6666
Returns:
@@ -81,9 +81,7 @@ def get_dual_matrix(self) -> sympy.matrices.dense.MutableDenseMatrix:
8181
"""
8282
raise ValueError("Dual matrix not supported for barycentric dual elements.")
8383

84-
def get_basis_functions(
85-
self, use_tensor_factorisation: bool = False
86-
) -> typing.List[AnyFunction]:
84+
def get_basis_functions(self, use_tensor_factorisation: bool = False) -> typing.List[Function]:
8785
"""Get the basis functions of the element.
8886
8987
Args:
@@ -97,7 +95,7 @@ def get_basis_functions(
9795
if self._basis_functions is None:
9896
from symfem import create_element
9997

100-
bfs: typing.List[AnyFunction] = []
98+
bfs: typing.List[Function] = []
10199
sub_e = create_element("triangle", self.fine_space, self.order)
102100
for coeff_list in self.dual_coefficients:
103101
v0 = self.reference.origin
@@ -194,10 +192,10 @@ def dof_entities(self) -> typing.List[typing.Tuple[int, int]]:
194192
def map_to_cell(
195193
self,
196194
vertices_in: SetOfPointsInput,
197-
basis: typing.Optional[typing.List[AnyFunction]] = None,
195+
basis: typing.Optional[typing.List[Function]] = None,
198196
forward_map: typing.Optional[PointType] = None,
199197
inverse_map: typing.Optional[PointType] = None,
200-
) -> typing.List[AnyFunction]:
198+
) -> typing.List[Function]:
201199
"""Map the basis onto a cell using the appropriate mapping for the element.
202200
203201
Args:

symfem/elements/guzman_neilan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from symfem.elements.lagrange import Lagrange, VectorLagrange
1414
from symfem.finite_element import CiarletElement
1515
from symfem.functionals import DotPointEvaluation, ListOfFunctionals, NormalIntegralMoment
16-
from symfem.functions import AnyFunction, FunctionInput, ScalarFunction, VectorFunction
16+
from symfem.functions import Function, FunctionInput, ScalarFunction, VectorFunction
1717
from symfem.geometry import SetOfPoints, SetOfPointsInput
1818
from symfem.moments import make_integral_moment_dofs
1919
from symfem.piecewise_functions import PiecewiseFunction
@@ -23,7 +23,7 @@
2323
__all__ = ["GuzmanNeilanFirstKind", "GuzmanNeilanSecondKind", "make_piecewise_lagrange"]
2424

2525

26-
def poly(reference: Reference, k: int) -> typing.List[AnyFunction]:
26+
def poly(reference: Reference, k: int) -> typing.List[Function]:
2727
"""Generate the P^perp polynomial set."""
2828
if k < 2:
2929
if reference.name == "triangle":

0 commit comments

Comments
 (0)