Skip to content

Commit 65f6636

Browse files
authored
Fix #207 (#208)
* Fix #207 * reduce code cuplication * pydocstyle * use different dummy names * revert change to meaning of integrate inputs * correct integrate * allow vars and dummy_vars inputs * various fixes * flake8 and mypy * lowercase * changelog
1 parent 43fbadd commit 65f6636

File tree

7 files changed

+189
-42
lines changed

7 files changed

+189
-42
lines changed

CHANGELOG_SINCE_LAST_VERSION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
- Added enriched Galerkin element
22
- Improved plotting
33
- Added enriched vector Galerkin element
4+
- Fixed bug in integration of piecewise functions

symfem/basis_functions.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .functions import AnyFunction, FunctionInput, ScalarFunction, SympyFormat, ValuesToSubstitute
1111
from .geometry import PointType
1212
from .references import Reference
13-
from .symbols import AxisVariables, AxisVariablesNotSingle, t
13+
from .symbols import AxisVariables, AxisVariablesNotSingle, t, x
1414

1515

1616
class BasisFunction(AnyFunction):
@@ -282,17 +282,21 @@ def norm(self) -> ScalarFunction:
282282
"""
283283
raise self.get_function().norm()
284284

285-
def integral(self, domain: Reference, vars: AxisVariablesNotSingle = t) -> AnyFunction:
285+
def integral(
286+
self, domain: Reference, vars: AxisVariablesNotSingle = x,
287+
dummy_vars: AxisVariablesNotSingle = t
288+
) -> ScalarFunction:
286289
"""Compute the integral of the function.
287290
288291
Args:
289-
domain: The domain to integrate over
290-
vars: The variables to integrate over
292+
domain: The domain of the integral
293+
vars: The variables to integrate with respect to
294+
dummy_vars: The dummy variables to use inside the integral
291295
292296
Returns:
293297
The integral
294298
"""
295-
return self.get_function().integral(domain, vars)
299+
return self.get_function().integral(domain, vars, dummy_vars)
296300

297301
def subs(self, vars: AxisVariables, values: ValuesToSubstitute) -> BasisFunction:
298302
"""Substitute values into the function.

symfem/functions.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,16 @@ def norm(self):
314314
pass
315315

316316
@abstractmethod
317-
def integral(self, domain: Reference, vars: AxisVariablesNotSingle = t):
317+
def integral(
318+
self, domain: Reference, vars: AxisVariablesNotSingle = x,
319+
dummy_vars: AxisVariablesNotSingle = t
320+
) -> ScalarFunction:
318321
"""Compute the integral of the function.
319322
320323
Args:
321324
domain: The domain of the integral
322325
vars: The variables to integrate with respect to
326+
dummy_vars: The dummy variables to use inside the integral
323327
324328
Returns:
325329
The integral
@@ -758,22 +762,27 @@ def norm(self) -> ScalarFunction:
758762
"""
759763
return ScalarFunction(abs(self._f))
760764

761-
def integral(self, domain: Reference, vars: AxisVariablesNotSingle = t) -> ScalarFunction:
765+
def integral(
766+
self, domain: Reference, vars: AxisVariablesNotSingle = x,
767+
dummy_vars: AxisVariablesNotSingle = t
768+
) -> ScalarFunction:
762769
"""Compute the integral of the function.
763770
764771
Args:
765772
domain: The domain of the integral
766773
vars: The variables to integrate with respect to
774+
dummy_vars: The dummy variables to use inside the integral
767775
768776
Returns:
769777
The integral
770778
"""
771-
limits = domain.integration_limits(vars)
772-
779+
limits = domain.integration_limits(dummy_vars)
773780
point = VectorFunction(domain.origin)
774-
for ti, a in zip(t, domain.axes):
781+
for ti, a in zip(dummy_vars, domain.axes):
775782
point += ti * VectorFunction(a)
776-
out = self._f.subs(x, point)
783+
out = self._f * 1
784+
for v, p in zip(vars, point):
785+
out = out.subs(v, p)
777786

778787
if len(limits[0]) == 2:
779788
for i in limits:
@@ -1159,12 +1168,16 @@ def norm(self) -> ScalarFunction:
11591168
a += i._f ** 2
11601169
return ScalarFunction(sympy.sqrt(a))
11611170

