Skip to content

Commit 70b6817

Browse files
committed
Remove intermediate tools.jacobi interface
1 parent fdea570 commit 70b6817

File tree

7 files changed

+76
-297
lines changed

7 files changed

+76
-297
lines changed

dedalus/core/basis.py

+30-19
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
from ..tools.cache import CachedAttribute
2020
from ..tools.cache import CachedMethod
2121
from ..tools.cache import CachedClass
22-
from ..tools import jacobi
2322
from ..tools import clenshaw
2423
from ..tools.array import reshape_vector, axindex, axslice
2524
from ..tools.dispatch import MultiClass, SkipDispatchException
2625
from ..tools.general import unify
2726

28-
from .spaces import ParityInterval, Disk
2927
from .coords import Coordinate, CartesianCoordinates, S2Coordinates, SphericalCoordinates, PolarCoordinates, AzimuthalCoordinate
3028
from .domain import Domain
3129
from .field import Operand, LockedField
@@ -435,23 +433,19 @@ def __init__(self, coord, size, bounds, a, b, a0=None, b0=None, dealias=1, libra
435433
self.b0 = float(b0)
436434
self.library = library
437435
self.grid_params = (coord, bounds, a0, b0)
438-
self.constant_mode_value = 1 / np.sqrt(jacobi.mass(self.a, self.b))
436+
self.constant_mode_value = (1 / np.sqrt(dedalus_sphere.jacobi.mass(self.a, self.b))).astype(np.float64)
439437

440438
def _native_grid(self, scale):
441439
"""Native flat global grid."""
442440
N, = self.grid_shape((scale,))
443-
return jacobi.build_grid(N, a=self.a0, b=self.b0)
441+
grid, weights = dedalus_sphere.jacobi.quadrature(N, self.a0, self.b0)
442+
return grid.astype(np.float64)
444443

445444
@CachedMethod
446445
def transform_plan(self, grid_size):
447446
"""Build transform plan."""
448447
return self.transforms[self.library](grid_size, self.size, self.a, self.b, self.a0, self.b0)
449448

450-
# def weights(self, scales):
451-
# """Gauss-Jacobi weights."""
452-
# N = self.grid_shape(scales)[0]
453-
# return jacobi.build_weights(N, a=self.a, b=self.b)
454-
455449
# def __str__(self):
456450
# space = self.space
457451
# cls = self.__class__
@@ -513,14 +507,30 @@ def Jacobi_matrix(self, size):
513507
size = self.size
514508
return dedalus_sphere.jacobi.operator('Z')(size, self.a, self.b).square
515509

510+
@staticmethod
511+
def conversion_matrix(N, a0, b0, a1, b1):
512+
if not float(a1-a0).is_integer():
513+
raise ValueError("a0 and a1 must be integer-separated")
514+
if not float(b1-b0).is_integer():
515+
raise ValueError("b0 and b1 must be integer-separated")
516+
if a0 > a1:
517+
raise ValueError("a0 must be less than or equal to a1")
518+
if b0 > b1:
519+
raise ValueError("b0 must be less than or equal to b1")
520+
A = dedalus_sphere.jacobi.operator('A')(+1)
521+
B = dedalus_sphere.jacobi.operator('B')(+1)
522+
da, db = int(a1-a0), int(b1-b0)
523+
conv = A**da @ B**db
524+
return conv(N, a0, b0).astype(np.float64)
525+
516526
def ncc_matrix(self, arg_basis, coeffs, cutoff=1e-6):
517527
"""Build NCC matrix via Clenshaw algorithm."""
518528
if arg_basis is None:
519529
return super().ncc_matrix(arg_basis, coeffs)
520530
# Kronecker Clenshaw on argument Jacobi matrix
521531
elif isinstance(arg_basis, Jacobi):
522532
N = self.size
523-
J = jacobi.jacobi_matrix(N, arg_basis.a, arg_basis.b)
533+
J = dedalus_sphere.jacobi.operator('Z')(N, arg_basis.a, arg_basis.b).square.astype(np.float64)
524534
A, B = clenshaw.jacobi_recursion(N, self.a, self.b, J)
525535
f0 = self.const * sparse.identity(N)
526536
total = clenshaw.kronecker_clenshaw(coeffs, A, B, f0, cutoff=cutoff)
@@ -552,7 +562,7 @@ def multiplication_matrix(self, subproblem, arg_basis, coeffs, ncc_comp, arg_com
552562
A, B = clenshaw.jacobi_recursion(Nmat, a_ncc, b_ncc, J)
553563
f0 = dedalus_sphere.jacobi.polynomials(1, a_ncc, b_ncc, 1)[0].astype(np.float64) * sparse.identity(Nmat)
554564
matrix = clenshaw.matrix_clenshaw(coeffs.ravel(), A, B, f0, cutoff=cutoff)
555-
convert = jacobi.conversion_matrix(Nmat, arg_basis.a, arg_basis.b, out_basis.a, out_basis.b)
565+
convert = Jacobi.conversion_matrix(Nmat, arg_basis.a, arg_basis.b, out_basis.a, out_basis.b)
556566
matrix = convert @ matrix
557567
return matrix[:N, :N]
558568

@@ -604,7 +614,7 @@ def _full_matrix(input_basis, output_basis):
604614
N = input_basis.size
605615
a0, b0 = input_basis.a, input_basis.b
606616
a1, b1 = output_basis.a, output_basis.b
607-
matrix = jacobi.conversion_matrix(N, a0, b0, a1, b1)
617+
matrix = Jacobi.conversion_matrix(N, a0, b0, a1, b1)
608618
return matrix.tocsr()
609619

610620

@@ -643,8 +653,9 @@ def _output_basis(input_basis):
643653
def _full_matrix(input_basis, output_basis):
644654
N = input_basis.size
645655
a, b = input_basis.a, input_basis.b
646-
matrix = jacobi.differentiation_matrix(N, a, b) / input_basis.COV.stretch
647-
return matrix.tocsr()
656+
native_matrix = dedalus_sphere.jacobi.operator('D')(+1)(N, a, b).square.astype(np.float64)
657+
problem_matrix = native_matrix / input_basis.COV.stretch
658+
return problem_matrix.tocsr()
648659

649660

650661
class InterpolateJacobi(operators.Interpolate, operators.SpectralOperator1D):
@@ -666,9 +677,9 @@ def _full_matrix(input_basis, output_basis, position):
666677
N = input_basis.size
667678
a, b = input_basis.a, input_basis.b
668679
x = input_basis.COV.native_coord(position)
669-
interp_vector = jacobi.build_polynomials(N, a, b, x)
670-
# Return with shape (1, N)
671-
return interp_vector[None, :]
680+
x = np.array([x])
681+
matrix = dedalus_sphere.jacobi.polynomials(N, a, b, x).T
682+
return matrix.astype(np.float64)
672683

673684

674685
class IntegrateJacobi(operators.Integrate, operators.SpectralOperator1D):
@@ -688,7 +699,7 @@ def _full_matrix(input_basis, output_basis):
688699
# Build native integration vector
689700
N = input_basis.size
690701
a, b = input_basis.a, input_basis.b
691-
integ_vector = jacobi.integration_vector(N, a, b)
702+
integ_vector = dedalus_sphere.jacobi.polynomial_integrals(N, a, b).astype(np.float64)
692703
# Rescale and return with shape (1, N)
693704
return integ_vector[None, :] * input_basis.COV.stretch
694705

@@ -710,7 +721,7 @@ def _full_matrix(input_basis, output_basis):
710721
# Build native integration vector
711722
N = input_basis.size
712723
a, b = input_basis.a, input_basis.b
713-
integ_vector = jacobi.integration_vector(N, a, b)
724+
integ_vector = dedalus_sphere.jacobi.polynomial_integrals(N, a, b).astype(np.float64)
714725
ave_vector = integ_vector / 2
715726
# Rescale and return with shape (1, N)
716727
return ave_vector[None, :]

dedalus/core/spaces.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import numpy as np
66

7-
from ..tools import jacobi
87
from ..tools.array import reshape_vector
98
from ..tools.cache import CachedMethod, CachedAttribute
109

dedalus/core/transforms.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from . import basis
1212
from ..libraries.fftw import fftw_wrappers as fftw
13-
from ..tools import jacobi
1413
from ..tools.array import apply_matrix, apply_dense, axslice, splu_inverse, apply_sparse, prod
1514
from ..tools.cache import CachedAttribute
1615
from ..tools.cache import CachedMethod
@@ -118,10 +117,9 @@ def forward_matrix(self):
118117
a, a0 = self.a, self.a0
119118
b, b0 = self.b, self.b0
120119
# Gauss quadrature with base (a0, b0) polynomials
121-
base_grid = jacobi.build_grid(N, a=a0, b=b0)
122-
base_polynomials = jacobi.build_polynomials(max(M, N), a0, b0, base_grid)
123-
base_weights = jacobi.build_weights(N, a=a0, b=b0)
124-
base_transform = (base_polynomials * base_weights)
120+
base_grid, base_weights = dedalus_sphere.jacobi.quadrature(N, a0, b0)
121+
base_polynomials = dedalus_sphere.jacobi.polynomials(max(M, N), a0, b0, base_grid)
122+
base_transform = (base_polynomials * base_weights).astype(np.float64)
125123
# Zero higher coefficients for transforms with grid_size < coeff_size
126124
base_transform[N:, :] = 0
127125
if DEALIAS_BEFORE_CONVERTING():
@@ -131,7 +129,7 @@ def forward_matrix(self):
131129
if (a == a0) and (b == b0):
132130
forward_matrix = base_transform
133131
else:
134-
conversion = jacobi.conversion_matrix(base_transform.shape[0], a0, b0, a, b)
132+
conversion = basis.Jacobi.conversion_matrix(base_transform.shape[0], a0, b0, a, b)
135133
forward_matrix = conversion @ base_transform
136134
if not DEALIAS_BEFORE_CONVERTING():
137135
# Truncate to specified coeff_size
@@ -146,8 +144,8 @@ def backward_matrix(self):
146144
a, a0 = self.a, self.a0
147145
b, b0 = self.b, self.b0
148146
# Construct polynomials on the base grid
149-
base_grid = jacobi.build_grid(N, a=a0, b=b0)
150-
polynomials = jacobi.build_polynomials(M, a, b, base_grid)
147+
base_grid, base_weights = dedalus_sphere.jacobi.quadrature(N, a0, b0)
148+
polynomials = dedalus_sphere.jacobi.polynomials(M, a, b, base_grid).astype(np.float64)
151149
# Zero higher polynomials for transforms with grid_size < coeff_size
152150
polynomials[N:, :] = 0
153151
# Transpose and ensure C ordering for fast dot products
@@ -816,10 +814,10 @@ def __init__(self, grid_size, coeff_size, a, b, a0, b0):
816814
self.resize_rescale_backward = self._resize_rescale_backward
817815
else:
818816
# Conversion matrices
819-
self.forward_conversion = jacobi.conversion_matrix(self.N, a0, b0, a, b)
817+
self.forward_conversion = basis.Jacobi.conversion_matrix(self.N, a0, b0, a, b)
820818
self.forward_conversion.resize(self.M_orig, self.N)
821819
self.forward_conversion = self.forward_conversion.tocsr()
822-
self.backward_conversion = jacobi.conversion_matrix(self.M_orig, a0, b0, a, b)
820+
self.backward_conversion = basis.Jacobi.conversion_matrix(self.M_orig, a0, b0, a, b)
823821
self.backward_conversion = splu_inverse(self.backward_conversion)
824822
self.resize_rescale_forward = self._resize_rescale_forward_convert
825823
self.resize_rescale_backward = self._resize_rescale_backward_convert

dedalus/libraries/dedalus_sphere/jacobi.py

+36
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,42 @@ def polynomials(n, a, b, z, init=None, Newton=False, normalized=True, dtype=None
113113
return P[:n]
114114

115115

116+
def polynomial_integrals(n, a, b, dtype=None, **kw):
117+
"""
118+
Definite integrals of the Jacobi polynomials, evaluated with Gauss-Legendre quadrature.
119+
120+
Parameters
121+
----------
122+
n : int
123+
Number of polynomials to compute (max degree + 1).
124+
a, b : float
125+
Jacobi parameters.
126+
dtype : dtype, optional
127+
Data type. Default: module-level DEFAULT_GRID_DTYPE.
128+
**kw : dict, optional
129+
Other keywords passed to jacobi.polynomials.
130+
131+
Returns
132+
-------
133+
integrals : array
134+
Vector of polynomial integrals.
135+
"""
136+
if dtype is None:
137+
dtype = DEFAULT_GRID_DTYPE
138+
# Build Legendre quadrature
139+
grid, weights = quadrature(n, a=0, b=0, dtype=dtype)
140+
# Evaluate polynomials on Legendre grid
141+
polys = polynomials(n, a, b, grid, dtype=dtype, **kw)
142+
# Compute integrals using Legendre quadrature
143+
integrals = weights @ polys.T
144+
# Eliminate known zeros
145+
if a == b == 0:
146+
integrals[1:] = 0
147+
elif a == b:
148+
integrals[1::2] = 0
149+
return integrals
150+
151+
116152
def quadrature(n, a, b, iterations=3, probability=False, dtype=None):
117153
"""
118154
Gauss-Jacobi quadrature grid z and weights w.

dedalus/tests/test_clenshaw.py

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from scipy import sparse
55
from dedalus.core import coords, distributor, basis, field, operators
66
from dedalus.tools.array import apply_matrix
7-
from dedalus.tools import jacobi
87
from dedalus.tools import clenshaw
98
from ..libraries import dedalus_sphere
109

dedalus/tools/clenshaw.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from scipy import sparse
55

6-
from . import jacobi
6+
from ..libraries import dedalus_sphere
77
from .general import DeferredTuple
88

99

@@ -72,7 +72,7 @@ def jacobi_recursion(N, a, b, X):
7272
B[n] = - J[n,n-1]/J[n,n+1]
7373
"""
7474
# Jacobi matrix
75-
J = jacobi.jacobi_matrix(N, a, b)
75+
J = dedalus_sphere.jacobi.operator('Z')(N, a, b).square.astype(np.float64)
7676
JA = J.A
7777
# Identity element
7878
if np.isscalar(X):

0 commit comments

Comments
 (0)