10
10
11
11
__all__ = \
12
12
[
13
- "diff"
13
+ "diff_bounded" ,
14
+ "diff_periodic"
14
15
]
15
16
16
17
@@ -51,11 +52,11 @@ def difference_coefficients(beta, order):
51
52
52
53
53
54
@partial (jax .jit , static_argnames = {"order" , "N" , "axis" , "i0" , "i1" , "boundary_expansion" })
54
- def diff (u , dx , order , N , * , axis = - 1 , i0 = None , i1 = None , boundary_expansion = None ):
55
+ def diff_bounded (u , dx , order , N , * , axis = - 1 , i0 = None , i1 = None , boundary_expansion = None ):
55
56
"""Compute a centred finite difference approximation to a derivative for
56
- data stored on a uniform grid. Transitions to one-sided differencing as the
57
- end-points are approached. Selects an additional right -sided point if
58
- `N` is even .
57
+ data stored on a uniform grid. Result is defined on the same grid as the
58
+ input (i.e. without staggering). Transitions to one -sided differencing as
59
+ the end-points are approached .
59
60
60
61
Parameters
61
62
----------
@@ -67,7 +68,8 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
67
68
order : Integral
68
69
Derivative order.
69
70
N : Integral
70
- Number of grid points in the difference approximation.
71
+ Number of grid points in the difference approximation. Centered
72
+ differencing uses an additional right-sided point if `N` is even.
71
73
axis : Integral
72
74
Axis.
73
75
i0 : Integral
@@ -115,7 +117,7 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
115
117
i1 = i0 + N
116
118
parity = (- 1 ) ** order
117
119
118
- for i in range (max (- i0 , i1 - 1 )):
120
+ for i in range (max (0 , min ( i0_b , u . shape [ - 1 ] - i1_b )), max ( - i0 , i1 - 1 )):
119
121
beta = tuple (range (- i , - i + N + int (bool (boundary_expansion ))))
120
122
alpha = tuple (map (dtype , difference_coefficients (beta , order )))
121
123
if i < - i0 and i >= i0_b :
@@ -130,7 +132,7 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
130
132
v = v .at [..., u .shape [- 1 ] - 1 - i ].add (
131
133
parity * alpha_j * u [..., u .shape [- 1 ] - 1 - i - beta_j ])
132
134
133
- # Center
135
+ # Center points
134
136
beta = tuple (range (i0 , i1 ))
135
137
alpha = tuple (map (dtype , difference_coefficients (beta , order )))
136
138
i0_c = max (- i0 , i0_b )
@@ -142,3 +144,35 @@ def diff(u, dx, order, N, *, axis=-1, i0=None, i1=None, boundary_expansion=None)
142
144
143
145
v = jnp .moveaxis (v , - 1 , axis )
144
146
return v / (dx ** order )
147
+
148
+
149
+ @partial (jax .jit , static_argnames = {"order" , "N" , "axis" })
150
+ def diff_periodic (u , dx , order , N , * , axis = - 1 ):
151
+ """Compute a centred finite difference approximation to a derivative for
152
+ data stored on a uniform grid. Result is defined on the same grid as the
153
+ input (i.e. without staggering). Applies periodic boundary conditions.
154
+
155
+ Arguments and return value are as for :func:`.diff_bounded`.
156
+ """
157
+
158
+ if axis < 0 :
159
+ axis = len (u .shape ) + axis
160
+ if axis < 0 or axis >= len (u .shape ):
161
+ raise ValueError ("Invalid axis" )
162
+ if u .shape [axis ] < N :
163
+ raise ValueError ("Insufficient points" )
164
+
165
+ u = jnp .moveaxis (u , axis , - 1 )
166
+ i0 = - (N // 2 )
167
+ i1 = i0 + N
168
+
169
+ # Periodic extension
170
+ u_e = jnp .zeros_like (u , shape = u .shape [:- 1 ] + (u .shape [- 1 ] + N ,))
171
+ u_e = u_e .at [..., - i0 :- i1 ].set (u )
172
+ u_e = u_e .at [..., :- i0 ].set (u [..., i0 :])
173
+ u_e = u_e .at [..., - i1 :].set (u [..., :i1 ])
174
+
175
+ v = diff_bounded (u_e , dx , order , N , axis = - 1 , i0 = - i0 , i1 = - i1 )[..., - i0 :- i1 ]
176
+
177
+ v = jnp .moveaxis (v , - 1 , axis )
178
+ return v
0 commit comments