Skip to content

Commit fdea570

Browse files
committed
Rework jacobi and sphere modules in dedalus_sphere to support extended precision via xprec.
1 parent ec0e3df commit fdea570

File tree

6 files changed

+421
-315
lines changed

6 files changed

+421
-315
lines changed

dedalus/core/basis.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
550550
Nmat = 3*((N+1)//2) + min((N+1)//2, (da+db+1)//2)
551551
J = arg_basis.Jacobi_matrix(size=Nmat)
552552
A, B = clenshaw.jacobi_recursion(Nmat, a_ncc, b_ncc, J)
553-
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0] * sparse.identity(Nmat)
553+
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0].astype(np.float64) * sparse.identity(Nmat)
554554
matrix = clenshaw.matrix_clenshaw(coeffs.ravel(), A, B, f0, cutoff=cutoff)
555555
convert = jacobi.conversion_matrix(Nmat, arg_basis.a, arg_basis.b, out_basis.a, out_basis.b)
556556
matrix = convert @ matrix
@@ -2040,7 +2040,7 @@ def _radius_weights(self, scale):
20402040
Q0 = dedalus_sphere.jacobi.polynomials(N, self.alpha[0], self.alpha[1], z0)
20412041
Q_proj = dedalus_sphere.jacobi.polynomials(N, self.alpha[0], self.alpha[1], z_proj)
20422042
normalization = self.dR/2
2043-
return normalization * ( (Q0 @ weights0).T ) @ (weights_proj*Q_proj)
2043+
return (normalization * ( (Q0 @ weights0).T ) @ (weights_proj*Q_proj)).astype(np.float64)
20442044

20452045
def global_radius_weights(self, scale=None):
20462046
if scale == None: scale = 1
@@ -2059,7 +2059,7 @@ def local_radius_weights(self, scale=None):
20592059
def constant_mode_value(self):
20602060
# Note the zeroth mode is constant only for k=0
20612061
Q0 = dedalus_sphere.jacobi.polynomials(1, self.alpha[0], self.alpha[1], np.array([0.0]))
2062-
return Q0[0,0]
2062+
return np.float64(Q0[0,0])
20632063

20642064
def _new_k(self, k):
20652065
return AnnulusBasis(self.coordsystem, self.shape, radii = self.radii, k=k, alpha=self.alpha, dealias=self.dealias, dtype=self.dtype,
@@ -2137,7 +2137,7 @@ def _interpolation(self, position):
21372137
a = self.alpha[0] + self.k
21382138
b = self.alpha[1] + self.k
21392139
radial_factor = (self.dR/position)**(self.k)
2140-
return radial_factor*dedalus_sphere.jacobi.polynomials(self.n_size(0), a, b, native_position)
2140+
return radial_factor*dedalus_sphere.jacobi.polynomials(self.n_size(0), a, b, native_position).astype(np.float64)
21412141

21422142
@CachedMethod
21432143
def operator_matrix(self,op,m,spintotal, size=None):
@@ -2188,7 +2188,7 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
21882188
Nmat = 3*((N0+1)//2) + self.k
21892189
J = arg_basis.operator_matrix('Z', m, spintotal_arg, size=Nmat)
21902190
A, B = clenshaw.jacobi_recursion(Nmat, a_ncc, b_ncc, J)
2191-
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0] * sparse.identity(Nmat)
2191+
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0].astype(np.float64) * sparse.identity(Nmat)
21922192
# Conversions to account for radial prefactors
21932193
prefactor = arg_basis.jacobi_conversion(m, dk=self.k, size=Nmat)
21942194
if self.dtype == np.float64:
@@ -2299,7 +2299,7 @@ def local_radius_weights(self, scale=None):
22992299
@CachedAttribute
23002300
def constant_mode_value(self):
23012301
Qk = dedalus_sphere.zernike.polynomials(2, 1, self.alpha+self.k, 0, np.array([0]))
2302-
return Qk[0]
2302+
return np.float64(Qk[0])
23032303

