1
1
"""Finite difference utilities.
2
2
"""
3
3
4
- from functools import partial
5
- from numbers import Real
4
+ from functools import lru_cache , partial
5
+ from numbers import Rational
6
6
7
7
import jax
8
8
import jax .numpy as jnp
@@ -21,34 +21,47 @@ def difference_coefficients(beta, order):
21
21
Parameters
22
22
----------
23
23
24
- beta : Sequence[Real, ...]
24
+ beta : Sequence
25
25
Grid location displacements.
26
26
order : Integral
27
27
Derivative order.
28
28
29
29
Returns
30
30
-------
31
31
32
- tuple[:class:`sympy.Rational `, ...]
32
+ tuple[:class:`sympy.core.expr.Expr `, ...]
33
33
Finite difference coefficients.
34
34
"""
35
35
36
- beta = tuple (map (sp .Rational , beta ))
37
- if not all (isinstance (beta_i , Real ) for beta_i in beta ):
38
- raise ValueError ("Invalid type" )
36
+ def displacement_cast (v ):
37
+ if isinstance (v , Rational ):
38
+ return sp .Rational (v )
39
+ elif isinstance (v , sp .core .expr .Expr ):
40
+ return sp .Expr (v )
41
+ else :
42
+ return v
43
+
44
+ return _difference_coefficients (tuple (map (displacement_cast , beta )), order )
45
+
46
+
47
+ @lru_cache (maxsize = 32 )
48
+ def _difference_coefficients (beta , order ):
39
49
N = len (beta )
40
50
if order >= N :
41
51
raise ValueError ("Invalid order" )
42
52
43
- a = tuple (sp .Symbol ("a_{" + f"{ i } " + "}" , real = True )
53
+ assumptions = {}
54
+ if all (map (bool , (beta_i .is_real for beta_i in beta ))):
55
+ assumptions ["real" ] = True
56
+ a = tuple (sp .Symbol ("_bt_ocean__finite_difference_{" + f"{ i } " + "}" , ** assumptions )
44
57
for i in range (N ))
45
- eqs = [sum ((a [i ] * sp . Rational ( beta [i ] ** j , sp .factorial (j ))
58
+ eqs = [sum ((a [i ] * (( beta [i ] ** j ) / sp .factorial (j ))
46
59
for i in range (N )), start = sp .Integer (0 ))
47
60
for j in range (N )]
48
61
eqs [order ] -= sp .Integer (1 )
49
62
50
63
soln , = sp .linsolve (eqs , a )
51
- return tuple ( map ( sp . Rational , soln ))
64
+ return soln
52
65
53
66
54
67
@partial (jax .jit , static_argnames = {"order" , "N" , "axis" , "i0" , "i1" , "boundary_expansion" })
@@ -115,6 +128,7 @@ def diff_bounded(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansi
115
128
dtype = u .dtype .type
116
129
i0 = - (N // 2 )
117
130
i1 = i0 + N
131
+ assert i1 > 0 # Insufficient points
118
132
parity = (- 1 ) ** order
119
133
120
134
for i in range (max (0 , min (i0_b , u .shape [- 1 ] - i1_b )), max (- i0 , i1 - 1 )):
@@ -165,6 +179,7 @@ def diff_periodic(u, dx, order, N, *, axis=-1):
165
179
u = jnp .moveaxis (u , axis , - 1 )
166
180
i0 = - (N // 2 )
167
181
i1 = i0 + N
182
+ assert i1 > 0 # Insufficient points
168
183
169
184
# Periodic extension
170
185
u_e = jnp .zeros_like (u , shape = u .shape [:- 1 ] + (u .shape [- 1 ] + N ,))
0 commit comments