Skip to content

Commit a999c31

Browse files
committed
Add periodic boundary conditions
1 parent 1158e38 commit a999c31

File tree

3 files changed

+77
-13
lines changed

3 files changed

+77
-13
lines changed

bt_ocean/finite_difference.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
__all__ = \
1212
[
13-
"diff"
13+
"diff_bounded",
14+
"diff_periodic"
1415
]
1516

1617

@@ -51,11 +52,11 @@ def difference_coefficients(beta, order):
5152

5253

5354
@partial(jax.jit, static_argnames={"order", "N", "axis", "i0", "i1", "boundary_expansion"})
54-
def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None):
55+
def diff_bounded(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None):
5556
"""Compute a centred finite difference approximation to a derivative for
56-
data stored on a uniform grid. Transitions to one-sided differencing as the
57-
end-points are approached. Selects an additional right-sided point if
58-
`N` is even.
57+
data stored on a uniform grid. Result is defined on the same grid as the
58+
input (i.e. without staggering). Transitions to one-sided differencing as
59+
the end-points are approached.
5960
6061
Parameters
6162
----------
@@ -67,7 +68,8 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
6768
order : Integral
6869
Derivative order.
6970
N : Integral
70-
Number of grid points in the difference approximation.
71+
Number of grid points in the difference approximation. Centered
72+
differencing uses an additional right-sided point if `N` is even.
7173
axis : Integral
7274
Axis.
7375
i0 : Integral
@@ -115,7 +117,7 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
115117
i1 = i0 + N
116118
parity = (-1) ** order
117119

118-
for i in range(max(-i0, i1 - 1)):
120+
for i in range(max(0, min(i0_b, u.shape[-1] - i1_b)), max(-i0, i1 - 1)):
119121
beta = tuple(range(-i, -i + N + int(bool(boundary_expansion))))
120122
alpha = tuple(map(dtype, difference_coefficients(beta, order)))
121123
if i < -i0 and i >= i0_b:
@@ -130,7 +132,7 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
130132
v = v.at[..., u.shape[-1] - 1 - i].add(
131133
parity * alpha_j * u[..., u.shape[-1] - 1 - i - beta_j])
132134

133-
# Center
135+
# Center points
134136
beta = tuple(range(i0, i1))
135137
alpha = tuple(map(dtype, difference_coefficients(beta, order)))
136138
i0_c = max(-i0, i0_b)
@@ -142,3 +144,35 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
142144

143145
v = jnp.moveaxis(v, -1, axis)
144146
return v / (dx ** order)
147+
148+
149+
@partial(jax.jit, static_argnames={"order", "N", "axis"})
150+
def diff_periodic(u, dx, order, N, *, axis=-1):
151+
"""Compute a centred finite difference approximation to a derivative for
152+
data stored on a uniform grid. Result is defined on the same grid as the
153+
input (i.e. without staggering). Applies periodic boundary conditions.
154+
155+
Arguments and return value are as for :func:`.diff_bounded`.
156+
"""
157+
158+
if axis < 0:
159+
axis = len(u.shape) + axis
160+
if axis < 0 or axis >= len(u.shape):
161+
raise ValueError("Invalid axis")
162+
if u.shape[axis] < N:
163+
raise ValueError("Insufficient points")
164+
165+
u = jnp.moveaxis(u, axis, -1)
166+
i0 = -(N // 2)
167+
i1 = i0 + N
168+
169+
# Periodic extension
170+
u_e = jnp.zeros_like(u, shape=u.shape[:-1] + (u.shape[-1] + N,))
171+
u_e = u_e.at[..., -i0:-i1].set(u)
172+
u_e = u_e.at[..., :-i0].set(u[..., i0:])
173+
u_e = u_e.at[..., -i1:].set(u[..., :i1])
174+
175+
v = diff_bounded(u_e, dx, order, N, axis=-1, i0=-i0, i1=-i1)[..., -i0:-i1]
176+
177+
v = jnp.moveaxis(v, -1, axis)
178+
return v

bt_ocean/grid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import cached_property, partial
88
from numbers import Real
99

10-
from .finite_difference import diff
10+
from .finite_difference import diff_bounded as diff
1111
from .precision import default_idtype, default_fdtype
1212

1313
__all__ = \

tests/test_finite_difference.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from bt_ocean.finite_difference import (
2-
difference_coefficients, diff as centered_difference_bounded)
2+
difference_coefficients, diff_bounded, diff_periodic)
33
from bt_ocean.precision import default_fdtype
44

55
import jax
@@ -53,7 +53,7 @@ def test_centered_difference_monomials(alpha):
5353
for i in range(order):
5454
diff_exact *= p - i
5555
for N in range(max(p + 1, order + 1), x.shape[0]):
56-
diff = centered_difference_bounded(u, dx, order=order, N=N)
56+
diff = diff_bounded(u, dx, order=order, N=N)
5757
error_norm = abs(diff - diff_exact).max()
5858
diff_exact_norm = abs(diff_exact).max()
5959
if diff_exact_norm > 0:
@@ -77,12 +77,12 @@ def test_centered_difference_convergence():
7777
W = W.at[-1].set(0.5 * dx)
7878
u = jnp.sin(jnp.pi * x)
7979

80-
D1_error = (centered_difference_bounded(u, dx, order=1, N=3)
80+
D1_error = (diff_bounded(u, dx, order=1, N=3)
8181
- jnp.pi * jnp.cos(jnp.pi * x))
8282
error_norms_1 = error_norms_1.at[i].set(jnp.sqrt((W * (D1_error ** 2)).sum())) # noqa: E501
8383
print(f"{p=:d} {error_norms_1[i]=:.6g}")
8484

85-
D2_error = (centered_difference_bounded(u, dx, order=2, N=3)
85+
D2_error = (diff_bounded(u, dx, order=2, N=3)
8686
+ (jnp.pi ** 2) * jnp.sin(jnp.pi * x))
8787
error_norms_2 = error_norms_2.at[i].set(jnp.sqrt((W * (D2_error ** 2)).sum())) # noqa: E501
8888
print(f"{p=:d} {error_norms_2[i]=:.6g}")
@@ -92,3 +92,33 @@ def test_centered_difference_convergence():
9292
print(f"{orders_2=}")
9393
assert orders_1.min() > 2
9494
assert orders_2.min() > 2
95+
96+
97+
def test_centered_difference_convergence_periodic():
98+
if default_fdtype() != np.float64 or not jax.config.x64_enabled:
99+
pytest.skip("float64 not available")
100+
101+
P = jnp.arange(7, 12, dtype=int)
102+
error_norms_1 = jnp.zeros_like(P, dtype=float)
103+
error_norms_2 = jnp.zeros_like(P, dtype=float)
104+
for i, p in enumerate(P):
105+
x = jnp.linspace(0, 1, 2 ** p + 1, dtype=float)[:-1]
106+
dx = x[1] - x[0]
107+
W = jnp.full_like(x, dx)
108+
u = jnp.sin(2 * jnp.pi * x)
109+
110+
D1_error = (diff_periodic(u, dx, order=1, N=3)
111+
- 2 * jnp.pi * jnp.cos(2 * jnp.pi * x))
112+
error_norms_1 = error_norms_1.at[i].set(jnp.sqrt((W * (D1_error ** 2)).sum())) # noqa: E501
113+
print(f"{p=:d} {error_norms_1[i]=:.6g}")
114+
115+
D2_error = (diff_periodic(u, dx, order=2, N=3)
116+
+ 4 * (jnp.pi ** 2) * jnp.sin(2 * jnp.pi * x))
117+
error_norms_2 = error_norms_2.at[i].set(jnp.sqrt((W * (D2_error ** 2)).sum())) # noqa: E501
118+
print(f"{p=:d} {error_norms_2[i]=:.6g}")
119+
orders_1 = jnp.log2(error_norms_1[:-1] / error_norms_1[1:])
120+
orders_2 = jnp.log2(error_norms_2[:-1] / error_norms_2[1:])
121+
print(f"{orders_1=}")
122+
print(f"{orders_2=}")
123+
assert orders_1.min() > 1.99
124+
assert orders_2.min() > 1.99

0 commit comments

Comments
 (0)