23042304
def _new_k(self, k):
23052305
return DiskBasis(self.coordsystem, self.shape, radius = self.radius, k=k, alpha=self.alpha, dealias=self.dealias, dtype=self.dtype,
@@ -2399,7 +2399,7 @@ def operator_matrix(self, op, m, spin, size=None):
23992399
def interpolation(self, m, spintotal, position):
24002400
native_position = self.radial_COV.native_coord(position)
24012401
native_z = 2*native_position**2 - 1
2402-
return dedalus_sphere.zernike.polynomials(2, self.n_size(m), self.alpha + self.k, np.abs(m + spintotal), native_z)
2402+
return dedalus_sphere.zernike.polynomials(2, self.n_size(m), self.alpha + self.k, np.abs(m + spintotal), native_z).astype(np.float64)
24032403

24042404
@CachedMethod
24052405
def radius_multiplication_matrix(self, m, spintotal, order, d):
@@ -2436,7 +2436,7 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
24362436
J = arg_basis.operator_matrix('Z', m, spintotal_arg)
24372437
A, B = clenshaw.jacobi_recursion(N, a_ncc, b_ncc, J)
24382438
# assuming that we're doing ball for now...
2439-
f0 = dedalus_sphere.zernike.polynomials(2, 1, a_ncc, b_ncc, 1)[0] * sparse.identity(N)
2439+
f0 = dedalus_sphere.zernike.polynomials(2, 1, a_ncc, b_ncc, 1)[0].astype(np.float64) * sparse.identity(N)
24402440
prefactor = arg_basis.radius_multiplication_matrix(m, spintotal_arg, diff_regtotal, d)
24412441
if self.dtype == np.float64:
24422442
coeffs_filter = coeffs.ravel()[:2*N]
@@ -2869,7 +2869,7 @@ def local_grid_colatitude(self, scale):
28692869
def _native_colatitude_grid(self, scale):
28702870
N = int(np.ceil(scale * self.shape[1]))
28712871
cos_theta, weights = dedalus_sphere.sphere.quadrature(Lmax=N-1)
2872-
theta = np.arccos(cos_theta).astype(np.float64)
2872+
theta = np.arccos(cos_theta.astype(np.float64)) # TODO: xprec doesn't yet support arccos
28732873
return theta
28742874

28752875
def global_colatitude_weights(self, scale=None):
@@ -3529,13 +3529,13 @@ def _radius_weights(self, scale):
35293529
Q0 = dedalus_sphere.jacobi.polynomials(N, self.alpha[0], self.alpha[1], z0)
35303530
Q_proj = dedalus_sphere.jacobi.polynomials(N, self.alpha[0], self.alpha[1], z_proj)
35313531
normalization = self.dR/2
3532-
return normalization * ( (Q0 @ weights0).T ) @ (weights_proj*Q_proj)
3532+
return (normalization * ( (Q0 @ weights0).T ) @ (weights_proj*Q_proj)).astype(np.float64)
35333533

35343534
@CachedAttribute
35353535
def constant_mode_value(self):
35363536
# Note the zeroth mode is constant only for k=0
35373537
Q0 = dedalus_sphere.jacobi.polynomials(1, self.alpha[0], self.alpha[1], np.array([0.0]))
3538-
return Q0[0,0]
3538+
return np.float64(Q0[0,0])
35393539

35403540
@CachedMethod
35413541
def radial_transform_factor(self, scale, data_axis, dk):
@@ -3548,7 +3548,7 @@ def interpolation(self, position):
35483548
a = self.alpha[0] + self.k
35493549
b = self.alpha[1] + self.k
35503550
radial_factor = (self.dR/position)**(self.k)
3551-
return radial_factor*dedalus_sphere.jacobi.polynomials(self.n_size(0), a, b, native_position)
3551+
return radial_factor*dedalus_sphere.jacobi.polynomials(self.n_size(0), a, b, native_position).astype(np.float64)
35523552

35533553
@CachedMethod
35543554
def transform_plan(self, grid_size, k):
@@ -3636,7 +3636,7 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
36363636
Nmat = 3*((N0+1)//2) + self.k
36373637
J = arg_radial_basis.operator_matrix('Z', ell, regtotal_arg, size=Nmat)
36383638
A, B = clenshaw.jacobi_recursion(Nmat, a_ncc, b_ncc, J)
3639-
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0] * sparse.identity(Nmat)
3639+
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0].astype(np.float64) * sparse.identity(Nmat)
36403640
# Conversions to account for radial prefactors
36413641
prefactor = arg_radial_basis.jacobi_conversion(ell, dk=self.k, size=Nmat)
36423642
if self.dtype == np.float64:
@@ -3740,18 +3740,18 @@ def _native_radius_grid(self, scale):
37403740
def _radius_weights(self, scale):
37413741
N = int(np.ceil(scale * self.shape[2]))
37423742
z, weights = dedalus_sphere.zernike.quadrature(3, N, k=self.alpha)
3743-
return weights
3743+
return weights.astype(np.float64)
37443744