1162-
def integral(self, domain: Reference, vars: AxisVariablesNotSingle = t):
1171+
def integral(
1172+
self, domain: Reference, vars: AxisVariablesNotSingle = x,
1173+
dummy_vars: AxisVariablesNotSingle = t
1174+
) -> ScalarFunction:
11631175
"""Compute the integral of the function.
11641176
11651177
Args:
11661178
domain: The domain of the integral
11671179
vars: The variables to integrate with respect to
1180+
dummy_vars: The dummy variables to use inside the integral
11681181
11691182
Returns:
11701183
The integral
@@ -1587,12 +1600,16 @@ def norm(self) -> ScalarFunction:
15871600
"""
15881601
raise NotImplementedError()
15891602

1590-
def integral(self, domain: Reference, vars: AxisVariablesNotSingle = t):
1603+
def integral(
1604+
self, domain: Reference, vars: AxisVariablesNotSingle = x,
1605+
dummy_vars: AxisVariablesNotSingle = t
1606+
) -> ScalarFunction:
15911607
"""Compute the integral of the function.
15921608
15931609
Args:
15941610
domain: The domain of the integral
15951611
vars: The variables to integrate with respect to
1612+
dummy_vars: The dummy variables to use inside the integral
15961613
15971614
Returns:
15981615
The integral

symfem/piecewise_functions.py

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
import sympy
88

9-
from .functions import (AnyFunction, FunctionInput, SympyFormat, ValuesToSubstitute, VectorFunction,
10-
_to_sympy_format, parse_function_input)
9+
from .functions import (AnyFunction, FunctionInput, ScalarFunction, SympyFormat, ValuesToSubstitute,
10+
VectorFunction, _to_sympy_format, parse_function_input)
1111
from .geometry import (PointType, SetOfPoints, SetOfPointsInput, parse_set_of_points_input,
1212
point_in_quadrilateral, point_in_tetrahedron, point_in_triangle)
1313
from .references import Reference
@@ -422,20 +422,27 @@ def norm(self) -> PiecewiseFunction:
422422
return PiecewiseFunction(
423423
{shape: f.norm() for shape, f in self._pieces.items()}, self.tdim)
424424

425-
def integral(self, domain: Reference, vars: AxisVariablesNotSingle = t) -> AnyFunction:
425+
def integral(
426+
self, domain: Reference, vars: AxisVariablesNotSingle = x,
427+
dummy_vars: AxisVariablesNotSingle = t
428+
) -> ScalarFunction:
426429
"""Compute the integral of the function.
427430
428431
Args:
429432
domain: The domain of the integral
430433
vars: The variables to integrate with respect to
434+
dummy_vars: The dummy variables to use inside the integral
431435
432436
Returns:
433437
The integral
434438
"""
435-
# TODO: Add check that the domain is a subset of one piece
436-
# TODO: Add integral over multiple pieces
437-
p = self.get_piece(domain.midpoint())
438-
return p.integral(domain, vars)
439+
result = ScalarFunction(0)
440+
for shape, f in self._pieces.items():
441+
ref = _piece_reference(self.tdim, shape)
442+
sub_domain = ref.intersection(domain)
443+
if sub_domain is not None:
444+
result += f.integral(sub_domain, vars, dummy_vars)
445+
return result
439446

440447
def det(self) -> PiecewiseFunction:
441448
"""Compute the determinant.
@@ -500,25 +507,11 @@ def plot_values(
500507
value_scale: The scale factor for the function values
501508
n: The number of points per side for plotting
502509
"""
503-
from .create import create_reference
504510
from .plotting import Picture
505511
assert isinstance(img, Picture)
506512

507513
for shape, f in self._pieces.items():
508-
if self.tdim == 2:
509-
if len(shape) == 3:
510-
ref = create_reference("triangle", shape)
511-
elif len(shape) == 4:
512-
ref = create_reference("quadrilateral", shape)
513-
else:
514-
raise ValueError("Unsupported cell type")
515-
elif self.tdim == 3:
516-
if len(shape) == 4:
517-
ref = create_reference("tetrahedron", shape)
518-
else:
519-
raise ValueError("Unsupported cell type")
520-
else:
521-
raise ValueError("Unsupported tdim")
514+
ref = _piece_reference(self.tdim, shape)
522515
f.plot_values(ref, img, value_scale, n // 2)
523516

524517
def with_floats(self) -> AnyFunction:
@@ -529,3 +522,22 @@ def with_floats(self) -> AnyFunction:
529522
"""
530523
return PiecewiseFunction(
531524
{shape: f.with_floats() for shape, f in self._pieces.items()}, self.tdim)
525+
526+
527+
def _piece_reference(tdim, shape):
528+
"""Create a reference element for a single piece."""
529+
from .create import create_reference
530+
if tdim == 2:
531+
if len(shape) == 3:
532+
return create_reference("triangle", shape)
533+
elif len(shape) == 4:
534+
return create_reference("quadrilateral", shape)
535+
else:
536+
raise ValueError("Unsupported cell type")
537+
elif tdim == 3:
538+
if len(shape) == 4:
539+
return create_reference("tetrahedron", shape)
540+
else:
541+
raise ValueError("Unsupported cell type")
542+
else:
543+
raise ValueError("Unsupported tdim")

symfem/polynomials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1076,7 +1076,7 @@ def orthonormal_basis(
10761076
ref = create_reference(cell)
10771077
if variables is None:
10781078
variables = x
1079-
norms = [sympy.sqrt((f ** 2).integral(ref, variables)) for f in poly[0]]
1079+
norms = [sympy.sqrt((f ** 2).integral(ref, dummy_vars=variables)) for f in poly[0]]
10801080
for i, n in enumerate(norms):
10811081
for j in range(len(poly)):
10821082
poly[j][i] /= n

symfem/references.py

Lines changed: 102 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,49 @@
1717
typing.Tuple[sympy.core.symbol.Symbol, sympy.core.expr.Expr]]]
1818

