Skip to content

Commit a6387ba

Browse files
committed
LRU cache, tidying
1 parent a999c31 commit a6387ba

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

bt_ocean/finite_difference.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""Finite difference utilities.
22
"""
33

4-
from functools import partial
5-
from numbers import Real
4+
from functools import lru_cache, partial
5+
from numbers import Rational
66

77
import jax
88
import jax.numpy as jnp
@@ -21,34 +21,47 @@ def difference_coefficients(beta, order):
2121
Parameters
2222
----------
2323
24-
beta : Sequence[Real, ...]
24+
beta : Sequence
2525
Grid location displacements.
2626
order : Integral
2727
Derivative order.
2828
2929
Returns
3030
-------
3131
32-
tuple[:class:`sympy.Rational`, ...]
32+
tuple[:class:`sympy.core.expr.Expr`, ...]
3333
Finite difference coefficients.
3434
"""
3535

36-
beta = tuple(map(sp.Rational, beta))
37-
if not all(isinstance(beta_i, Real) for beta_i in beta):
38-
raise ValueError("Invalid type")
36+
def displacement_cast(v):
37+
if isinstance(v, Rational):
38+
return sp.Rational(v)
39+
elif isinstance(v, sp.core.expr.Expr):
40+
return sp.Expr(v)
41+
else:
42+
return v
43+
44+
return _difference_coefficients(tuple(map(displacement_cast, beta)), order)
45+
46+
47+
@lru_cache(maxsize=32)
48+
def _difference_coefficients(beta, order):
3949
N = len(beta)
4050
if order >= N:
4151
raise ValueError("Invalid order")
4252

43-
a = tuple(sp.Symbol("a_{" + f"{i}" + "}", real=True)
53+
assumptions = {}
54+
if all(map(bool, (beta_i.is_real for beta_i in beta))):
55+
assumptions["real"] = True
56+
a = tuple(sp.Symbol("_bt_ocean__finite_difference_{" + f"{i}" + "}", **assumptions)
4457
for i in range(N))
45-
eqs = [sum((a[i] * sp.Rational(beta[i] ** j, sp.factorial(j))
58+
eqs = [sum((a[i] * ((beta[i] ** j) / sp.factorial(j))
4659
for i in range(N)), start=sp.Integer(0))
4760
for j in range(N)]
4861
eqs[order] -= sp.Integer(1)
4962

5063
soln, = sp.linsolve(eqs, a)
51-
return tuple(map(sp.Rational, soln))
64+
return soln
5265

5366

5467
@partial(jax.jit, static_argnames={"order", "N", "axis", "i0", "i1", "boundary_expansion"})
@@ -115,6 +128,7 @@ def diff_bounded(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansi
115128
dtype = u.dtype.type
116129
i0 = -(N // 2)
117130
i1 = i0 + N
131+
assert i1 > 0 # Insufficient points
118132
parity = (-1) ** order
119133

120134
for i in range(max(0, min(i0_b, u.shape[-1] - i1_b)), max(-i0, i1 - 1)):
@@ -165,6 +179,7 @@ def diff_periodic(u, dx, order, N, *, axis=-1):
165179
u = jnp.moveaxis(u, axis, -1)
166180
i0 = -(N // 2)
167181
i1 = i0 + N
182+
assert i1 > 0 # Insufficient points
168183

169184
# Periodic extension
170185
u_e = jnp.zeros_like(u, shape=u.shape[:-1] + (u.shape[-1] + N,))

0 commit comments

Comments
 (0)