37453745
@CachedAttribute
37463746
def constant_mode_value(self):
37473747
Qk = dedalus_sphere.zernike.polynomials(3, 1, self.alpha+self.k, 0, np.array([0]))
3748-
return Qk[0]
3748+
return np.float64(Qk[0])
37493749

37503750
@CachedMethod
37513751
def interpolation(self, ell, regtotal, position):
37523752
native_position = self.radial_COV.native_coord(position)
37533753
native_z = 2*native_position**2 - 1
3754-
return dedalus_sphere.zernike.polynomials(3, self.n_size(ell), self.alpha + self.k, ell + regtotal, native_z)
3754+
return dedalus_sphere.zernike.polynomials(3, self.n_size(ell), self.alpha + self.k, ell + regtotal, native_z).astype(np.float64)
37553755

37563756
@CachedMethod
37573757
def transform_plan(self, grid_shape, regindex, axis, regtotal, k, alpha):
@@ -3855,7 +3855,7 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
38553855
if (d >= 0) and (d % 2 == 0):
38563856
J = arg_radial_basis.operator_matrix('Z', ell, regtotal_arg, size=Nmat)
38573857
A, B = clenshaw.jacobi_recursion(N0, a_ncc, b_ncc, J)
3858-
f0 = dedalus_sphere.zernike.polynomials(3, 1, a_ncc, regtotal_ncc, 1)[0] * sparse.identity(Nmat)
3858+
f0 = dedalus_sphere.zernike.polynomials(3, 1, a_ncc, regtotal_ncc, 1)[0].astype(np.float64) * sparse.identity(Nmat)
38593859
radial_factor = arg_radial_basis.radius_multiplication_matrix(ell, regtotal_arg, diff_regtotal, d, size=Nmat)
38603860
conversion = arg_radial_basis.conversion_matrix(ell, regtotal_out, dk, size=Nmat)
38613861
prefactor = conversion @ radial_factor
@@ -4903,7 +4903,7 @@ def _radial_matrix(basis, m):
49034903
N = basis.shape[1]
49044904
z0, w0 = dedalus_sphere.zernike.quadrature(2, N, k=0)
49054905
Qk = dedalus_sphere.zernike.polynomials(2, n_size, basis.alpha+basis.k, abs(m), z0)
4906-
matrix = (w0[None, :] @ Qk.T).astype(basis.dtype)
4906+
matrix = (w0[None, :] @ Qk.T).astype(np.float64)
49074907
matrix *= basis.radius**2
49084908
matrix *= 2 * np.pi # Fourier contribution
49094909
else:
@@ -4925,8 +4925,8 @@ def _radial_matrix(basis, m):
49254925
z0, w0 = dedalus_sphere.jacobi.quadrature(N, a=0, b=0)
49264926
r0 = basis.dR / 2 * (z0 + basis.rho)
49274927
Qk = dedalus_sphere.jacobi.polynomials(n_size, basis.alpha[0]+basis.k, basis.alpha[1]+basis.k, z0)
4928-
w0_geom = r0 * w0 * (r0 / basis.dR)**(-basis.k)
4929-
matrix = (w0_geom[None, :] @ Qk.T).astype(basis.dtype)
4928+
w0_geom = r0 * w0 * (r0.astype(np.float64) / basis.dR)**(-basis.k) # TODO: xprec does not yet support power
4929+
matrix = (w0_geom[None, :] @ Qk.T).astype(np.float64)
49304930
matrix *= basis.dR / 2
49314931
matrix *= 2 * np.pi # Fourier contribution
49324932
else:
@@ -4992,7 +4992,7 @@ def _radial_matrix(basis, ell):
49924992
N = basis.shape[2]
49934993
z0, w0 = dedalus_sphere.zernike.quadrature(3, N, k=0)
49944994
Qk = dedalus_sphere.zernike.polynomials(3, n_size, basis.alpha+basis.k, ell, z0)
4995-
matrix = (w0[None, :] @ Qk.T).astype(basis.dtype)
4995+
matrix = (w0[None, :] @ Qk.T).astype(np.float64)
49964996
matrix *= basis.radius**3
49974997
matrix *= 4 * np.pi / np.sqrt(2) # SWSH contribution
49984998
else:
@@ -5014,8 +5014,8 @@ def _radial_matrix(basis, ell):
50145014
z0, w0 = dedalus_sphere.jacobi.quadrature(N, a=0, b=0)
50155015
r0 = basis.dR / 2 * (z0 + basis.rho)
50165016
Qk = dedalus_sphere.jacobi.polynomials(n_size, basis.alpha[0]+basis.k, basis.alpha[1]+basis.k, z0)
5017-
w0_geom = r0**2 * w0 * (r0 / basis.dR)**(-basis.k)
5018-
matrix = (w0_geom[None, :] @ Qk.T).astype(basis.dtype)
5017+
w0_geom = r0**2 * w0 * (r0.astype(np.float64) / basis.dR)**(-basis.k) # TODO: xprec does not yet support power
5018+
matrix = (w0_geom[None, :] @ Qk.T).astype(np.float64)
50195019
matrix *= basis.dR / 2
50205020
matrix *= 4 * np.pi / np.sqrt(2) # SWSH contribution
50215021
else:
@@ -5147,7 +5147,7 @@ def _interpolation_vectors(sphere_basis, Ntheta, s, theta):
51475147
Lmin = max(abs(m), abs(s))
51485148
interp_m = dedalus_sphere.sphere.harmonics(sphere_basis.Lmax, m, s, z)[None, :]
51495149
forward_m = forward._forward_SWSH_matrices[m][Lmin-abs(m):]
5150-
interp_vectors[m] = interp_m @ forward_m
5150+
interp_vectors[m] = interp_m.astype(np.float64) @ forward_m
51515151
else:
51525152
interp_vectors[m] = np.zeros((1, Ntheta))
51535153
return interp_vectors