1919

20+
def _which_side(vs: SetOfPoints, p: PointType, q: PointType) -> typing.Optional[int]:
21+
"""Check which side of a line or plane a set of points are.
22+
23+
Args:
24+
vs: The set of points
25+
p: A point on the line or plane
26+
q: Another point on the line (2D) or the normal to the plane (3D)
27+
28+
Returns:
29+
2 if the points are all to the left, 1 if the points are all to the left or on the line,
30+
0 if the points are all on the line, -1 if the points are all to the right or on the line,
31+
-1 if the points are all to the right, None if there are some points on either side.
32+
"""
33+
sides = []
34+
for v in vs:
35+
if len(q) == 2:
36+
cross = (v[0] - p[0]) * (q[1] - p[1]) - (v[1] - p[1]) * (q[0] - p[0])
37+
elif len(q) == 3:
38+
cross = (v[0] - p[0]) * q[0] + (v[1] - p[1]) * q[1] + (v[2] - p[2]) * q[2]
39+
else:
40+
return None
41+
if cross == 0:
42+
sides.append(0)
43+
elif cross > 0:
44+
sides.append(1)
45+
else:
46+
sides.append(-1)
47+
48+
if -1 in sides and 1 in sides:
49+
return None
50+
if 1 in sides:
51+
if 0 in sides:
52+
return 1
53+
else:
54+
return 2
55+
if -1 in sides:
56+
if 0 in sides:
57+
return -1
58+
else:
59+
return -2
60+
return 0
61+
62+
2063
def _vsub(v: PointTypeInput, w: PointTypeInput) -> PointType:
2164
"""Subtract.
2265
@@ -155,6 +198,45 @@ def clockwise_vertices(self) -> SetOfPoints:
155198
"""
156199
return self.vertices
157200