dedalus/core/operators.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -2824,14 +2824,14 @@ def radial_matrix(self, spinindex_in, spinindex_out, m):
28242824
radial_basis = self.input_basis
28252825
spintotal = radial_basis.spintotal(spinindex_in)
28262826
if spinindex_out in self.spinindex_out(spinindex_in):
2827-
return self._radial_matrix(radial_basis.Lmax, spintotal, m, self.dtype)
2827+
return self._radial_matrix(radial_basis.Lmax, spintotal, m)
28282828
else:
28292829
raise ValueError("This should never happen")
28302830

28312831
@staticmethod
28322832
@CachedMethod
2833-
def _radial_matrix(Lmax, spintotal, m, dtype):
2834-
matrix = dedalus_sphere.sphere.operator('Cos', dtype)(Lmax, m, spintotal).square
2833+
def _radial_matrix(Lmax, spintotal, m):
2834+
matrix = dedalus_sphere.sphere.operator('Cos')(Lmax, m, spintotal).square
28352835
# Pad to include invalid ells
28362836
trunc = abs(spintotal) - abs(m)
28372837
if trunc > 0:

dedalus/core/transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def _forward_matrices(self):
13951395
# Spectral conversion
13961396
if self.k > 0:
13971397
conversion = dedalus_sphere.zernike.operator(2, 'E')(+1)**self.k
1398-
W = conversion(W.shape[0], self.alpha, abs(m + self.s)) @ W
1398+
W = conversion(W.shape[0], self.alpha, abs(m + self.s)) @ W.astype(np.float64)
13991399
if not DEALIAS_BEFORE_CONVERTING():
14001400
# Truncate to specified coeff_size
14011401
W = W[:max(self.N2c-Nmin,0)]
@@ -1504,7 +1504,7 @@ def _forward_GSZP_matrix(self):
15041504
# Spectral conversion
15051505
if self.k > 0:
15061506
conversion = dedalus_sphere.zernike.operator(3, 'E')(+1)**self.k
1507-
W = conversion(W.shape[0], self.alpha, ell+self.regtotal) @ W
1507+
W = conversion(W.shape[0], self.alpha, ell+self.regtotal) @ W.astype(np.float64)
15081508
if not DEALIAS_BEFORE_CONVERTING():
15091509
# Truncate to specified coeff_size
15101510
W = W[:max(self.N3c-Nmin,0)]

0 commit comments

Comments
 (0)