201+
def intersection(self, other: Reference) -> typing.Optional[Reference]:
202+
"""Get the intersection of two references.
203+
204+
Returns:
205+
A reference element that is the intersection
206+
"""
207+
if self.gdim != other.gdim:
208+
raise ValueError("Incompatible cell dimensions")
209+
210+
for cell1, cell2 in [(self, other), (other, self)]:
211+
try:
212+
for v in cell1.vertices:
213+
if not cell2.contains(v):
214+
break
215+
else:
216+
return cell1
217+
except NotImplementedError:
218+
pass
219+
for cell1, cell2 in [(self, other), (other, self)]:
220+
if cell1.gdim == 2:
221+
for e in cell1.edges:
222+
p = cell1.vertices[e[0]]
223+
q = cell1.vertices[e[1]]
224+
dir1 = _which_side(cell1.vertices, p, q)
225+
dir2 = _which_side(cell2.vertices, p, q)
226+
if dir1 is not None and dir2 is not None and dir1 * dir2 < 0:
227+
return None
228+
if cell1.gdim == 3:
229+
for i in range(cell1.sub_entity_count(2)):
230+
face = cell1.sub_entity(2, i)
231+
p = face.midpoint()
232+
n = face.normal()
233+
dir1 = _which_side(cell1.vertices, p, n)
234+
dir2 = _which_side(cell2.vertices, p, n)
235+
if dir1 is not None and dir2 is not None and dir1 * dir2 < 0:
236+
return None
237+
238+
raise NotImplementedError("Intersection of these elements is not yet supported")
239+
158240
@abstractmethod
159241
def default_reference(self) -> Reference:
160242
"""Get the default reference for this cell type.
@@ -1028,9 +1110,17 @@ def contains(self, point: PointType) -> bool:
10281110
Returns:
10291111
Is the point contained in the reference?
10301112
"""
1031-
if self.vertices != self.reference_vertices:
1032-
raise NotImplementedError()
1033-
return 0 <= point[0] and 0 <= point[1] and sum(point) <= 1
1113+
if self.vertices == self.reference_vertices:
1114+
return 0 <= point[0] and 0 <= point[1] and sum(point) <= 1
1115+
elif self.gdim == 2:
1116+
po = _vsub(point, self.origin)
1117+
det = self.axes[0][0] * self.axes[1][1] - self.axes[0][1] * self.axes[1][0]
1118+
t0 = (self.axes[1][1] * po[0] - self.axes[1][0] * po[1]) / det
1119+
t1 = (self.axes[0][0] * po[1] - self.axes[0][1] * po[0]) / det
1120+
print(self.origin, self.axes, point)
1121+
print(t0, t1)
1122+
return 0 <= t0 and 0 <= t1 and t0 + t1 <= 1
1123+
raise NotImplementedError()
10341124

10351125

10361126
class Tetrahedron(Reference):
@@ -1203,9 +1293,15 @@ def contains(self, point: PointType) -> bool:
12031293
Returns:
12041294
Is the point contained in the reference?
12051295
"""
1206-
if self.vertices != self.reference_vertices:
1207-
raise NotImplementedError()
1208-
return 0 <= point[0] and 0 <= point[1] and 0 <= point[2] and sum(point) <= 1
1296+
if self.vertices == self.reference_vertices:
1297+
return 0 <= point[0] and 0 <= point[1] and 0 <= point[2] and sum(point) <= 1
1298+
else:
1299+
po = _vsub(point, self.origin)
1300+
minv = sympy.Matrix([[a[i] for a in self.axes] for i in range(3)]).inv()
1301+
t0 = (minv[0, 0] * po[0] + minv[0, 1] * po[1] + minv[0, 2] * po[2])
1302+
t1 = (minv[1, 0] * po[0] + minv[1, 1] * po[1] + minv[1, 2] * po[2])
1303+
t2 = (minv[2, 0] * po[0] + minv[2, 1] * po[1] + minv[2, 2] * po[2])
1304+
return 0 <= t0 and 0 <= t1 and 0 >= t2 and t0 + t1 + t2 <= 1
12091305

12101306

12111307
class Quadrilateral(Reference):

test/test_hct.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,20 @@ def test_rhct():
7878
assert f1.diff(x[0]).diff(x[0]) == 0
7979
assert f2.diff(x[0]).diff(x[0]) == 0
8080
assert f3.diff(x[1]).diff(x[1]) == 0
81+
82+
83+
def test_rhct_integral():
84+
element = symfem.create_element("triangle", "rHCT", 3)
85+
ref = element.reference
86+
f1 = element.get_basis_function(1).directional_derivative((1, 0))
87+
f2 = element.get_basis_function(6).directional_derivative((1, 0))
88+
integrand = f1 * f2
89+
90+
third = sympy.Rational(1, 3)
91+
expr = (f1*f2).pieces[((0, 1), (0, 0), (third, third))].as_sympy()
92+
assert len((f1*f2).pieces) == 3
93+
assert (f1*f2).pieces[((0, 0), (1, 0), (third, third))] == 0
94+
assert (f1*f2).pieces[((1, 0), (0, 1), (third, third))] == 0
95+
96+
assert sympy.integrate(sympy.integrate(
97+
expr, (x[1], x[0], 1 - 2 * x[0])), (x[0], 0, third)) == integrand.integral(ref, x)

0 commit comments

Comments
 (0)